# Diseases Code

In [1]:
import os
import random
from pathlib import Path
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import timm
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
import numpy as np
from PIL import Image

# ******** CONFIG ********
DATA_DIR = "data"
TRAIN_DIR = os.path.join(DATA_DIR, "train")
VAL_DIR = os.path.join(DATA_DIR, "valid")
MODEL_SAVE = "best_wheat_model.pth"
BATCH_SIZE = 32
IMG_SIZE = 224
EPOCHS = 20
LR = 1e-4
WEIGHT_DECAY = 1e-5
MODEL_NAME = "efficientnet_b0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_WORKERS = 4
PATIENCE = 4  # early stopping patience
# ***********************

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if DEVICE.startswith("cuda"):
        torch.cuda.manual_seed_all(seed)

set_seed()

# Transforms
train_tf = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

val_tf = transforms.Compose([
    transforms.Resize(int(IMG_SIZE*1.1)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

# Datasets
train_ds = datasets.ImageFolder(TRAIN_DIR, transform=train_tf)
val_ds = datasets.ImageFolder(VAL_DIR, transform=val_tf)
CLASSES = train_ds.classes
NUM_CLASSES = len(CLASSES)
print(f"Classes: {CLASSES}")

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# Model
def build_model(num_classes):
    model = timm.create_model(MODEL_NAME, pretrained=True)
    # replace classifier
    if hasattr(model, 'classifier'):
        in_features = model.classifier.in_features
        model.classifier = nn.Linear(in_features, num_classes)
    elif hasattr(model, 'fc'):
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    else:
        # generic
        model.reset_classifier(num_classes)
    return model

model = build_model(NUM_CLASSES).to(DEVICE)

# Loss, optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

# Training & validation loops
def validate(model, loader):
    model.eval()
    preds = []
    trues = []
    total_loss = 0.0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            out = model(xb)
            loss = criterion(out, yb)
            total_loss += loss.item() * xb.size(0)
            preds.extend(torch.argmax(out, dim=1).cpu().numpy())
            trues.extend(yb.cpu().numpy())
    avg_loss = total_loss / len(loader.dataset)
    acc = accuracy_score(trues, preds)
    f1 = f1_score(trues, preds, average='weighted')
    return avg_loss, acc, f1, trues, preds

best_val_loss = float('inf')
no_improve = 0

for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for xb, yb in pbar:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
        pbar.set_postfix(loss=loss.item())
    train_loss = running_loss / len(train_loader.dataset)
    val_loss, val_acc, val_f1, trues, preds = validate(model, val_loader)
    scheduler.step(val_loss)
    print(f"Epoch {epoch} summary — train_loss: {train_loss:.4f} val_loss: {val_loss:.4f} val_acc: {val_acc:.4f} val_f1: {val_f1:.4f}")

    # save best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'classes': CLASSES,
            'img_size': IMG_SIZE,
        }, MODEL_SAVE)
        print("Saved best model.")
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print("Early stopping triggered.")
            break

# Final evaluation and report
print("Loading best model for final evaluation.")
ckpt = torch.load(MODEL_SAVE, map_location=DEVICE)
model.load_state_dict(ckpt['model_state'])
val_loss, val_acc, val_f1, trues, preds = validate(model, val_loader)
print(f"Final val_loss: {val_loss:.4f} val_acc: {val_acc:.4f} val_f1: {val_f1:.4f}")
print("Classification report:")
print(classification_report(trues, preds, target_names=CLASSES))
print("Confusion matrix:")
print(confusion_matrix(trues, preds))

# Single image inference utility
def predict_image(path, model, classes, transform):
    model.eval()
    img = Image.open(path).convert('RGB')
    x = transform(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        out = model(x)
        probs = torch.softmax(out, dim=1).cpu().numpy()[0]
        idx = int(out.argmax(dim=1).cpu().numpy()[0])
    return {'class': classes[idx], 'prob': float(probs[idx]), 'probs': probs.tolist()}

if __name__ == "__main__":
    # example inference. set image path
    test_img = "test_samples/sample1.jpg"
    if os.path.exists(test_img):
        result = predict_image(test_img, model, CLASSES, val_tf)
        print("Prediction:", result)
    else:
        print(f"Training complete. Model saved to {MODEL_SAVE}. To test inference, provide an image at {test_img}")

Classes: ['Aphid', 'Black Rust', 'Blast', 'Brown Rust', 'Common Root Rot', 'Fusarium Head Blight', 'Healthy', 'Leaf Blight', 'Mildew', 'Mite', 'Septoria', 'Smut', 'Stem fly', 'Tan spot', 'Yellow Rust']


Epoch 1/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [28:30<00:00,  4.17s/it, loss=0.365]


Epoch 1 summary — train_loss: 1.1813 val_loss: 1.1314 val_acc: 0.7000 val_f1: 0.6800
Saved best model.


Epoch 2/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [29:05<00:00,  4.26s/it, loss=0.797]


Epoch 2 summary — train_loss: 0.5279 val_loss: 0.8592 val_acc: 0.7833 val_f1: 0.7675
Saved best model.


Epoch 3/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [29:43<00:00,  4.35s/it, loss=0.678]


Epoch 3 summary — train_loss: 0.3747 val_loss: 0.7609 val_acc: 0.8467 val_f1: 0.8310
Saved best model.


Epoch 4/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [32:06<00:00,  4.70s/it, loss=0.161]


Epoch 4 summary — train_loss: 0.2932 val_loss: 0.6338 val_acc: 0.8900 val_f1: 0.8757
Saved best model.


Epoch 5/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [30:59<00:00,  4.53s/it, loss=0.294]


Epoch 5 summary — train_loss: 0.2265 val_loss: 0.6555 val_acc: 0.8767 val_f1: 0.8646


Epoch 6/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [28:49<00:00,  4.22s/it, loss=0.331]


Epoch 6 summary — train_loss: 0.1899 val_loss: 0.6725 val_acc: 0.8833 val_f1: 0.8660


Epoch 7/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [28:48<00:00,  4.22s/it, loss=0.447]


Epoch 7 summary — train_loss: 0.1608 val_loss: 0.5944 val_acc: 0.9000 val_f1: 0.8830
Saved best model.


Epoch 8/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [28:50<00:00,  4.22s/it, loss=0.1]


Epoch 8 summary — train_loss: 0.1469 val_loss: 0.6695 val_acc: 0.8967 val_f1: 0.8774


Epoch 9/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [29:17<00:00,  4.29s/it, loss=0.227]


Epoch 9 summary — train_loss: 0.1299 val_loss: 0.6030 val_acc: 0.9233 val_f1: 0.9115


Epoch 10/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [32:03<00:00,  4.69s/it, loss=0.188]


Epoch 10 summary — train_loss: 0.1185 val_loss: 0.6982 val_acc: 0.9167 val_f1: 0.9022


Epoch 11/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [31:54<00:00,  4.67s/it, loss=0.165]


Epoch 11 summary — train_loss: 0.0966 val_loss: 0.6280 val_acc: 0.9267 val_f1: 0.9105
Early stopping triggered.
Loading best model for final evaluation.




Final val_loss: 0.5944 val_acc: 0.9000 val_f1: 0.8830
Classification report:
                      precision    recall  f1-score   support

               Aphid       0.95      0.95      0.95        20
          Black Rust       1.00      1.00      1.00        20
               Blast       1.00      0.95      0.97        20
          Brown Rust       0.94      0.80      0.86        20
     Common Root Rot       1.00      0.95      0.97        20
Fusarium Head Blight       0.91      1.00      0.95        20
             Healthy       0.33      0.05      0.09        20
         Leaf Blight       0.90      0.95      0.93        20
              Mildew       0.95      1.00      0.98        20
                Mite       1.00      0.90      0.95        20
            Septoria       1.00      1.00      1.00        20
                Smut       0.95      1.00      0.98        20
            Stem fly       1.00      1.00      1.00        20
            Tan spot       0.90      0.95      0.93   

# Inference Code

In [None]:
import os
import torch
import timm
import numpy as np
from PIL import Image
from torchvision import transforms

# ========== CONFIG ==========
MODEL_PATH = "best_wheat_model.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "efficientnet_b0"
# =============================

# ----- Load checkpoint -----
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
CLASSES = ckpt["classes"]
IMG_SIZE = ckpt["img_size"]

# ----- Build same model -----
def build_model(num_classes):
    model = timm.create_model(MODEL_NAME, pretrained=False)
    if hasattr(model, 'classifier'):
        in_features = model.classifier.in_features
        model.classifier = torch.nn.Linear(in_features, num_classes)
    elif hasattr(model, 'fc'):
        in_features = model.fc.in_features
        model.fc = torch.nn.Linear(in_features, num_classes)
    else:
        model.reset_classifier(num_classes)
    return model

model = build_model(len(CLASSES)).to(DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

# ----- Transform (must match training) -----
val_tf = transforms.Compose([
    transforms.Resize(int(IMG_SIZE * 1.1)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# ----- Prediction function -----
def predict_image(image_path):
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"{image_path} not found.")
    img = Image.open(image_path).convert("RGB")
    x = val_tf(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        out = model(x)
        probs = torch.softmax(out, dim=1).cpu().numpy()[0]
        idx = int(np.argmax(probs))
        return {
            "predicted_class": CLASSES[idx],
            "confidence": float(probs[idx]),
            "class_probabilities": dict(zip(CLASSES, [float(p) for p in probs]))
        }

if __name__ == "__main__":
    test_dir = "/home/cipher/Documents/Diseases/data/test/aphid_test"
    os.makedirs(test_dir, exist_ok=True)
    print(f"Testing on: {test_dir}")
    print("Classes:", CLASSES)

    found = False
    for fname in os.listdir(test_dir):
        if fname.lower().endswith((".jpg", ".jpeg", ".png")):
            found = True
            path = os.path.join(test_dir, fname)
            result = predict_image(path)
            print(f"{fname}: {result['predicted_class']} ({result['confidence']:.3f})")
    if not found:
        print("No image files found in test_dir.")

Testing on: /home/cipher/Documents/Diseases/data/test/aphid_test
Classes: ['Aphid', 'Black Rust', 'Blast', 'Brown Rust', 'Common Root Rot', 'Fusarium Head Blight', 'Healthy', 'Leaf Blight', 'Mildew', 'Mite', 'Septoria', 'Smut', 'Stem fly', 'Tan spot', 'Yellow Rust']
aphid_69.png: Brown Rust (0.337)
aphid_66.png: Aphid (1.000)
aphid_30.png: Aphid (0.988)
aphid_61.png: Aphid (0.999)
aphid_65.png: Aphid (0.999)
aphid_38.png: Aphid (1.000)
aphid_44.png: Aphid (0.990)
aphid_63.png: Aphid (0.966)
aphid_34.png: Aphid (0.999)
aphid_40.png: Aphid (1.000)
aphid_41.png: Aphid (0.806)
aphid_60.png: Aphid (1.000)
aphid_55.png: Aphid (0.985)
aphid_71.png: Aphid (0.994)
aphid_27.png: Aphid (0.998)
aphid_70.png: Aphid (1.000)
aphid_57.png: Aphid (0.991)
aphid_54.png: Aphid (0.995)
aphid_49.png: Aphid (0.633)
aphid_37.png: Aphid (1.000)
aphid_28.png: Aphid (1.000)
aphid_72.png: Aphid (0.999)
aphid_39.png: Aphid (0.979)
aphid_73.png: Aphid (0.899)
aphid_76.png: Aphid (0.999)
aphid_51.png: Aphid (0.996)
