In [None]:
# Setup imports and path
import os, sys, shutil
from pathlib import Path
print('cwd:', os.getcwd())
# ensure project root is importable
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else '.', '..')))
# Import libs we need
import numpy as np
import torch
import torchvision
from PIL import Image
import random
print('python', sys.version.split()[0])
print('torch', torch.__version__)
print('torchvision', torchvision.__version__)
print('PIL', Image.__version__)

In [None]:
# Create a tiny synthetic dataset (images are plain RGB squares)
root = Path('data_synthetic')
if root.exists():
    shutil.rmtree(root)
(root / 'plants').mkdir(parents=True, exist_ok=True)
(root / 'diseases').mkdir(parents=True, exist_ok=True)
# Create 4 species with imbalanced counts
species = {'Cashew': 10, 'Cassava': 4, 'Maize': 2, 'Tomato': 8}  # imbalanced counts
for sp, count in species.items():
    d = root / 'plants' / sp
    d.mkdir(parents=True, exist_ok=True)
    for i in range(count):
        img = Image.new('RGB', (300,300), (random.randint(0,255), random.randint(0,255), random.randint(0,255)))
        img.save(d / f'{sp.lower()}_{i}.jpg')
# Create diseases for Cashew with two disease classes, imbalanced
droot = root / 'diseases' / 'Cashew'
(droot / 'LeafSpot').mkdir(parents=True, exist_ok=True)
(droot / 'Healthy').mkdir(parents=True, exist_ok=True)
# LeafSpot few samples, Healthy many
for i in range(3):
    img = Image.new('RGB', (400,400), (255,0,0))
    img.save(droot / 'LeafSpot' / f'leaf_{i}.jpg')
for i in range(12):
    img = Image.new('RGB', (400,400), (0,255,0))
    img.save(droot / 'Healthy' / f'healthy_{i}.jpg')
print('Created synthetic dataset at', root)

In [None]:
# Use src.data utilities to prepare splits and inspect counts
from src.data import prepare_plant_dataset, prepare_disease_dataset, list_images, make_weighted_sampler
plants = prepare_plant_dataset(str(root / 'plants'))
print('Species meta:', plants['meta'])
for split in ['train','val','test']:
    labs = [lab for _, lab in plants[split]]
    unique, counts = np.unique(labs, return_counts=True)
    print(f'{split} counts ->', dict(zip(unique.tolist(), counts.tolist())))
# disease dataset for Cashew
cashew = prepare_disease_dataset('Cashew', data_dir=str(root / 'diseases'))
print('Cashew disease classes:', cashew['meta']['classes'])
for split in ['train','val','test']:
    labs = [lab for _, lab in cashew[split]]
    unique, counts = np.unique(labs, return_counts=True)
    print(f'Cashew {split} counts ->', dict(zip(unique.tolist(), counts.tolist())))

In [None]:
# Demonstrate WeightedRandomSampler oversampling distribution
from src.data import SimpleImageDataset, get_transforms
train_samples = plants['train']
sampler, class_weights = make_weighted_sampler(train_samples)
print('class_weights (for loss):', class_weights)
# Create dataset with heavier transforms for minority classes
tf_default = get_transforms('plant')
# detect minority classes: simple threshold median rule
labels = [lab for _, lab in train_samples]
uniques, counts = np.unique(labels, return_counts=True)
median = np.median(counts)
transform_map = {}
for cls, cnt in zip(uniques, counts):
    if cnt < median:
        # use disease-style heavy augment for minority plant classes
        transform_map[int(cls)] = get_transforms('disease')
ds = SimpleImageDataset(train_samples, transform=tf_default, transform_map=transform_map)
from torch.utils.data import DataLoader
loader = DataLoader(ds, batch_size=1, sampler=sampler)
# sample 1000 times and count labels
counts_sampled = {}
it = iter(loader)
for i in range(200):
    try:
        xb, yb, p = next(it)
    except StopIteration:
        it = iter(loader)
        xb, yb, p = next(it)
    lab = int(yb.item())
    counts_sampled[lab] = counts_sampled.get(lab, 0) + 1
print('Sampled label distribution (200 draws):', counts_sampled)

In [None]:
# Small forward-pass smoke test with EfficientNet-B0
from src.model import load_efficientnet_b0
model = load_efficientnet_b0(num_classes=4, pretrained=False)
model.eval()
# get one batch from loader
it = iter(loader)
xb, yb, paths = next(it)
with torch.no_grad():
    out = model(xb)
print('out shape:', out.shape)

In [None]:
# Run a 1-epoch head + 1-epoch finetune two-step training smoke test (very small)
from src.train import train_two_step
# build tiny dataloaders dict using the sampler and small batch
val_ds = SimpleImageDataset(plants['val'], transform=tf_default)
test_ds = SimpleImageDataset(plants['test'], transform=tf_default)
dls = {
    'train': DataLoader(ds, batch_size=4, sampler=sampler),
    'val': DataLoader(val_ds, batch_size=4, shuffle=False),
    'test': DataLoader(test_ds, batch_size=4, shuffle=False),
}
device = torch.device('cpu')
model_small = load_efficientnet_b0(num_classes=4, pretrained=False)
model_trained, history = train_two_step(model_small, dls, device, epochs_head=1, epochs_finetune=1, lr_head=1e-3, lr_ft=1e-4, class_weights=class_weights, mixup_alpha=0.2)
print('history keys:', history.keys())
print('history val_f1:', history.get('val_f1'))

In [None]:
# Quick run using focal loss to ensure code path runs (1 epoch each)
model_small2 = load_efficientnet_b0(num_classes=4, pretrained=False)
model_trained2, history2 = train_two_step(model_small2, dls, device, epochs_head=1, epochs_finetune=1, class_weights=class_weights, use_focal=True, focal_gamma=2.0, mixup_alpha=0.2)
print('focal history val_f1:', history2.get('val_f1'))

## Summary
- This notebook created a small synthetic dataset, demonstrated that `make_weighted_sampler` oversamples minority classes, applied class-specific heavy augmentations for minority labels, and ran a 1-epoch two-step training smoke test with MixUp and focal-loss code paths.
- Next steps: run the same notebook on your real `data/` tree (adjust paths) and increase epochs, enable GPU, and tune hyperparameters (mixup alpha, focal gamma) as needed.