In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from tqdm import tqdm
import os

from config import TRAIN_JSON_DIR, TRAIN_IMG_DIR, BATCH_SIZE, LR, EPOCHS, DEVICE, CHECKPOINT_DIR
from dataset import PlantDataset, train_transform, val_transform
from model.efficientnet import MultiTaskEfficientNet

Using device: cuda


In [2]:
# ── 체크포인트 디렉토리 ─────────────────────────────────────
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# ── 데이터로더 생성 ─────────────────────────────────────────
dataset = PlantDataset(TRAIN_JSON_DIR, TRAIN_IMG_DIR, transform=train_transform)
n_val   = int(len(dataset) * 0.2)
n_trn   = len(dataset) - n_val
train_ds, val_ds = random_split(dataset, [n_trn, n_val])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [5]:
# ── 모델·옵티마이저·손실 함수·스케쥴러 설정 ─────────────────────────────────
model     = MultiTaskEfficientNet().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

Loaded pretrained weights for efficientnet-b0


In [6]:
# ── 체크포인트용 변수 ───────────────────────────────────────
best_score = 0.0  # 최고 검증 정확도 평균

# ── 학습 & 검증 루프 ────────────────────────────────────────
for epoch in range(1, EPOCHS+1):
    # ----- train -----
    model.train()
    total_loss = 0.0
    correct_d = correct_g = total_samples = 0

    for imgs, d_labels, g_labels in tqdm(train_loader, desc=f"[Train Epoch {epoch}]"):
        imgs, d_labels, g_labels = imgs.to(DEVICE), d_labels.to(DEVICE), g_labels.to(DEVICE)
        d_logits, g_logits = model(imgs)

        loss = criterion(d_logits, d_labels) + criterion(g_logits, g_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss    += loss.item()
        _, d_pred     = d_logits.max(1)
        _, g_pred     = g_logits.max(1)
        correct_d    += (d_pred == d_labels).sum().item()
        correct_g    += (g_pred == g_labels).sum().item()
        total_samples += imgs.size(0)

    scheduler.step()

    train_loss   = total_loss / len(train_loader)
    train_acc_d  = correct_d / total_samples
    train_acc_g  = correct_g / total_samples
    print(f"Epoch {epoch:02d} ▶ [Train] Loss: {train_loss:.4f}, 질병Acc: {train_acc_d:.4f}, 생장Acc: {train_acc_g:.4f}")

    # ----- validation -----
    model.eval()
    correct_d = correct_g = total_samples = 0
    with torch.no_grad():
        for imgs, d_labels, g_labels in tqdm(val_loader):
            imgs, d_labels, g_labels = imgs.to(DEVICE), d_labels.to(DEVICE), g_labels.to(DEVICE)
            d_logits, g_logits = model(imgs)
            _, d_pred = d_logits.max(1)
            _, g_pred = g_logits.max(1)
            correct_d    += (d_pred == d_labels).sum().item()
            correct_g    += (g_pred == g_labels).sum().item()
            total_samples += imgs.size(0)

    val_acc_d = correct_d / total_samples
    val_acc_g = correct_g / total_samples
    print(f"[Val] 질병 Acc: {val_acc_d:.4f}, 생장 Acc: {val_acc_g:.4f}")

    # ----- checkpoint 저장 -----
    avg_val_acc = (val_acc_d + val_acc_g) / 2
    if avg_val_acc > best_score:
        best_score = avg_val_acc
        ckpt_path = os.path.join(CHECKPOINT_DIR, f'best_epoch{epoch:02d}_acc{avg_val_acc:.4f}.pth')
        torch.save(model.state_dict(), ckpt_path)
        print(f"Saved new best model ▶ {ckpt_path}\n")
    else:
        print()  # 줄바꿈

[Train Epoch 1]: 100%|██████████| 349/349 [01:49<00:00,  3.18it/s]


Epoch 01 ▶ [Train] Loss: 0.5890, 질병Acc: 0.9357, 생장Acc: 0.9553
[Val] 질병 Acc: 0.9810, 생장 Acc: 0.9896
Saved new best model ▶ /home/kyun/25s/Capstone_EE/plantmate/checkpoints/best_epoch01_acc0.9853.pth



[Train Epoch 2]: 100%|██████████| 349/349 [01:33<00:00,  3.72it/s]


Epoch 02 ▶ [Train] Loss: 0.0955, 질병Acc: 0.9814, 생장Acc: 0.9909
[Val] 질병 Acc: 0.9871, 생장 Acc: 0.9917
Saved new best model ▶ /home/kyun/25s/Capstone_EE/plantmate/checkpoints/best_epoch02_acc0.9894.pth



[Train Epoch 3]: 100%|██████████| 349/349 [01:38<00:00,  3.55it/s]


Epoch 03 ▶ [Train] Loss: 0.0614, 질병Acc: 0.9881, 생장Acc: 0.9940
[Val] 질병 Acc: 0.9907, 생장 Acc: 0.9885
Saved new best model ▶ /home/kyun/25s/Capstone_EE/plantmate/checkpoints/best_epoch03_acc0.9896.pth



[Train Epoch 4]: 100%|██████████| 349/349 [01:45<00:00,  3.31it/s]


Epoch 04 ▶ [Train] Loss: 0.0392, 질병Acc: 0.9929, 생장Acc: 0.9962
[Val] 질병 Acc: 0.9882, 생장 Acc: 0.9896



[Train Epoch 5]: 100%|██████████| 349/349 [01:46<00:00,  3.27it/s]


Epoch 05 ▶ [Train] Loss: 0.0259, 질병Acc: 0.9952, 생장Acc: 0.9978
[Val] 질병 Acc: 0.9907, 생장 Acc: 0.9943
Saved new best model ▶ /home/kyun/25s/Capstone_EE/plantmate/checkpoints/best_epoch05_acc0.9925.pth



[Train Epoch 6]: 100%|██████████| 349/349 [01:51<00:00,  3.14it/s]


Epoch 06 ▶ [Train] Loss: 0.0190, 질병Acc: 0.9961, 생장Acc: 0.9983
[Val] 질병 Acc: 0.9896, 생장 Acc: 0.9921



[Train Epoch 7]: 100%|██████████| 349/349 [01:49<00:00,  3.20it/s]


Epoch 07 ▶ [Train] Loss: 0.0133, 질병Acc: 0.9973, 생장Acc: 0.9993
[Val] 질병 Acc: 0.9910, 생장 Acc: 0.9950
Saved new best model ▶ /home/kyun/25s/Capstone_EE/plantmate/checkpoints/best_epoch07_acc0.9930.pth



[Train Epoch 8]: 100%|██████████| 349/349 [01:51<00:00,  3.13it/s]


Epoch 08 ▶ [Train] Loss: 0.0118, 질병Acc: 0.9980, 생장Acc: 0.9987
[Val] 질병 Acc: 0.9903, 생장 Acc: 0.9950



[Train Epoch 9]: 100%|██████████| 349/349 [01:49<00:00,  3.19it/s]


Epoch 09 ▶ [Train] Loss: 0.0083, 질병Acc: 0.9987, 생장Acc: 0.9994
[Val] 질병 Acc: 0.9910, 생장 Acc: 0.9953
Saved new best model ▶ /home/kyun/25s/Capstone_EE/plantmate/checkpoints/best_epoch09_acc0.9932.pth



[Train Epoch 10]: 100%|██████████| 349/349 [01:52<00:00,  3.09it/s]


Epoch 10 ▶ [Train] Loss: 0.0091, 질병Acc: 0.9986, 생장Acc: 0.9995
[Val] 질병 Acc: 0.9910, 생장 Acc: 0.9953



[Train Epoch 11]: 100%|██████████| 349/349 [01:48<00:00,  3.22it/s]


Epoch 11 ▶ [Train] Loss: 0.0076, 질병Acc: 0.9987, 생장Acc: 0.9995
[Val] 질병 Acc: 0.9914, 생장 Acc: 0.9950



[Train Epoch 12]: 100%|██████████| 349/349 [01:47<00:00,  3.24it/s]


Epoch 12 ▶ [Train] Loss: 0.0069, 질병Acc: 0.9990, 생장Acc: 0.9994
[Val] 질병 Acc: 0.9907, 생장 Acc: 0.9946



[Train Epoch 13]: 100%|██████████| 349/349 [01:50<00:00,  3.16it/s]


Epoch 13 ▶ [Train] Loss: 0.0058, 질병Acc: 0.9992, 생장Acc: 0.9996
[Val] 질병 Acc: 0.9892, 생장 Acc: 0.9964



[Train Epoch 14]: 100%|██████████| 349/349 [01:45<00:00,  3.31it/s]


Epoch 14 ▶ [Train] Loss: 0.0053, 질병Acc: 0.9994, 생장Acc: 0.9995
[Val] 질병 Acc: 0.9910, 생장 Acc: 0.9946



[Train Epoch 15]: 100%|██████████| 349/349 [01:46<00:00,  3.27it/s]


Epoch 15 ▶ [Train] Loss: 0.0037, 질병Acc: 0.9994, 생장Acc: 0.9998
[Val] 질병 Acc: 0.9910, 생장 Acc: 0.9946



[Train Epoch 16]: 100%|██████████| 349/349 [01:47<00:00,  3.23it/s]


Epoch 16 ▶ [Train] Loss: 0.0044, 질병Acc: 0.9995, 생장Acc: 0.9996
[Val] 질병 Acc: 0.9907, 생장 Acc: 0.9946



[Train Epoch 17]: 100%|██████████| 349/349 [01:45<00:00,  3.32it/s]


Epoch 17 ▶ [Train] Loss: 0.0032, 질병Acc: 0.9996, 생장Acc: 0.9996
[Val] 질병 Acc: 0.9896, 생장 Acc: 0.9953



[Train Epoch 18]: 100%|██████████| 349/349 [01:45<00:00,  3.31it/s]


Epoch 18 ▶ [Train] Loss: 0.0027, 질병Acc: 0.9996, 생장Acc: 0.9998


KeyboardInterrupt: 