In [3]:
import sys
sys.path.insert(0, '../src')

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Subset
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import json
import time

from plant_care_ai.data.dataset import PlantNetDataset
from plant_care_ai.data.preprocessing import get_training_pipeline, get_inference_pipeline
from plant_care_ai.models.resnet18 import Resnet18
from plant_care_ai.models.effecientnetv2 import create_efficientnetv2

In [4]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.1+cu128
CUDA available: False


In [11]:
CONFIG = {
    "model": "efficientnetv2",
    "variant": "b0",
    
    "subset_classes": 50, #how many classes
    "subset_samples_per_class": 50, # max samples/ class
    "train_samples_per_class": 50,
    "val_samples_per_class": 15,
    
    "batch_size": 32,
    "epochs": 15,
    "lr": 0.001,
    "weight_decay": 0.01,
    "label_smoothing": 0.1,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    
    "data_dir": "../data/plantnet_300K",
    "img_size": 224,
    "augm_strength": 0.7,
    
    "checkpoint_dir": "../checkpoints/notebook_exp",
    "experiment_name": f"exp_{int(time.time())}",
}

In [6]:
train_tfm = get_training_pipeline(CONFIG["img_size"], CONFIG["augm_strength"])
val_tfm = get_inference_pipeline(CONFIG["img_size"])

#loading full datasets
train_dataset = PlantNetDataset(CONFIG["data_dir"], "train", train_tfm)
val_dataset = PlantNetDataset(CONFIG["data_dir"], "val", val_tfm)

#loading its subsets
all_classes = sorted(train_dataset.classes)
selected_classes = all_classes[:CONFIG["subset_classes"]]

print(selected_classes)


Loaded 243916 samples, 1081 classes
Loaded 31118 samples, 1081 classes
['1355868', '1355920', '1355932', '1355936', '1355937', '1355955', '1355959', '1355961', '1355978', '1355990', '1356003', '1356022', '1356037', '1356055', '1356075', '1356076', '1356111', '1356126', '1356138', '1356257', '1356278', '1356279', '1356309', '1356379', '1356380', '1356382', '1356420', '1356421', '1356428', '1356469', '1356692', '1356781', '1356816', '1356847', '1356901', '1357330', '1357331', '1357367', '1357379', '1357506', '1357635', '1357652', '1357677', '1357681', '1357682', '1357705', '1358094', '1358095', '1358096', '1358097']


In [14]:
train_indices = []
train_class_counts = {}

for idx, (_, species_id) in enumerate(train_dataset.paths):
    if species_id not in selected_classes:
        continue
    count = train_class_counts.get(species_id, 0)
    if count < CONFIG["train_samples_per_class"]:
        train_indices.append(idx)
        train_class_counts[species_id] = count + 1

val_indices = []
val_class_counts = {}

for idx, (_, species_id) in enumerate(val_dataset.paths):
    if species_id not in selected_classes:
        continue
    count = val_class_counts.get(species_id, 0)
    if count < CONFIG["val_samples_per_class"]:
        val_indices.append(idx)
        val_class_counts[species_id] = count + 1

train_subset = Subset(train_dataset, train_indices)
val_subset = Subset(val_dataset, val_indices)

print(f"Train samples: {len(train_subset)}")
print(f"Val samples: {len(val_subset)}")

train_loader = DataLoader(
    train_subset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=2,
    pin_memory=True,
)

val_loader = DataLoader(
    val_subset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

Train samples: 1915
Val samples: 520
Train batches: 60
Val batches: 17


In [16]:
num_classes = len(selected_classes)
print(num_classes)

if CONFIG["model"] == "resnet18":
    model = Resnet18(num_classes=num_classes)
    model_name = "ResNet18"
elif CONFIG["model"] == "efficientnetv2":
    model = create_efficientnetv2(
        variant=CONFIG["variant"],
        num_classes=num_classes,
    )
    model_name = f"EfficientNetV2-{CONFIG['variant'].upper()}"
else:
    raise ValueError(f"Unknown model: {CONFIG['model']}")

model = model.to(CONFIG["device"])

# get the stats
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel: {model_name}")
print(f"\tParameters: {total_params:,}")
print(f"\tTrainable parameters: {trainable_params:,}")
print(f"\tModel size: {total_params * 4 / (1024**2):.2f} MB")
print(f"\tDevice: {CONFIG['device']}")

50

Model: EfficientNetV2-B0
	Parameters: 69,770,010
	Trainable parameters: 69,770,010
	Model size: 266.15 MB
	Device: cpu


In [17]:
criterion = nn.CrossEntropyLoss(label_smoothing=CONFIG["label_smoothing"])
optimizer = AdamW(
    model.parameters(),
    lr=CONFIG["lr"],
    weight_decay=CONFIG["weight_decay"],
)
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=CONFIG["epochs"],
    eta_min=1e-6,
)

print(f"\nOptimizer: AdamW")
print(f"\tLearning rate: {CONFIG['lr']}")
print(f"\tWeight decay: {CONFIG['weight_decay']}")
print(f"\nScheduler: CosineAnnealingLR")
print(f"\tT_max: {CONFIG['epochs']}")
print(f"\nLoss: CrossEntropyLoss")
print(f"\tLabel smoothing: {CONFIG['label_smoothing']}")


Optimizer: AdamW
	Learning rate: 0.001
	Weight decay: 0.01

Scheduler: CosineAnnealingLR
	T_max: 15

Loss: CrossEntropyLoss
	Label smoothing: 0.1


In [18]:
checkpoint_dir = Path(CONFIG["checkpoint_dir"]) / CONFIG["experiment_name"]
checkpoint_dir.mkdir(parents=True, exist_ok=True)
print(f"\nCheckpoint dir: {checkpoint_dir}")

with open(checkpoint_dir / "config.json", "w") as f:
    json.dump(CONFIG, f, indent=2)

history = {
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": [],
    "val_top5": [],
    "lr": [],
}

best_acc = 0.0


Checkpoint dir: ../checkpoints/notebook_exp/exp_1767623821


In [19]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc="Train", leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        _, pred = outputs.max(1)
        total += labels.size(0)
        correct += pred.eq(labels).sum().item()
        
        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "acc": f"{100.*correct/total:.1f}%"
        })
    
    return total_loss / len(loader), 100.0 * correct / total


In [20]:
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(loader, desc="Val", leave=False)
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            
            # top 1
            _, pred = outputs.max(1)
            total += labels.size(0)
            correct_top1 += pred.eq(labels).sum().item()
            
            # top 5 res
            _, top5 = outputs.topk(min(5, outputs.size(1)), 1, True, True)
            correct_top5 += top5.eq(labels.view(-1, 1).expand_as(top5)).sum().item()
            
            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "acc": f"{100.*correct_top1/total:.1f}%"
            })
    
    avg_loss = total_loss / len(loader)
    top1_acc = 100.0 * correct_top1 / total
    top5_acc = 100.0 * correct_top5 / total
    
    return avg_loss, top1_acc, top5_acc

In [21]:
print(f"Model: {model_name}")
print(f"Epochs: {CONFIG['epochs']}")
print(f"Dataset: {len(train_subset)} train, {len(val_subset)} val")

start_time = time.time()

for epoch in range(1, CONFIG["epochs"] + 1):
    print(f"Epoch {epoch}/{CONFIG['epochs']}")
    
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, CONFIG["device"]
    )
    
    val_loss, val_acc, val_top5 = validate(
        model, val_loader, criterion, CONFIG["device"]
    )
    
    scheduler.step()
    current_lr = optimizer.param_groups[0]["lr"]
    
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)
    history["val_top5"].append(val_top5)
    history["lr"].append(current_lr)
    
    print(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
    print(f"Val:   Loss={val_loss:.4f}, Top-1={val_acc:.2f}%, Top-5={val_top5:.2f}%")
    print(f"LR: {current_lr:.6f}")
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_acc": val_acc,
            "val_top5": val_top5,
            "config": CONFIG,
        }, checkpoint_dir / "best.pth")
        print(f"Model saved: (acc: {val_acc:.2f}%)")

torch.save({
    "epoch": CONFIG["epochs"],
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "history": history,
    "config": CONFIG,
}, checkpoint_dir / "last.pth")

with open(checkpoint_dir / "history.json", "w") as f:
    json.dump(history, f, indent=2)

Model: EfficientNetV2-B0
Epochs: 15
Dataset: 1915 train, 520 val
Epoch 1/15


Train:   0%|          | 0/60 [00:00<?, ?it/s]



KeyboardInterrupt: 

In [23]:
# interference phase
model.eval()

# get a batch from validation
test_images, test_labels = next(iter(val_loader))
test_images = test_images[:8].to(CONFIG["device"])
test_labels = test_labels[:8]

with torch.no_grad():
    outputs = model(test_images)
    probs = torch.softmax(outputs, dim=1)
    confidences, predictions = torch.max(probs, dim=1)

print("\nSample predictions:")
print(f"{'#':<3} {'True':<8} {'Pred':<8} {'Conf':<8} {'Status'}")
print("-" * 40)

correct = 0
for i in range(len(test_labels)):
    true_label = test_labels[i].item()
    pred_label = predictions[i].item()
    confidence = confidences[i].item() * 100
    
    status = "Correct" if true_label == pred_label else "Wrong"
    if true_label == pred_label:
        correct += 1
    
    print(f"{i+1:<3} {true_label:<8} {pred_label:<8} {confidence:>6.1f}% {status}")

print(f"\nBatch accuracy: {100*correct/len(test_labels):.1f}%")


Sample predictions:
#   True     Pred     Conf     Status
----------------------------------------
1   5        0           2.0% Wrong
2   5        0           2.0% Wrong
3   9        0           2.0% Wrong
4   9        0           2.0% Wrong
5   9        0           2.0% Wrong
6   9        0           2.0% Wrong
7   9        0           2.0% Wrong
8   9        0           2.0% Wrong

Batch accuracy: 0.0%
