In [None]:
import sys
import os
import json
import numpy as np
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import glob
from torch.utils.data import DataLoader, ConcatDataset
from dotenv import load_dotenv

# 1. Load .env specifically from the current notebook directory
load_dotenv(os.path.join(os.getcwd(), ".env"))

# 2. Setup path to import 'src'
if os.path.exists(os.path.join(os.getcwd(), 'src')):
    sys.path.append(os.getcwd())
elif os.path.exists(os.path.join(os.getcwd(), '..', 'src')):
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.dataset import Cifar10DataManager
from src.model import build_model
from src.training import train_epoch, validate

# Setup
PROJECT_NAME = os.getenv("WANDB_PROJECT", "cifar10_mlops_project")
ENTITY = os.getenv("WANDB_ENTITY", None)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Checking for Feedback Artifacts...")

# 1. Initialize Run
run = wandb.init(project=PROJECT_NAME, job_type="automated_retraining", tags=["retrain"])

# 2. Check/Download Feedback
try:
    artifact = run.use_artifact(f'{ENTITY}/{PROJECT_NAME}/cifar10-feedback:latest', type='dataset')
    artifact_dir = artifact.download(root=".")
    feedback_path = os.path.join(artifact_dir, "feedback_v1.npy")
    
    if os.path.exists(feedback_path):
        feedback_data = np.load(feedback_path)
        print(f"Found {len(feedback_data)} feedback samples to integrate.")
    else:
        print("No feedback data found.")
        feedback_data = []
except Exception as e:
    print(f"No feedback artifact found (or no new feedback): {e}")
    feedback_data = []

if len(feedback_data) > 0:
    # 3. Fetch Best Config & Model from W&B (Cloud Source of Truth)
    api = wandb.Api()
    
    # Try to resolve Sweep ID - checks env first (if passed via CI/CD), then assumes we want the latest
    sweep_id = os.getenv("SWEEP_ID")
    if not sweep_id:
        sweeps = api.project(PROJECT_NAME, entity=ENTITY).sweeps()
        if len(sweeps) > 0:
            sweep_id = sweeps[0].id
            print(f"Using default latest sweep: {sweep_id}")
    
    if not sweep_id:
        print("Skipping: No Sweep ID found to base retraining on.")
        run.finish()
        sys.exit(0)

    sweep = api.sweep(f"{ENTITY}/{PROJECT_NAME}/{sweep_id}")
    best_run = sweep.best_run()
    config = best_run.config
    
    # Download Base Model
    artifacts = best_run.logged_artifacts()
    model_artifact = [a for a in artifacts if a.type == "model"][0]
    model_dir = model_artifact.download(root="./models")
    model_path = glob.glob(os.path.join(model_dir, "*.pth"))[0]
    
    # Data Setup
    dm = Cifar10DataManager(data_dir="./data")
    # Ensuring data is present - relies on dm.prepare_initial_split or previous steps
    # Ideally use artifact again:
    # run.use_artifact(f'{ENTITY}/{PROJECT_NAME}/cifar10_dataset:latest').download("./data")
    
    train_loader_base, test_loader = dm.get_loaders(config['batch_size'], config['architecture_option'])
    
    # 4. Create Augmented Dataset
    sim_data_subset = dm.get_simulation_data() 
    
    class FeedbackDataset(torch.utils.data.Dataset):
        def __init__(self, subset, feedback_indices, transform=None):
            self.subset = subset
            self.indices = [int(x[0]) for x in feedback_indices]
            self.transform = transform
        def __len__(self): return len(self.indices)
        def __getitem__(self, idx):
            sim_idx = self.indices[idx]
            image, label = self.subset[sim_idx]
            if self.transform: image = self.transform(image)
            return image, label

    train_transform, _ = dm.get_transforms(config['architecture_option'])
    feedback_ds = FeedbackDataset(sim_data_subset, feedback_data, transform=train_transform)
    
    full_train_set = ConcatDataset([train_loader_base.dataset, feedback_ds])
    full_train_loader = DataLoader(full_train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
    
    print(f"Retraining on {len(full_train_set)} samples...")
    
    # 5. Retrain
    model = build_model(config['architecture_option']).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(2):
        print(f"Retraining Epoch {epoch+1}...")
        loss = train_epoch(model, full_train_loader, criterion, optimizer, device)
        _, acc = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1} Loss: {loss:.4f} Acc: {acc:.2f}%")
        wandb.log({"retrain_loss": loss, "retrain_acc": acc})
        
    os.makedirs("./models", exist_ok=True)
    torch.save(model.state_dict(), "./models/model_retrained_v2.pth")
    
    art = wandb.Artifact(f"model-retrained-v2", type="model")
    art.add_file("./models/model_retrained_v2.pth")
    run.log_artifact(art)
    print("Retrained model v2 saved and logged.")
    
else:
    print("Skipping retraining - no feedback data.")

run.finish()