In [1]:
import torch
import time
import random
from torch.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms.functional as F
from torchvision.datasets import VOCDetection
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

NUM_EPOCHS = 5
BATCH_SIZE = 4
PATIENCE = 2
NUM_WORKERS = 4
PRINT_INTERVAL = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Config => epochs: {NUM_EPOCHS}, batch_size: {BATCH_SIZE}, workers: {NUM_WORKERS}, patience: {PATIENCE}")
print(f"Using device: {DEVICE}")

Config => epochs: 5, batch_size: 4, workers: 4, patience: 2
Using device: cuda


In [None]:
PLANT_DICT = {'Apple Scab Leaf': 0, 'Apple leaf': 1, 'Apple rust leaf': 2, 'Bell_pepper leaf': 3, 'Bell_pepper leaf spot': 4, 'Blueberry leaf': 5, 'Cherry leaf': 6, 'Corn Gray leaf spot': 7, 'Corn leaf blight': 8, 'Corn rust leaf': 9, 'Peach leaf': 10, 'Potato leaf': 11, 'Potato leaf early blight': 12, 'Potato leaf late blight': 13, 'Raspberry leaf': 14, 'Soyabean leaf': 15, 'Squash Powdery mildew leaf': 16, 'Strawberry leaf': 17, 'Tomato Early blight leaf': 18, 'Tomato Septoria leaf spot': 19, 'Tomato leaf': 20, 'Tomato leaf bacterial spot': 21, 'Tomato leaf late blight': 22, 'Tomato leaf mosaic virus': 23, 'Tomato leaf yellow virus': 24, 'Tomato mold leaf': 25, 'Tomato two spotted spider mites leaf': 26, 'grape leaf': 27, 'grape leaf black rot': 28}

PLANT_CLASSES = list(PLANT_DICT.keys())

In [None]:
def prepare_sample(img, target):
    # Random horizontal flip
    if random.random() < 0.5:
        w = img.width
        img = F.hflip(img)
        # Ensure 'object' is a list
        objs = target['annotation']['object']
        if not isinstance(objs, list): objs = [objs]
        for o in objs:
            # Update bndbox coordinates after flip
            xmin, xmax = int(o['bndbox']['xmin']), int(o['bndbox']['xmax'])
            o['bndbox']['xmin'] = w - xmax
            o['bndbox']['xmax'] = w - xmin
    # Random color jitter
    if random.random() < 0.5:
        img = F.adjust_brightness(img, random.uniform(0.8, 1.2))
        img = F.adjust_contrast(img,   random.uniform(0.8, 1.2))
        img = F.adjust_saturation(img, random.uniform(0.8, 1.2))
    # To tensor & parse boxes
    img_tensor = F.to_tensor(img)
    
    # Ensure 'object' is a list for robust parsing
    objs = target['annotation']['object']
    if not isinstance(objs, list): objs = [objs]
    
    # Extract boxes
    boxes = torch.tensor([
        [int(o['bndbox']['xmin']), int(o['bndbox']['ymin']),
         int(o['bndbox']['xmax']), int(o['bndbox']['ymax'])]
        for o in objs
    ], dtype=torch.float32)
    
    # Extract labels (using PLANT_CLASSES)
    labels = torch.tensor([
        PLANT_CLASSES.index(o['name']) + 1 for o in objs # +1 for background class
    ], dtype=torch.int64)
    return img_tensor, {"boxes": boxes, "labels": labels}

def collate_fn(batch):
    return tuple(zip(*batch))