In [None]:
import os
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import timm
from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# --- Configuration ---
BATCH_SIZE = 64
MODEL_NAME = 'vit_tiny_patch16_224'
LEARNING_RATE = 0.001
NUM_EPOCHS = 8
SUBSET_RATIO = 0.1  # Use 10% of data for 10x faster training
WEIGHT_DECAY = 0.01  # L2 regularization

# --- Data Transforms with Augmentation ---
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),  # Augmentation
        transforms.RandomRotation(15),  # Augmentation
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # Augmentation
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.2)  # Augmentation
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# --- Data Loading with Subset ---
train_dir = 'vit_data/train'
val_dir = 'vit_data/val'

# Load full datasets
full_datasets = {x: datasets.ImageFolder(os.path.join('vit_data', x), data_transforms[x])
                  for x in ['train', 'val']}

# Create subsets for faster training
def create_subset(dataset, ratio):
    subset_size = int(len(dataset) * ratio)
    indices = np.random.choice(len(dataset), subset_size, replace=False)
    return Subset(dataset, indices)

image_datasets = {
    'train': create_subset(full_datasets['train'], SUBSET_RATIO),
    'val': create_subset(full_datasets['val'], SUBSET_RATIO)
}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
              for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = full_datasets['train'].classes

print(f"Dataset sizes: {dataset_sizes}")
print(f"Class names: {class_names}")
print(f"Number of classes: {len(class_names)}")

Dataset sizes: {'train': 27564, 'val': 5900}
Class names: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy']
Number of classes: 17


In [None]:
# --- Model Setup with Dropout ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create ViT model with dropout
model = timm.create_model(MODEL_NAME, pretrained=True, drop_rate=0.2)

# Freeze backbone
for param in model.parameters():
    param.requires_grad = False

# Replace head with dropout layer
num_classes = len(class_names)
num_ftrs = model.head.in_features
model.head = nn.Sequential(
    nn.Dropout(0.3),  # Additional dropout for regularization
    nn.Linear(num_ftrs, num_classes)
)
model = model.to(device)

# Optimizer with weight decay (L2 regularization) & Loss with label smoothing
optimizer = optim.AdamW(model.head.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label smoothing regularization

print(f"Model initialized with regularization. Training only the head.")

Using device: cpu
Model initialized. Training only the head.


In [None]:
# --- Training Function with History Tracking ---
def train_model(model, criterion, optimizer, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # History for plotting
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in tqdm(dataloaders[phase], desc=f'{phase} Phase'):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    probabilities = torch.softmax(outputs, dim=1)
                    predicted_classes = torch.argmax(probabilities, dim=1)
                    
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(predicted_classes == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # Store history
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model, history

# --- Start Training ---
model_ft, history = train_model(model, criterion, optimizer, num_epochs=NUM_EPOCHS)

Epoch 0/7
----------


train Phase:   0%|          | 0/216 [00:00<?, ?it/s]

train Loss: 0.5961 Acc: 0.8370


val Phase:   0%|          | 0/47 [00:00<?, ?it/s]

val Loss: 0.1642 Acc: 0.9537

Epoch 1/7
----------


train Phase:   0%|          | 0/216 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# --- Save the Model ---
model_save_path = "vit_model.pth"
torch.save(model_ft.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")


Model saved to vit_model.pth


In [None]:
# --- Test the Model ---
test_dir = 'vit_data/test'
test_dataset = datasets.ImageFolder(test_dir, data_transforms['val'])
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Evaluate the model on the test set
model_ft.eval()
test_corrects = 0
test_total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc='Testing Phase'):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model_ft(inputs)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_classes = torch.argmax(probabilities, dim=1)

        test_corrects += torch.sum(predicted_classes == labels.data).item()
        test_total += labels.size(0)

test_acc = test_corrects / test_total
print(f'Test Accuracy: {test_acc:.4f}')

Testing Phase:   0%|          | 0/186 [00:00<?, ?it/s]

Test Accuracy: 0.0044


In [None]:
# --- Plot Training History ---
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot Loss
ax1.plot(history['train_loss'], label='Train Loss', marker='o', linewidth=2)
ax1.plot(history['val_loss'], label='Val Loss', marker='s', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot Accuracy
ax2.plot(history['train_acc'], label='Train Accuracy', marker='o', linewidth=2)
ax2.plot(history['val_acc'], label='Val Accuracy', marker='s', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy', fontsize=12)
ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final metrics
print("\n" + "="*50)
print("FINAL TRAINING METRICS")
print("="*50)
print(f"Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Train Acc: {history['train_acc'][-1]:.4f}")
print(f"Final Val Loss: {history['val_loss'][-1]:.4f}")
print(f"Final Val Acc: {history['val_acc'][-1]:.4f}")
print("="*50)

In [None]:
# --- Debugging Test Dataset ---
# Check test dataset size and class names
print(f"Test dataset size: {len(test_dataset)}")
print(f"Test class names: {test_dataset.classes}")

# Inspect predictions and labels during testing
for inputs, labels in test_loader:
    outputs = model_ft(inputs.to(device))
    probabilities = torch.softmax(outputs, dim=1)
    predicted_classes = torch.argmax(probabilities, dim=1)
    print(f"Predicted: {predicted_classes.tolist()}, Actual: {labels.tolist()}")
# --- Confusion Matrix ---

all_preds = []
all_labels = []

for inputs, labels in test_loader:
    outputs = model_ft(inputs.to(device))
    probabilities = torch.softmax(outputs, dim=1)
    predicted_classes = torch.argmax(probabilities, dim=1)
    all_preds.extend(predicted_classes.cpu().numpy())
    all_labels.extend(labels.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=test_dataset.classes, yticklabels=test_dataset.classes)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# --- Fine-Tune the Model ---
for param in model.parameters():
    param.requires_grad = True
optimizer = optim.Adam(model.parameters(), lr=1e-5)
print("Model backbone unfrozen and optimizer updated for fine-tuning.")

Test dataset size: 5924
Test class names: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy']
Predicted: [9, 11, 2, 2, 9, 2, 2, 14, 9, 13, 2, 1, 2, 2, 0, 2, 15, 3, 2, 2, 2, 9, 3, 9, 2, 9, 3, 2, 2, 2, 2, 3], Actual: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Predicted: [2, 1, 9, 9, 2, 1, 1, 11, 1, 2, 15, 2, 2, 2, 2, 11, 2, 2, 1, 2, 2, 2, 1, 1, 2, 15, 2, 3, 2, 2, 3, 3], Actual: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Predicted: [1

KeyboardInterrupt: 