In [None]:
# ✅ 2. Install required packages
!pip install tensorflow scikit-learn matplotlib seaborn pillow opencv-python --quiet


In [4]:
rm -rf sample_data

In [1]:
!git clone https://github.com/Ashfinn/tomato-leaf.git

Cloning into 'tomato-leaf'...
remote: Enumerating objects: 36105, done.[K
remote: Counting objects: 100% (319/319), done.[K
remote: Compressing objects: 100% (316/316), done.[K
remote: Total 36105 (delta 4), reused 315 (delta 3), pack-reused 35786 (from 1)[K
Receiving objects: 100% (36105/36105), 538.25 MiB | 23.20 MiB/s, done.
Resolving deltas: 100% (1373/1373), done.
Updating files: 100% (32035/32035), done.


# Comprehensive list of models for comparative analysis
# Modern CNNs
```
    'poolformer_s12',               # PoolFormer - simple, efficient CNN alternative
    
    # Hybrid CNN-Transformers
    'coat_lite_tiny',               # CoAtNet-Lite - combines conv and self-attention
    'levit_128s',                   # LeViT - efficient hybrid CNN/ViT
    
    # Pure Vision Transformers
    'vit_tiny_patch16_224',         # Vanilla ViT - smallest standard ViT
    'vit_small_patch16_224',        # Vanilla ViT - slightly larger
    'beit_base_patch16_224',        # BEiT - masked image modeling pre-trained
    'crossvit_tiny_240',            # CrossViT - multiple size patches
    'pvt_tiny',                     # Pyramid Vision Transformer
    'twins_pcpvt_base',             # Twins - local-global attention
    'xcit_tiny_12_224',             # XCiT - cross-covariance attention
    'tnt_s_patch16_224',            # TNT - Transformer in Transformer
    
    # # Efficient/Edge Models
    'maxvit_nano_rw_256',           # MaxViT - Max-Attention for efficiency
```


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, SubsetRandomSampler
from PIL import Image
import os
import timm
from tqdm import tqdm

# --- 1. Configuration ---
IMG_SIZE = 256 # Standard input size for most pre-trained models
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Updated list of the 5 chosen models (EXCLUDING MobileNet, EfficientNet, GhostNet, Inception, Xception, ShuffleNet)
MODEL_NAMES = ['maxvit_nano_rw_256']

# --- 2. Data Preparation and Splitting ---

# Ensure DATA_DIR points to the parent folder of your disease subfolders
DATA_DIR = "tomato-leaf/dataset" # Confirmed this is your structure based on screenshot

# Define data transformations for training and validation
# Training transforms include data augmentation to improve model robustness
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet mean/std
])

# Validation transforms do not include augmentation, only necessary resizing and normalization
val_transforms = transforms.Compose([
    transforms.Resize(256),       # Resize to a larger size first
    transforms.CenterCrop(IMG_SIZE), # Then crop the center
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet mean/std
])

# Load the full dataset (without transforms initially, just to get class info and paths)
full_dataset_no_transform = datasets.ImageFolder(root=DATA_DIR)

NUM_CLASSES = len(full_dataset_no_transform.classes)
class_to_idx = full_dataset_no_transform.class_to_idx
idx_to_class = {v: k for k, v in class_to_idx.items()}
print(f"Found {NUM_CLASSES} classes: {list(class_to_idx.keys())}")

# Define the split ratio
TRAIN_SPLIT_RATIO = 0.8

# Get indices for splitting
dataset_size = len(full_dataset_no_transform)
indices = list(range(dataset_size))
# Shuffle indices for a truly random split
import random
random.seed(42) # for reproducibility
random.shuffle(indices)

split_point = int(TRAIN_SPLIT_RATIO * dataset_size)
train_indices, val_indices = indices[:split_point], indices[split_point:]

# Create samplers
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

# Create two ImageFolder instances with their respective transforms, pointing to the same data
train_dataset = datasets.ImageFolder(root=DATA_DIR, transform=train_transforms)
val_dataset = datasets.ImageFolder(root=DATA_DIR, transform=val_transforms)

# Create DataLoaders using the samplers and their respective datasets
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler, num_workers=4, pin_memory=True)

print(f"Data split into {len(train_indices)} training images and {len(val_indices)} validation images.")


# --- 3. Model Loading and Fine-tuning Setup ---
def load_and_prepare_model(model_name, num_classes, pretrained=True):
    # For some transformer models, `pretrained_cfg` might not always be directly available
    # or fully define the transforms needed. It's safer to define your own based on common
    # ImageNet practices as done in train_transforms/val_transforms.
    # We still use pretrained=True to load weights, but manage transforms manually.
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
    model = model.to(DEVICE)
    print(f"Loaded {model_name} with {num_classes} classes.")
    return model

# --- 4. Training Function ---
def train_model(model, train_loader, val_loader, num_epochs, learning_rate):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

    print(f"Starting training for {num_epochs} epochs...")
    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)")
        for images, labels in train_loop:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

            train_loop.set_postfix(loss=loss.item())

        epoch_loss = running_loss / total_samples
        epoch_accuracy = correct_predictions / total_samples
        print(f"Epoch {epoch+1} Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_accuracy:.4f}")

        model.eval()
        val_loss = 0.0
        val_correct_predictions = 0
        val_total_samples = 0
        with torch.no_grad():
            val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Val)")
            for images, labels in val_loop:
                images = images.to(DEVICE)
                labels = labels.to(DEVICE)

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total_samples += labels.size(0)
                val_correct_predictions += (predicted == labels).sum().item()

        val_epoch_loss = val_loss / val_total_samples
        val_epoch_accuracy = val_correct_predictions / val_total_samples
        print(f"Epoch {epoch+1} Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_accuracy:.4f}")

        scheduler.step(val_epoch_loss)

        if val_epoch_accuracy > best_val_accuracy:
            best_val_accuracy = val_epoch_accuracy
            # For saving the model name: model.name is assigned just before calling train_model
            torch.save(model.state_dict(), f"{model.name}_best.pth")
            print(f"Validation accuracy improved. Saving best model for {model.name}.")

    print("Training finished!")
    return model

# --- 5. Inference Function ---
def predict_image(model, image_path, transform, idx_to_class):
    model.eval()

    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        _, predicted_idx = torch.max(probabilities, 1)

    predicted_class = idx_to_class[predicted_idx.item()]
    confidence = probabilities[0][predicted_idx.item()].item()

    return predicted_class, confidence


# --- Main Execution Block ---
if __name__ == "__main__":
    print("\n--- Preparing DataLoaders ---")

    # The data loaders train_loader and val_loader are created above the main loop
    # and will be reused for all models, ensuring consistent splits and transforms.

    for model_name in MODEL_NAMES:
        print(f"\n--- Starting fine-tuning for {model_name} ---")

        model = load_and_prepare_model(model_name, NUM_CLASSES, pretrained=True)
        model.name = model_name # Assign name attribute for saving

        train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.001)

        print(f"Loading best weights for {model_name}...")
        try:
            model.load_state_dict(torch.load(f"{model.name}_best.pth"))
        except FileNotFoundError:
            print(f"Warning: No best model saved for {model.name}. Using last epoch's weights.")

        print(f"\n--- Performing inference with {model.name} ---")
        if len(val_indices) > 0:
            # val_dataset.imgs is a list of (image_path, class_idx) tuples for the full dataset
            # val_indices gives us the indices of images that belong to the validation set
            sample_image_path = val_dataset.imgs[val_indices[0]][0]
            print(f"Using sample image: {os.path.basename(sample_image_path)}") # Print just filename for brevity
            inference_transform = val_transforms

            predicted_class, confidence = predict_image(model, sample_image_path, inference_transform, idx_to_class)
            print(f"Predicted Class: {predicted_class}")
            print(f"Confidence: {confidence:.4f}")
        else:
            print("No validation images available for inference demonstration.")
        print("-" * 50)

Using device: cuda
Found 10 classes: ['Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']
Data split into 12808 training images and 3203 validation images.

--- Preparing DataLoaders ---

--- Starting fine-tuning for maxvit_nano_rw_256 ---




Loaded maxvit_nano_rw_256 with 10 classes.
Starting training for 10 epochs...


Epoch 1/10 (Train): 100%|██████████| 401/401 [03:46<00:00,  1.77it/s, loss=1.18]


Epoch 1 Train Loss: 1.3394, Train Acc: 0.5411


Epoch 1/10 (Val): 100%|██████████| 101/101 [00:18<00:00,  5.52it/s]


Epoch 1 Val Loss: 0.9118, Val Acc: 0.6460
Validation accuracy improved. Saving best model for maxvit_nano_rw_256.


Epoch 2/10 (Train): 100%|██████████| 401/401 [03:46<00:00,  1.77it/s, loss=0.517]


Epoch 2 Train Loss: 0.5956, Train Acc: 0.8033


Epoch 2/10 (Val): 100%|██████████| 101/101 [00:18<00:00,  5.53it/s]


Epoch 2 Val Loss: 0.5373, Val Acc: 0.8002
Validation accuracy improved. Saving best model for maxvit_nano_rw_256.


Epoch 3/10 (Train): 100%|██████████| 401/401 [03:45<00:00,  1.78it/s, loss=0.817]


Epoch 3 Train Loss: 0.4423, Train Acc: 0.8542


Epoch 3/10 (Val): 100%|██████████| 101/101 [00:17<00:00,  5.67it/s]


Epoch 3 Val Loss: 0.3124, Val Acc: 0.8976
Validation accuracy improved. Saving best model for maxvit_nano_rw_256.


Epoch 4/10 (Train): 100%|██████████| 401/401 [03:45<00:00,  1.78it/s, loss=0.397]


Epoch 4 Train Loss: 0.3631, Train Acc: 0.8799


Epoch 4/10 (Val): 100%|██████████| 101/101 [00:17<00:00,  5.66it/s]


Epoch 4 Val Loss: 0.2229, Val Acc: 0.9204
Validation accuracy improved. Saving best model for maxvit_nano_rw_256.


Epoch 5/10 (Train):  56%|█████▌    | 223/401 [02:06<01:39,  1.79it/s, loss=0.487]