In [None]:
import torch
import timm
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
batch_size = 16
num_epochs = 10
learning_rate = 0.0001
num_classes = 100

# Define the model
model = timm.create_model('convnextv2_base.fcmae_ft_in1k', pretrained=True, num_classes=num_classes)
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Data preparation
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
])

# Load datasets
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

# Track accuracies and losses for plotting
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

In [None]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += targets.size(0)
        train_correct += predicted.eq(targets).sum().item()

    train_acc = 100.0 * train_correct / train_total
    train_losses.append(train_loss / len(train_loader))
    train_accuracies.append(train_acc)

    # Validation step
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += targets.size(0)
            val_correct += predicted.eq(targets).sum().item()

    val_acc = 100.0 * val_correct / val_total
    val_losses.append(val_loss / len(test_loader))
    val_accuracies.append(val_acc)

    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    print(f"Train Loss: {train_losses[-1]:.4f}, Train Accuracy: {train_acc:.2f}%")
    print(f"Validation Loss: {val_losses[-1]:.4f}, Validation Accuracy: {val_acc:.2f}%")

    scheduler.step()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid()
plt.show()


In [None]:

# Plot training and validation accuracy
plt.figure(figsize=(10, 6))
# plt.plot(range(1, num_epochs + 1), train_accuracies, label='Training Accuracy')
plt.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(np.arange(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training Accuracy over Epochs')
plt.legend()
plt.grid()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Function to display misclassified images
def show_misclassified_images(test_loader, misclassified_indices, misclassified_labels, misclassified_preds):
    for i, idx in enumerate(misclassified_indices):
        image, label = test_loader.dataset[idx]
        image = image.permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C) for plotting
        
        plt.figure(figsize=(3, 3))
        plt.imshow(image.cpu().numpy())
        plt.title(f"True: {misclassified_labels[i]}, Pred: {misclassified_preds[i]}")
        plt.axis('off')
        plt.show()

# Modified testing code to track misclassified images
def test_model(model, test_loader):
    model.eval()
    correct_test = 0
    total_test = 0
    misclassified_indices = []
    misclassified_labels = []
    misclassified_preds = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            correct_test += predicted.eq(labels).sum().item()
            total_test += labels.size(0)
            
            # Find misclassified indices
            misclassified_mask = ~predicted.eq(labels)
            for i in range(len(misclassified_mask)):
                if misclassified_mask[i]:
                    misclassified_indices.append(batch_idx * test_loader.batch_size + i)
                    misclassified_labels.append(labels[i].item())
                    misclassified_preds.append(predicted[i].item())

    test_accuracy = 100 * correct_test / total_test
    print(f"Final Test Accuracy: {test_accuracy:.2f}%")
    print(f"Total Misclassified Images: {len(misclassified_indices)}")
    
    return misclassified_indices, misclassified_labels, misclassified_preds

# Run the test and show misclassified images
misclassified_indices, misclassified_labels, misclassified_preds = test_model(model, test_loader)
show_misclassified_images(test_loader, misclassified_indices, misclassified_labels, misclassified_preds)
