In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm
import matplotlib.pyplot as plt
import numpy as np

# Set device and hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 16
num_epochs = 10
learning_rate = 0.0001
num_classes = 100

# Data preparation with resize to match ViT input size (224x224)
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),           # Resize to 224x224
    transforms.RandomCrop(224, padding=4),    # Crop to 224x224 after padding
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),            # Resize to 224x224
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
])

train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1, persistent_workers=True)

# Initialize the Vision Transformer model
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Track training accuracy for plotting
train_accuracies = []

In [None]:
# Training function with per-epoch training accuracy calculation
def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        correct_train = 0
        total_train = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            _, predicted = outputs.max(1)
            correct_train += predicted.eq(labels).sum().item()
            total_train += labels.size(0)
        
        # Calculate and store training accuracy
        train_accuracy = 100 * correct_train / total_train
        train_accuracies.append(train_accuracy)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Accuracy: {train_accuracy:.2f}%")
        
        scheduler.step()

    # Testing after entire training is complete
    model.eval()
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            correct_test += predicted.eq(labels).sum().item()
            total_test += labels.size(0)
    
    test_accuracy = 100 * correct_test / total_test
    print(f"Final Test Accuracy: {test_accuracy:.2f}%")

# Run training
train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs)


In [None]:
print("done")

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(np.arange(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.xlabel('Number of Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Epochs vs Accuracy(%)')
plt.legend()
plt.grid()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Function to display misclassified images
def show_misclassified_images(test_loader, misclassified_indices, misclassified_labels, misclassified_preds):
    for i, idx in enumerate(misclassified_indices):
        image, label = test_loader.dataset[idx]
        image = image.permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C) for plotting
        
        plt.figure(figsize=(3, 3))
        plt.imshow(image.cpu().numpy())
        plt.title(f"True: {misclassified_labels[i]}, Pred: {misclassified_preds[i]}")
        plt.axis('off')
        plt.show()

# Modified testing code to track misclassified images
def test_model(model, test_loader):
    model.eval()
    correct_test = 0
    total_test = 0
    misclassified_indices = []
    misclassified_labels = []
    misclassified_preds = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            correct_test += predicted.eq(labels).sum().item()
            total_test += labels.size(0)
            
            # Find misclassified indices
            misclassified_mask = ~predicted.eq(labels)
            for i in range(len(misclassified_mask)):
                if misclassified_mask[i]:
                    misclassified_indices.append(batch_idx * test_loader.batch_size + i)
                    misclassified_labels.append(labels[i].item())
                    misclassified_preds.append(predicted[i].item())

    test_accuracy = 100 * correct_test / total_test
    print(f"Final Test Accuracy: {test_accuracy:.2f}%")
    print(f"Total Misclassified Images: {len(misclassified_indices)}")
    
    return misclassified_indices, misclassified_labels, misclassified_preds

# Run the test and show misclassified images
misclassified_indices, misclassified_labels, misclassified_preds = test_model(model, test_loader)
show_misclassified_images(test_loader, misclassified_indices, misclassified_labels, misclassified_preds)
