# Model Evaluation

This notebook demonstrates how to fetch the best model from a W&B Sweep and evaluate it on the test set.


In [None]:
import wandb
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Setup
PROJECT_NAME = "cifar10_mlops_project"
# REPLACE THIS WITH YOUR SWEEP ID FROM NOTEBOOK 02
SWEEP_ID = "YOUR_SWEEP_ID_HERE" 
ENTITY = None # Your username

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Initialize W&B API
api = wandb.Api()

# Get the sweep
if SWEEP_ID == "YOUR_SWEEP_ID_HERE":
    print("Please set the SWEEP_ID variable in the previous cell.")
else:
    sweep = api.sweep(f"{ENTITY}/{PROJECT_NAME}/{SWEEP_ID}")
    best_run = sweep.best_run()
    print(f"Best Run ID: {best_run.id}")
    print(f"Best Run Accuracy: {best_run.summary.get('val_acc')}")

    # Find the model artifact from this run
    # We named it f"model-best-{run.id}" in notebook 02
    artifact_name = f"model-best-{best_run.id}:v0" # version 0 is usually the first one
    # Note: If you logged multiple times it might be v1, v2... or use 'latest'
    
    # Alternatively, list artifacts produced by the run
    artifacts = best_run.logged_artifacts()
    model_artifact = None
    for a in artifacts:
        if a.type == "model":
            model_artifact = a
            break
            
    if model_artifact:
        print(f"Downloading artifact: {model_artifact.name}")
        artifact_dir = model_artifact.download()
        model_path = f"{artifact_dir}/model_best_{best_run.id}.pth"
        print(f"Model downloaded to: {model_path}")
    else:
        print("No model artifact found for the best run.")

In [None]:
# Load Data (Test Set Only)
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = torchvision.datasets.CIFAR10(root='../data/raw', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# Re-define Model Architecture to load weights
def build_model():
    model = torchvision.models.resnet18(pretrained=False) # No need to download weights, we load ours
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)
    return model

if 'model_path' in locals():
    model = build_model()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    print("Model loaded successfully.")
else:
    print("Model path not defined. Did artifact download succeed?")

In [None]:
# Evaluate
if 'model' in locals():
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Metrics
    print(classification_report(all_labels, all_preds, target_names=classes))
    
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=classes, yticklabels=classes, cmap='Blues')
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    plt.show()