# 🧪 Evaluate Best ResNet50 Sweep Model on Test Set

In [1]:
import wandb
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os, matplotlib.pyplot as plt
import numpy as np

In [None]:
best_run_id = "alokgaurav04-indian-institute-of-technology-madras/inat-cnn-sweep/runs/6h0r51rw"

api = wandb.Api()
run = api.run(best_run_id)
model_file = run.file(f"model_resnet50_{run.id}.pth")
model_file.download(replace=True)

In [None]:
# Load model
def load_resnet50_finetune(num_classes, freeze_until=0):
    model = models.resnet50(pretrained=True)
    ct = 0
    for child in model.children():
        ct += 1
        if ct <= freeze_until:
            for param in child.parameters():
                param.requires_grad = False
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

In [None]:
# Set device and load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 10
config = run.config
best_model = load_resnet50_finetune(num_classes=NUM_CLASSES, freeze_until=config["freeze_until"])
best_model.load_state_dict(torch.load(f"model_resnet50_{run.id}.pth", map_location=device))
best_model.to(device)
best_model.eval()

In [None]:
# Test loader setup
DATA_DIR = "/content/inaturalist_12K"
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
test_dir = os.path.join(DATA_DIR, "val")
test_dataset = datasets.ImageFolder(test_dir, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Accuracy Evaluation
correct, total = 0, 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = best_model(inputs)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
test_accuracy = correct / total
print(f"✅ Test Accuracy of Best Sweep Model: {test_accuracy * 100:.2f}%")

In [None]:
# Visualization Grid
def imshow(inp, title=None):
    inp = inp.cpu().numpy().transpose((1, 2, 0))
    plt.imshow(inp)
    plt.axis('off')
    if title: plt.title(title, fontsize=8)

inputs, labels = next(iter(test_loader))
inputs, labels = inputs.to(device), labels.to(device)
outputs = best_model(inputs)
_, preds = torch.max(outputs, 1)
classes = test_dataset.classes

plt.figure(figsize=(20, 10))
for i in range(30):
    plt.subplot(10, 3, i + 1)
    imshow(inputs[i])
    plt.title(f"Pred: {classes[preds[i]]}\nTrue: {classes[labels[i]]}", fontsize=8)
plt.tight_layout()
plt.show()

In [None]:
# Confusion Matrix
all_preds, all_labels = [], []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = best_model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.xticks(rotation=45)
plt.yticks(rotation=45)
plt.tight_layout()
plt.show()