In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os
import copy

# --- CONFIGURATION ---
config = {
    'data_path': r'dataset_split/',
    'model_name': 'eva02_tiny_patch14_224.mim_in22k',
    'batch_size': 16,
    'img_size': 224,
    'weight_decay': 0.01,
    'epochs': 80,
    'num_workers': 3,
    'pin_memory': True,
    'patience': 12,
    'drop_rate': 0.4,
    'drop_path_rate': 0.2,
    'head_lr': 1e-3,
    'head_epochs': 10,
    'full_train_lr': 5e-5,
    'lr_warmup_epochs': 5,
    'label_smoothing': 0.1,
}

In [2]:
# --- DATA PREPARATION, TRAINING, AND VALIDATION FUNCTIONS ---

def get_data_loaders(data_path, img_size, batch_size, num_workers, pin_memory):
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.TrivialAugmentWide(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    val_test_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    train_dataset = datasets.ImageFolder(root=os.path.join(data_path, 'train'), transform=train_transform)
    val_dataset = datasets.ImageFolder(root=os.path.join(data_path, 'val'), transform=val_test_transform)
    test_dataset = datasets.ImageFolder(root=os.path.join(data_path, 'test'), transform=val_test_transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
    class_names = train_dataset.classes
    print(f"Found classes: {class_names}")
    return train_loader, val_loader, test_loader, class_names

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(dataloader.dataset)

def validate_one_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels.data)
    acc = correct_predictions.double() / len(dataloader.dataset)
    return running_loss / len(dataloader.dataset), acc.item()

print("✅ Helper functions are defined.")

✅ Helper functions are defined.


In [3]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load data
train_loader, val_loader, test_loader, class_names = get_data_loaders(
    config['data_path'], config['img_size'], config['batch_size'], config['num_workers'], config['pin_memory']
)
num_classes = len(class_names)

# Load model
model = timm.create_model(
    config['model_name'], pretrained=True, num_classes=num_classes,
    drop_rate=config['drop_rate'], drop_path_rate=config['drop_path_rate']
)
model.to(device)

# Loss function
criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])

print("\n✅ Model and data are ready for training!")

Using device: cuda
Found classes: ['Alapadmam(1)', 'Aralam(1)', 'Ardhachandran(1)', 'Ardhapathaka(1)', 'Bramaram(1)', 'Chandrakala(1)', 'Chaturam(1)', 'Hamsapaksha(1)', 'Hamsasyam(1)', 'Kangulam(1)', 'Kapith(1)', 'Katakamukha_1', 'Katakamukha_2', 'Katakamukha_3', 'Katrimukha(1)', 'Mayura(1)', 'Mrigasirsha(1)', 'Mukulam(1)', 'Mushti(1)', 'Padmakosha(1)', 'Pathaka(1)', 'Sarpasirsha(1)', 'Shukatundam(1)', 'Sikharam(1)', 'Simhamukham(1)', 'Suchi(1)', 'Tamarachudam(1)', 'Tripathaka(1)', 'Trishulam(1)']

✅ Model and data are ready for training!


In [5]:
# --- PHASE 1: HEAD TRAINING ---
print("\n--- Starting Phase 1: Head Training ---")
for param in model.parameters(): param.requires_grad = False
for param in model.head.parameters(): param.requires_grad = True
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=config['head_lr'], weight_decay=config['weight_decay'])
scheduler = CosineAnnealingLR(optimizer, T_max=config['head_epochs'], eta_min=1e-6)
best_val_acc = 0.0
best_model_wts = copy.deepcopy(model.state_dict())

for epoch in range(config['head_epochs']):
    print(f"\nEpoch {epoch+1}/{config['head_epochs']}")
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device)
    scheduler.step()
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(model.state_dict(), 'best_model.pth')
        print("✅ New best model saved!")




--- Starting Phase 1: Head Training ---

Epoch 1/10


Training: 100%|██████████| 593/593 [00:22<00:00, 26.14it/s]
Validating: 100%|██████████| 74/74 [00:11<00:00,  6.41it/s]


Train Loss: 3.4509 | Val Loss: 3.2755 | Val Acc: 0.1090
✅ New best model saved!

Epoch 2/10


Training:   0%|          | 0/593 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [6]:
# --- PHASE 2: FULL NETWORK FINE-TUNING ---
print("\n--- Starting Phase 2: Full Fine-Tuning ---")
for param in model.parameters(): param.requires_grad = True
model.load_state_dict(best_model_wts)
optimizer = optim.AdamW(model.parameters(), lr=config['full_train_lr'], weight_decay=config['weight_decay'])
scheduler = CosineAnnealingLR(optimizer, T_max=config['epochs'] - config['lr_warmup_epochs'], eta_min=1e-6)
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda e: (e+1)/config['lr_warmup_epochs'] if e < config['lr_warmup_epochs'] else 1)
patience_counter = 0

for epoch in range(config['epochs']):
    print(f"\nEpoch {epoch+1}/{config['epochs']}")
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device)
    if epoch < config['lr_warmup_epochs']: warmup_scheduler.step()
    else: scheduler.step()
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print("✅ New best model saved!")
        patience_counter = 0
    else:
        patience_counter += 1
    if patience_counter >= config['patience']:
        print("Early stopping triggered.")
        break

print("\n🎉 Training finished!")


--- Starting Phase 2: Full Fine-Tuning ---

Epoch 1/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.37it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  8.96it/s]


Train Loss: 3.3926 | Val Loss: 2.8749 | Val Acc: 0.2053
✅ New best model saved!

Epoch 2/80


Training: 100%|██████████| 593/593 [00:48<00:00, 12.18it/s]
Validating: 100%|██████████| 74/74 [00:12<00:00,  5.82it/s]


Train Loss: 3.1457 | Val Loss: 2.2534 | Val Acc: 0.4106
✅ New best model saved!

Epoch 3/80


Training: 100%|██████████| 593/593 [00:55<00:00, 10.67it/s]
Validating: 100%|██████████| 74/74 [00:12<00:00,  5.75it/s]


Train Loss: 2.5200 | Val Loss: 1.5938 | Val Acc: 0.6721
✅ New best model saved!

Epoch 4/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.39it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  9.19it/s]


Train Loss: 1.9451 | Val Loss: 1.1706 | Val Acc: 0.8279
✅ New best model saved!

Epoch 5/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.42it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  9.23it/s]


Train Loss: 1.5846 | Val Loss: 0.9799 | Val Acc: 0.8893
✅ New best model saved!

Epoch 6/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.44it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  9.12it/s]


Train Loss: 1.3812 | Val Loss: 0.8804 | Val Acc: 0.9208
✅ New best model saved!

Epoch 7/80


Training: 100%|██████████| 593/593 [00:48<00:00, 12.24it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  9.25it/s]


Train Loss: 1.2181 | Val Loss: 0.8250 | Val Acc: 0.9429
✅ New best model saved!

Epoch 8/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.40it/s]
Validating: 100%|██████████| 74/74 [00:09<00:00,  7.88it/s]


Train Loss: 1.1327 | Val Loss: 0.7978 | Val Acc: 0.9608
✅ New best model saved!

Epoch 9/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.45it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  9.12it/s]


Train Loss: 1.0842 | Val Loss: 0.7885 | Val Acc: 0.9532

Epoch 10/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.40it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  8.73it/s]


Train Loss: 1.0366 | Val Loss: 0.7567 | Val Acc: 0.9625
✅ New best model saved!

Epoch 11/80


Training: 100%|██████████| 593/593 [00:48<00:00, 12.32it/s]
Validating: 100%|██████████| 74/74 [00:07<00:00,  9.28it/s]


Train Loss: 0.9975 | Val Loss: 0.7606 | Val Acc: 0.9625

Epoch 12/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.47it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  9.20it/s]


Train Loss: 0.9861 | Val Loss: 0.7198 | Val Acc: 0.9813
✅ New best model saved!

Epoch 13/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.42it/s]
Validating: 100%|██████████| 74/74 [00:07<00:00,  9.25it/s]


Train Loss: 0.9575 | Val Loss: 0.7284 | Val Acc: 0.9761

Epoch 14/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.51it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  8.94it/s]


Train Loss: 0.9453 | Val Loss: 0.7071 | Val Acc: 0.9838
✅ New best model saved!

Epoch 15/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.53it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  9.16it/s]


Train Loss: 0.9255 | Val Loss: 0.7039 | Val Acc: 0.9855
✅ New best model saved!

Epoch 16/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.43it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  9.18it/s]


Train Loss: 0.8999 | Val Loss: 0.6908 | Val Acc: 0.9864
✅ New best model saved!

Epoch 17/80


Training: 100%|██████████| 593/593 [00:47<00:00, 12.45it/s]
Validating: 100%|██████████| 74/74 [00:08<00:00,  9.12it/s]


Train Loss: 0.9050 | Val Loss: 0.6917 | Val Acc: 0.9864

Epoch 18/80


Training:   1%|          | 4/593 [00:06<16:45,  1.71s/it]  


KeyboardInterrupt: 

In [1]:
# --- FINAL EVALUATION ON TEST SET (WITH IMPROVED PLOTTING) ---
print("\n--- Evaluating on Test Set ---")

# Create a new instance of the model architecture
evaluation_model = timm.create_model(
    config['model_name'], pretrained=False, num_classes=num_classes
)
# Load the best weights saved during training
evaluation_model.load_state_dict(torch.load('best_model.pth'))
evaluation_model.to(device)
evaluation_model.eval()

all_preds, all_labels = [], []
with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Testing"):
        inputs = inputs.to(device)
        outputs = evaluation_model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

print(f"\nFinal Test Accuracy: {accuracy_score(all_labels, all_preds) * 100:.2f}%")
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names, zero_division=0))

# --- Generate Improved Confusion Matrix ---
cm = confusion_matrix(all_labels, all_preds)

# Create a much larger figure to give space for labels
plt.figure(figsize=(20, 18))

# Use Seaborn's heatmap with customizations
sns.heatmap(
    cm, 
    annot=True,          # Show the numbers in the cells
    fmt='d',             # Format as integers
    cmap='Blues',        # Color scheme
    xticklabels=class_names, 
    yticklabels=class_names,
    annot_kws={"size": 12} # Set the font size for the numbers
)

# Rotate labels for better readability
plt.xticks(rotation=90) # Rotate x-axis labels (Predicted)
plt.yticks(rotation=0)  # Keep y-axis labels horizontal (True)

# Add titles and labels
plt.xlabel('Predicted Label', fontsize=14)
plt.ylabel('True Label', fontsize=14)
plt.title('Confusion Matrix - EVA-02 ViT', fontsize=16)

# Use tight_layout to ensure everything fits without overlapping
plt.tight_layout()

# Save the high-quality figure
plt.savefig('confusion_matrix_eva02_improved.png', dpi=300)
print("\nImproved confusion matrix saved to confusion_matrix_eva02_improved.png")
plt.show()


--- Evaluating on Test Set ---


NameError: name 'timm' is not defined