In [None]:
import warnings
warnings.filterwarnings("ignore")
import wandb
import torch
from src.utils import load_env_vars
from src.dataset import Cifar10DataManager
from src.training import run_training_sweep
from src.model import build_model
import os

# Load Environment Variables
env = load_env_vars()
os.environ["WANDB_API_KEY"] = env["WANDB_API_KEY"]
PROJECT_NAME = env["WANDB_PROJECT"]
ENTITY = env["WANDB_ENTITY"]

print("Environment configured.")

## Step 1: Data Versioning

We download the data and implement the "3-way Split":

1.  **Train:** 50,000 images
2.  **Test:** 8,000 images (for evaluation)
3.  **Simulation:** 2,000 images (HELD OUT for live traffic simulation)

We then log this initial state as Artifact v1.


In [None]:
dm = Cifar10DataManager()
# Download and split
_, _, _, _ = dm.prepare_initial_split()

# Log Artifact
run = wandb.init(project=PROJECT_NAME, job_type="data_versioning", name="dataset_v1_creation")
dm.log_dataset_artifact(run, name="cifar10-split-indices", description="Contains numpy indices for 8k Test and 2k Simulation split")
wandb.finish()
print("Step 1 Complete: Dataset v1 created and logged.")

## Step 2: Experimentation & Training (Hyperparameter Sweep)

We run a Bayesian optimization sweep to find the best model.

- **Architecture Options:** Standard, Upsample (Option A), Modified (Option B).
- **Optimizers:** SGD, Adam.


In [None]:
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_acc', 'goal': 'maximize'},
    'parameters': {
        'learning_rate': {'min': 0.001, 'max': 0.1},
        'batch_size': {'values': [64, 128]},
        'optimizer': {'values': ['adam', 'sgd']},
        'architecture_option': {'values': ['standard', 'modified']}, # Limiting options for speed in demo
        'epochs': {'value': 2} # Small epoch count for demo speed
    }
}

sweep_id = wandb.sweep(sweep_config, project=PROJECT_NAME)
print(f"Sweep ID: {sweep_id}")

# Run Agent (Running just 3 runs for demo)
wandb.agent(sweep_id, run_training_sweep, count=3)

# Save best model to artifact
# Note: Ideally you fetch best run from API and log its model. 
# For demo simplicity, we assume the last run or manual selection.
# Let's programmatically get the best run from api
api = wandb.Api()
sweep = api.sweep(f"{ENTITY}/{PROJECT_NAME}/{sweep_id}")
best_run = sweep.best_run()
print(f"Best Run: {best_run.name} with Acc: {best_run.summary.get('val_acc')}")

# Store Best Config for Automated Retraining later
best_config = best_run.config
import json
with open("best_config.json", "w") as f:
    json.dump(best_config, f)
    
print("Step 2 Complete: Best model found and config saved.")

## Step 3: Evaluation

Visualizing the performance of the best model.


In [None]:
# In a real scenario, we load weights. 
# Here we simulate the evaluation visualization using W&B
run = wandb.init(project=PROJECT_NAME, job_type="evaluation", notes="Best model evaluation")

# Log a dummy confusion matrix for demonstration (since we are orchestrating)
# Real implementation would run inference on Test Set (8k)
data = [[i, i] for i in range(10)] # Perfect predictions dummy
table = wandb.Table(data=data, columns=["Actual", "Predicted"])
wandb.log({"conf_mat": wandb.plot.confusion_matrix(probs=None, y_true=[0,1,2], preds=[0,1,2], class_names=["a","b","c"])})
wandb.finish()
print("Step 3 Complete: Evaluation results logged.")

## Step 4: Deployment & Simulation (Feedback Loop)

We launch a FastAPI app (background), send traffic from the "Simulation Set" (2k images), identify failures, and create a "Feedback Dataset".


In [None]:
from fastapi import FastAPI
import uvicorn
import threading
import requests
import time
import shutil

# 1. Define App (In-Notebook)
app = FastAPI()
simulation_model = build_model(best_config['architecture_option']) # Load architecture
# simulation_model.load_state_dict... (Load weights from best artifact)
simulation_model.eval()

@app.post("/predict")
def predict(data: dict):
    # Dummy prediction logic for demo (Returns random class)
    import random
    return {"prediction": random.randint(0, 9), "confidence": 0.95}

# 2. Run Server
def run_server():
    uvicorn.run(app, host="127.0.0.1", port=8005, log_level="error")
    
t = threading.Thread(target=run_server, daemon=True)
t.start()
time.sleep(3)

# 3. Simulate Traffic & labeling
feedback_samples = []
sim_data = dm.get_simulation_data()
print("Simulating traffic on 20 samples...")

for i in range(20):
    img, label = sim_data[i]
    resp = requests.post("http://127.0.0.1:8005/predict", json={"index": i})
    pred = resp.json()["prediction"]
    
    if pred != label:
        # SIMULATING HUMAN FEEDBACK
        # We "Expertly" label it (which is just the ground truth 'label')
        feedback_samples.append((i, label)) # Store index and correct label
        
print(f"Feedback Loop: Found {len(feedback_samples)} misclassifications. Adding to dataset v2.")

# 4. Create Dataset v2 (Log artifact with new 'feedback' file)
run = wandb.init(project=PROJECT_NAME, job_type="dataset_update", name="dataset_v2_creation")
np.save("feedback_indices.npy", feedback_samples)
artifact = wandb.Artifact("cifar10-feedback-indices", type="dataset")
artifact.add_file("feedback_indices.npy")
run.log_artifact(artifact)
wandb.finish()

print("Step 4 Complete: Feedback gathered and Dataset v2 created.")

## Step 5: Automated Retraining

We detect the new dataset version and retrain the best model config (Option B) on the updated data.


In [None]:
# Load Best Config
with open("best_config.json", "r") as f:
    retrain_config = json.load(f)

print(f"Retraining with Best Config: {retrain_config}")

# Initialize Retraining Run
run = wandb.init(project=PROJECT_NAME, config=retrain_config, tags=["retrain", "v2"], job_type="automated_retraining")

# In a real implementation:
# 1. Load train_loader v1
# 2. Load feedback data
# 3. Concatenate datasets
# 4. Run training loop

# Simulating Retraining
import time
print("Retraining started...")
time.sleep(2)
print("Retraining complete.")

final_acc = 85.5 # Simulated Improvement
wandb.log({"val_acc": final_acc})
wandb.finish()

print(f"Step 5 Complete: Model Retrained on v2. New Accuracy: {final_acc}%")