In [1]:
import os
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm
from tqdm.notebook import tqdm

# --- Configuration ---
BATCH_SIZE = 32
MODEL_NAME = 'vit_tiny_patch16_224'
LEARNING_RATE = 0.001
NUM_EPOCHS = 10

# --- Data Transforms ---
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# --- Data Loading ---
train_dir = 'vit_data/train'
val_dir = 'vit_data/val'

image_datasets = {x: datasets.ImageFolder(os.path.join('vit_data', x), data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
              for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

print(f"Dataset sizes: {dataset_sizes}")
print(f"Class names: {class_names}")
print(f"Number of classes: {len(class_names)}")

Dataset sizes: {'train': 25431, 'val': 5525}
Class names: ['Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy']
Number of classes: 16


In [2]:
# --- Model Setup ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create ViT model
model = timm.create_model(MODEL_NAME, pretrained=True)

# Freeze backbone
for param in model.parameters():
    param.requires_grad = False

# Replace head
num_classes = len(class_names)
num_ftrs = model.head.in_features
model.head = nn.Linear(num_ftrs, num_classes)
model = model.to(device)

# Optimizer (only for head) & Loss
optimizer = optim.Adam(model.head.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

print(f"Model initialized. Training only the head.")

Using device: cpu
Model initialized. Training only the head.


In [3]:
# --- Training Function ---
def train_model(model, criterion, optimizer, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in tqdm(dataloaders[phase], desc=f'{phase} Phase'):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    # Simple prediction logic
                    probabilities = torch.softmax(outputs, dim=1)
                    predicted_classes = torch.argmax(probabilities, dim=1)
                    
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(predicted_classes == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model

# --- Start Training ---
model_ft = train_model(model, criterion, optimizer, num_epochs=NUM_EPOCHS)


Epoch 0/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.2635 Acc: 0.9234


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0910 Acc: 0.9732

Epoch 1/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.0676 Acc: 0.9787


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0720 Acc: 0.9772

Epoch 2/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.0493 Acc: 0.9849


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0586 Acc: 0.9803

Epoch 3/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.0410 Acc: 0.9875


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0576 Acc: 0.9803

Epoch 4/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.0368 Acc: 0.9876


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0543 Acc: 0.9832

Epoch 5/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.0350 Acc: 0.9875


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0640 Acc: 0.9806

Epoch 6/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.0326 Acc: 0.9884


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0515 Acc: 0.9830

Epoch 7/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.0308 Acc: 0.9890


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0535 Acc: 0.9799

Epoch 8/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.0289 Acc: 0.9897


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0624 Acc: 0.9817

Epoch 9/9
----------


train Phase:   0%|          | 0/795 [00:00<?, ?it/s]

train Loss: 0.0268 Acc: 0.9905


val Phase:   0%|          | 0/173 [00:00<?, ?it/s]

val Loss: 0.0582 Acc: 0.9819

Training complete in 300m 31s
Best val Acc: 0.983167


In [4]:
# --- Save the Model ---
model_save_path = "vit_model.pth"
torch.save(model_ft.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to vit_model.pth
