In [None]:
import os
import json
import glob
import time
import requests
import threading
import numpy as np
import nest_asyncio
import uvicorn
import wandb
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
from fastapi import FastAPI

# ==========================================
# 1. Setup
# ==========================================
WANDB_API_KEY = "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"
PROJECT_NAME = "cifar10_mlops_project"
ENTITY = "esi-sba-dz"
wandb.login(key=WANDB_API_KEY)
nest_asyncio.apply()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# ==========================================
# 2. Helpers (Strict No-Download)
# ==========================================
class Cifar10DataManager:
    def __init__(self, data_dir="./data"):
        self.data_dir = data_dir
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2023, 0.1994, 0.2010)

    def get_transform(self, architecture_option='standard'):
        tf = [transforms.ToTensor(), transforms.Normalize(self.mean, self.std)]
        if architecture_option == 'upsample':
            tf.insert(0, transforms.Resize(224))
        return transforms.Compose(tf)

    def get_simulation_data(self):
        # STRICT: download=False
        test_set = torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=False)
        indices = np.load(os.path.join(self.data_dir, "processed", "sim_indices.npy"))
        return Subset(test_set, indices)

def build_model(architecture_option='standard'):
    model = torchvision.models.resnet18(pretrained=True)
    if architecture_option == 'modified':
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        model.maxpool = nn.Identity()
    elif architecture_option == 'upsample':
        pass
    model.fc = nn.Linear(model.fc.in_features, 10)
    return model

In [None]:
# ==========================================
# 3. Fetch Resources
# ==========================================
run = wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type="deploy_prep")
print("Fetching Data Artifact...")
run.use_artifact(f'{ENTITY}/{PROJECT_NAME}/cifar10_dataset:latest').download("./data")

# Fetch Model
api = wandb.Api()
sweeps = api.project(PROJECT_NAME, entity=ENTITY).sweeps()
sweep_id = sweeps[0].id
best_run = api.sweep(f"{ENTITY}/{PROJECT_NAME}/{sweep_id}").best_run()
config = best_run.config

print("Fetching Model Artifact...")
model_dir = best_run.logged_artifacts()[0].download(root="./models")
model_path = glob.glob(os.path.join(model_dir, "*.pth"))[0]

model = build_model(config['architecture_option']).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
run.finish()

dm = Cifar10DataManager()
val_transform = dm.get_transform(config['architecture_option'])

In [None]:
# ==========================================
# 4. FastAPI Server
# ==========================================
app = FastAPI()

@app.post("/predict")
def predict(payload: dict):
    idx = payload.get("index")
    sim_data = dm.get_simulation_data()
    image, _ = sim_data[idx] 
    tensor = val_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(tensor)
        conf, pred = torch.max(torch.nn.functional.softmax(output, dim=1), 1)
        
    return {"prediction": int(pred.item()), "confidence": float(conf.item())}

def start_server():
    uvicorn.run(app, host="127.0.0.1", port=8000, log_level="error")

threading.Thread(target=start_server, daemon=True).start()
print("Server starting...")
time.sleep(5)

In [None]:
# ==========================================
# 5. Simulation Loop
# ==========================================
wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type="deployment_simulation")

sim_data = dm.get_simulation_data()
feedback_data = []
table = wandb.Table(columns=["index", "pred", "truth", "conf", "correct"])
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

for idx in np.random.choice(len(sim_data), 30, replace=False):
    _, gt = sim_data[idx]
    resp = requests.post("http://127.0.0.1:8000/predict", json={"index": int(idx)}).json()
    
    pred = resp["prediction"]
    correct = (pred == gt)
    table.add_data(idx, classes[pred], classes[gt], resp["confidence"], correct)
    
    if not correct:
        feedback_data.append((int(idx), int(gt)))

wandb.log({"simulation_results": table})

if feedback_data:
    print(f"Captured {len(feedback_data)} errors.")
    np.save("feedback_v1.npy", feedback_data)
    art = wandb.Artifact("cifar10-feedback", type="dataset")
    art.add_file("feedback_v1.npy")
    wandb.log_artifact(art)

wandb.finish()