In [None]:
import os
import json
import glob
import wandb
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader

# ==========================================
# 1. Setup & Auth
# ==========================================
WANDB_API_KEY = "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"
PROJECT_NAME = "cifar10_mlops_project"
ENTITY = "esi-sba-dz"
wandb.login(key=WANDB_API_KEY)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# ==========================================
# 2. Helpers (No Download)
# ==========================================
class Cifar10DataManager:
    def __init__(self, data_dir="./data"):
        self.data_dir = data_dir
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2023, 0.1994, 0.2010)

    def get_loaders(self, batch_size, architecture_option='standard'):
        transform_list = [
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ]
        # Only test transform needed for eval
        test_transform = transforms.Compose(transform_list)
        if architecture_option == 'upsample':
            test_transform = transforms.Compose([transforms.Resize(224)] + transform_list)

        # STRICT: download=False
        test_set = torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=False, transform=test_transform)
        
        indices_path = os.path.join(self.data_dir, "processed", "test_indices.npy")
        real_test_set = Subset(test_set, np.load(indices_path))
        return DataLoader(real_test_set, batch_size=batch_size, shuffle=False)

def build_model(architecture_option='standard'):
    model = torchvision.models.resnet18(pretrained=True)
    if architecture_option == 'modified':
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        model.maxpool = nn.Identity()
    elif architecture_option == 'upsample':
        pass
    model.fc = nn.Linear(model.fc.in_features, 10)
    return model

In [None]:
# ==========================================
# 3. Fetch Artifacts
# ==========================================
run = wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type="evaluation")

print("Downloading Data Artifact...")
run.use_artifact(f'{ENTITY}/{PROJECT_NAME}/cifar10_dataset:latest', type='dataset').download(root="./data")

# Resolve Model
api = wandb.Api()
sweeps = api.project(PROJECT_NAME, entity=ENTITY).sweeps()
sweep_id = sweeps[0].id if sweeps else os.getenv("SWEEP_ID")
best_run = api.sweep(f"{ENTITY}/{PROJECT_NAME}/{sweep_id}").best_run()
config = best_run.config

print(f"Downloading Model from: {best_run.name}")
model_dir = best_run.logged_artifacts()[0].download(root="./models")
model_path = glob.glob(os.path.join(model_dir, "*.pth"))[0]

In [None]:
# ==========================================
# 4. Evaluation
# ==========================================
dm = Cifar10DataManager(data_dir="./data")
test_loader = dm.get_loaders(100, config['architecture_option'])

model = build_model(config['architecture_option']).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

all_preds, all_labels = [], []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        all_preds.extend(torch.max(outputs, 1)[1].cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(
    y_true=all_labels, preds=all_preds, 
    class_names=['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
)})
run.finish()
print("Evaluation Complete.")