In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import wandb
import torch
import torchvision
import uvicorn
from fastapi import FastAPI
import threading
import requests
import time
import numpy as np
import json
import glob
import nest_asyncio
from src.utils import load_env_vars
from src.dataset import Cifar10DataManager
from src.model import build_model

nest_asyncio.apply()
env = load_env_vars()
PROJECT_NAME = env.get("WANDB_PROJECT", "cifar10_mlops_project")
ENTITY = env.get("WANDB_ENTITY", None)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Setup Inference App
app = FastAPI()

# Load Config & Model
with open("../artifacts/best_config.json", "r") as f:
    config = json.load(f)

print("Downloading Real Model for Simulation...")
# Fetch from W&B using sweep_id to find best run
with open("../artifacts/sweep_id.txt", "r") as f:
    sweep_id = f.read().strip()

api = wandb.Api()
sweep = api.sweep(f"{ENTITY}/{PROJECT_NAME}/{sweep_id}")
best_run = sweep.best_run()
artifacts = best_run.logged_artifacts()
model_artifact = [a for a in artifacts if a.type == "model"][0]
model_dir = model_artifact.download(root="../models")
model_path = glob.glob(os.path.join(model_dir, "*.pth"))[0]

print(f"Loading model: {model_path}")
model = build_model(config['architecture_option']).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# Helper Transforms (Need to match training transforms for normality)
dm = Cifar10DataManager(data_dir="../data")
_, val_transform = dm.get_transforms(config['architecture_option'])

@app.post("/predict")
def predict(payload: dict):
    try:
        idx = payload.get("index")
        sim_data = dm.get_simulation_data()
        image, _ = sim_data[idx] 
        input_tensor = val_transform(image).unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(input_tensor)
            probs = torch.nn.functional.softmax(output, dim=1)
            conf, pred = torch.max(probs, 1)
            
        return {
            "prediction": int(pred.item()),
            "confidence": float(conf.item())
        }
    except Exception as e:
        return {"error": str(e)}

# 2. Start Server
def run_server():
    uvicorn.run(app, host="127.0.0.1", port=8000, log_level="error")

server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
print("Services started... Waiting 5s")
time.sleep(5)

# 3. Real Simulation Loop
wandb.init(project=PROJECT_NAME, job_type="deployment_simulation")

sim_data = dm.get_simulation_data() # access the 2k holdout set
feedback_data = [] # Store (image_index, correct_label)
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

table = wandb.Table(columns=["index", "prediction", "ground_truth", "confidence", "correct"])

print("Running Simulation on 30 random samples...")
indices = np.random.choice(len(sim_data), 30, replace=False)

for idx in indices:
    # Ground Truth
    _, gt_label = sim_data[idx]
    
    # Request Prediction
    try:
        resp = requests.post("http://127.0.0.1:8000/predict", json={"index": int(idx)})
        res = resp.json()
        
        pred = res["prediction"]
        conf = res["confidence"]
        
        is_correct = (pred == gt_label)
        
        print(f"Idx {idx}: Truth={classes[gt_label]} | Pred={classes[pred]} ({conf:.2f}) -> {'✅' if is_correct else '❌'}")
        
        table.add_data(idx, classes[pred], classes[gt_label], conf, is_correct)
        
        if not is_correct:
            # FEEDBACK LOOP: Capture failure
            feedback_data.append((int(idx), int(gt_label)))
            
    except Exception as e:
        print(f"Request failed: {e}")

wandb.log({"simulation_results": table})

# 4. Create Feedback Artifact (v2 Dataset Increment)
if len(feedback_data) > 0:
    print(f"\nCreating Feedback Artifact with {len(feedback_data)} new labeled samples...")
    np.save("feedback_v1.npy", feedback_data)
    
    artifact = wandb.Artifact("cifar10-feedback", type="dataset")
    artifact.add_file("feedback_v1.npy")
    wandb.log_artifact(artifact)
    print("Feedback artifact logged. Triggering Automated Retraining...")
else:
    print("No errors found! No retraining needed.")

wandb.finish()

In [None]:
import wandb
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from fastapi import FastAPI
import uvicorn
import threading
import requests
import time
import numpy as np
from PIL import Image
import io
import asyncio
import nest_asyncio
import os

# Apply nest_asyncio to allow running uvicorn in a notebook
nest_asyncio.apply()

# Configuration
PROJECT_NAME = "cifar10_mlops_project"

# --- AUTOMATION: READ SWEEP ID FROM FILE ---
try:
    with open("../artifacts/sweep_id.txt", "r") as f:
        SWEEP_ID = f.read().strip()
    print(f"Loaded Sweep ID: {SWEEP_ID}")
except FileNotFoundError:
    SWEEP_ID = "YOUR_SWEEP_ID" 
    print("Sweep ID file not found. Please run notebook 02 first or set manually.")

ENTITY = None 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# 1. Load Best Model from W&B Registry
def load_best_model():
    api = wandb.Api()
    # Try-catch or conditional check for valid ID
    try:
        sweep = api.sweep(f"{ENTITY}/{PROJECT_NAME}/{SWEEP_ID}")
        best_run = sweep.best_run()
        print(f"Loading best model from run: {best_run.id}")
        
        artifacts = best_run.logged_artifacts()
        model_artifact = None
        for a in artifacts:
            if a.type == "model":
                model_artifact = a
                break
                
        if model_artifact:
            artifact_dir = model_artifact.download()
            model_path = f"{artifact_dir}/model_best_{best_run.id}.pth"
            
            # Reconstruct model architecture
            # Note: Ideally architecture config should be saved with the model or in config
            model = torchvision.models.resnet18(pretrained=False)
            model.fc = nn.Linear(model.fc.in_features, 10)
            
            # If you used 'modified' architecture in sweep, you need to handle that check here too.
            # For simplicity assuming standard/upsample backbone structure for loading weights
            try:
                model.load_state_dict(torch.load(model_path, map_location=device))
            except:
                # Fallback for modified architecture if keys don't match
                model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
                model.maxpool = nn.Identity()
                model.load_state_dict(torch.load(model_path, map_location=device))
                
            model.to(device)
            model.eval()
            print("Model loaded successfully!")
            return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

model = load_best_model()

In [None]:
# 2. Define FastAPI App
app = FastAPI()

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Transform for inference
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

@app.post("/predict")
async def predict_image(data: dict):
    # Simulate receiving an image array (simplified)
    # In prod, you'd handle file uploads or base64 strings
    try:
        # Expecting a list of pixels or simple identifier for simulation
        # Here we just fetch real data from CIFAR test set by index for simulation
        idx = data.get("index", 0)
        
        # Load dataset on the fly (inefficient for prod, okay for demo)
        testset = torchvision.datasets.CIFAR10(root='../data/raw', train=False, download=True, transform=transform)
        image, label = testset[idx]
        
        image = image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(image)
            _, predicted = torch.max(output, 1)
            probs = torch.nn.functional.softmax(output, dim=1)
            confidence = probs[0][predicted.item()].item()
            
        return {
            "prediction": classes[predicted.item()],
            "confidence": confidence,
            "ground_truth": classes[label],
            "correct": predicted.item() == label
        }
    except Exception as e:
        return {"error": str(e)}

# 3. Utilities to run Server in Notebook
def run_server():
    uvicorn.run(app, host="127.0.0.1", port=8000, log_level="warning")

# Start server in a separate thread
server_thread = threading.Thread(target=run_server)
server_thread.daemon = True
server_thread.start()
print("FastAPI server started at http://127.0.0.1:8000")
time.sleep(3) # Wait for startup

In [None]:
# 4. Simulate Prediction Requests and Log to W&B
wandb.init(project=PROJECT_NAME, job_type="production-monitoring")

# Create a W&B Table to log requests
columns = ["request_id", "input_index", "prediction", "confidence", "ground_truth", "correct"]
prediction_table = wandb.Table(columns=columns)

correct_count = 0
total_requests = 10

print("Simulating 10 user requests...")

for i in range(total_requests):
    # Randomly pick an image index from test set
    idx = np.random.randint(0, 1000)
    
    payload = {"index": idx}
    
    try:
        response = requests.post("http://127.0.0.1:8000/predict", json=payload)
        result = response.json()
        
        if "error" in result:
            print(f"Request {i+1} failed: {result['error']}")
            continue
            
        print(f"Req {i+1}: Pred={result['prediction']}, True={result['ground_truth']} ({result['correct']})")
        
        # Log to Table
        prediction_table.add_data(
            i+1, 
            idx, 
            result['prediction'], 
            result['confidence'], 
            result['ground_truth'], 
            result['correct']
        )
        
        if result['correct']:
            correct_count += 1
            
    except Exception as e:
        print(f"Connection failed: {e}")

# Calculate Production Accuracy
prod_accuracy = correct_count / total_requests
print(f"\nProduction Accuracy: {prod_accuracy * 100}%")

# Log accumulated metrics to W&B
wandb.log({
    "production_accuracy": prod_accuracy,
    "inference_requests": prediction_table
})

wandb.finish()