In [None]:
import os, re
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm

# Configuration

In [None]:
checkpoint_dir = 'checkpoints'; os.makedirs(checkpoint_dir, exist_ok=True)
data_root      = 'PlantVillage'
tomato_classes = [
    "Tomato__Tomato_YellowLeaf__Curl_Virus",
    "Tomato__Tomato_mosaic_virus",
    "Tomato__Target_Spot",
    "Tomato_Spider_mites_Two_spotted_spider_mite",
    "Tomato_Septoria_leaf_spot",
    "Tomato_Leaf_Mold",
    "Tomato_Late_blight",
    "Tomato_healthy",
    "Tomato_Early_blight",
    "Tomato_Bacterial_spot"
]
BATCH   = 64
EPOCHS  = 20
LR      = 1e-3
DEVICE  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

# Checkpoint

In [None]:
def get_last_checkpoint(model_name):
    pattern = re.compile(rf"{model_name}_ep(\d+)\.pth")
    epochs_found = []
    for fname in os.listdir(checkpoint_dir):
        m = pattern.match(fname)
        if m:
            epochs_found.append(int(m.group(1)))
    return max(epochs_found) if epochs_found else 0

# Loading Dataset

In [None]:
EXT = {".jpg",".jpeg",".png",".bmp",".tif",".tiff"}
class TomatoDS(Dataset):
    def __init__(self, root, classes, tf=None):
        self.tf = tf
        self.idx = {c:i for i,c in enumerate(classes)}
        self.samples = []
        for c in classes:
            p = os.path.join(root,c)
            for fn in os.listdir(p):
                if os.path.splitext(fn)[1].lower() in EXT:
                    self.samples.append((os.path.join(p,fn), self.idx[c]))
    def __len__(self): return len(self.samples)
    def __getitem__(self,i):
        path,label = self.samples[i]
        img = Image.open(path).convert("RGB")
        if self.tf: img = self.tf(img)
        return img, label

# Tranforms

In [None]:
train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224,scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2,0.2,0.2,0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225]),
])
val_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225]),
])

full_ds = TomatoDS(data_root, tomato_classes, tf=None)
n = len(full_ds)
idx = np.random.RandomState(42).permutation(n)
split = int(0.8*n)
tr_idx, va_idx = idx[:split], idx[split:]

train_ds = Subset(TomatoDS(data_root, tomato_classes, tf=train_tf), tr_idx)
val_ds   = Subset(TomatoDS(data_root, tomato_classes, tf=val_tf),   va_idx)

# **Set num_workers=0** to avoid multiprocessing pickling issues on Windows
train_ld = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                      num_workers=0, pin_memory=True)
val_ld   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False,
                      num_workers=0, pin_memory=True)

# Attention Block

In [None]:
class SEBlock(nn.Module):
    def __init__(self, channels, r=16):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc   = nn.Sequential(
            nn.Linear(channels, channels//r, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels//r, channels,  bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        b,c,_,_ = x.size()
        y = self.pool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x * y

# Model

In [None]:
class TomatoNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        backbone = models.efficientnet_b0(
            weights=models.EfficientNet_B0_Weights.DEFAULT
        )
        self.features   = backbone.features
        self.se         = SEBlock(channels=1280)
        self.pool       = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(1280, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.se(x)
        x = self.pool(x).view(x.size(0), -1)
        return self.classifier(x)

model = TomatoNet(len(tomato_classes)).to(DEVICE)

# Training Loop

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

for ep in range(EPOCHS):
    model.train()
    bar = tqdm(train_ld, desc=f"Epoch {ep+1}/{EPOCHS}", leave=False)
    tl, tc = 0.0, 0
    for imgs, labels in bar:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        tl += loss.item()*imgs.size(0)
        tc += (out.argmax(1)==labels).sum().item()
        bar.set_postfix(loss=tl/((bar.n+1)*BATCH), acc=tc/((bar.n+1)*BATCH))
    scheduler.step()
    torch.save(model.state_dict(),
               os.path.join(checkpoint_dir, f"tomatonet_ep{ep+1}.pth"))

# Final Eval


In [None]:
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for imgs, labels in tqdm(val_ld, desc="Val"):
        imgs = imgs.to(DEVICE)
        logits = model(imgs)
        preds  = logits.argmax(1).cpu()
        all_preds .append(preds)
        all_labels.append(labels)
all_preds  = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()

print("Validation Accuracy:", accuracy_score(all_labels, all_preds))
print("\nClassification Report:\n",
      classification_report(all_labels, all_preds,
                            target_names=tomato_classes))

Val: 100%|█████████████████████████████████████████████████████████████████████████████| 51/51 [00:10<00:00,  4.85it/s]

Validation Accuracy: 0.9962535123321886

Classification Report:
                                              precision    recall  f1-score   support

      Tomato__Tomato_YellowLeaf__Curl_Virus       1.00      1.00      1.00       622
                Tomato__Tomato_mosaic_virus       1.00      1.00      1.00        76
                        Tomato__Target_Spot       0.99      1.00      0.99       273
Tomato_Spider_mites_Two_spotted_spider_mite       1.00      0.99      0.99       332
                  Tomato_Septoria_leaf_spot       1.00      1.00      1.00       368
                           Tomato_Leaf_Mold       1.00      1.00      1.00       186
                         Tomato_Late_blight       1.00      0.99      0.99       388
                             Tomato_healthy       1.00      1.00      1.00       337
                        Tomato_Early_blight       0.98      1.00      0.99       210
                      Tomato_Bacterial_spot       1.00      1.00      1.00       411


