In [None]:
import os
import json
import numpy as np
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import glob
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset, Subset
from dotenv import load_dotenv

# ==========================================
# 1. Environment & Setup
# ==========================================
load_dotenv(os.path.join(os.getcwd(), ".env"))
PROJECT_NAME = os.getenv("WANDB_PROJECT", "cifar10_mlops_project")
ENTITY = os.getenv("WANDB_ENTITY", None)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==========================================
# 2. Shared Code (Inlined)
# ==========================================

class Cifar10DataManager:
    def __init__(self, data_dir="./data"):
        self.data_dir = data_dir
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2023, 0.1994, 0.2010)

    def get_transforms(self, architecture_option='standard'):
        transform_list = [
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ]
        train_transforms = [
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4)
        ] + transform_list

        if architecture_option == 'upsample':
            transform_list.insert(0, transforms.Resize(224))
            train_transforms.insert(0, transforms.Resize(224))

        return transforms.Compose(train_transforms), transforms.Compose(transform_list)

    def get_loaders(self, batch_size, architecture_option='standard'):
        train_transform, test_transform = self.get_transforms(architecture_option)
        train_set = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=train_transform)
        test_set = torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=test_transform)
        
        indices_path = os.path.join(self.data_dir, "processed", "test_indices.npy")
        if not os.path.exists(indices_path):
             # Ensure we have data
             raise FileNotFoundError(f"Indices not found at {indices_path}")
             
        test_indices = np.load(indices_path)
        real_test_set = Subset(test_set, test_indices)
        
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
        test_loader = DataLoader(real_test_set, batch_size=batch_size, shuffle=False, num_workers=2)
        return train_loader, test_loader

    def get_simulation_data(self):
        test_set_raw = torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=None)
        sim_indices = np.load(os.path.join(self.data_dir, "processed", "sim_indices.npy"))
        return Subset(test_set_raw, sim_indices)

def build_model(architecture_option='standard', num_classes=10, pretrained=True):
    model = torchvision.models.resnet18(pretrained=pretrained)
    if architecture_option == 'modified':
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        model.maxpool = nn.Identity()
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return running_loss / len(loader), 100 * correct / total

# ==========================================
# 3. Automated Retraining Logic
# ==========================================

print("Checking for Feedback Artifacts...")
run = wandb.init(project=PROJECT_NAME, job_type="automated_retraining", tags=["retrain"])

# 3.1 Download Feedback Data
try:
    artifact = run.use_artifact(f'{ENTITY}/{PROJECT_NAME}/cifar10-feedback:latest', type='dataset')
    artifact_dir = artifact.download(root=".")
    feedback_path = os.path.join(artifact_dir, "feedback_v1.npy")
    
    if os.path.exists(feedback_path):
        feedback_data = np.load(feedback_path)
        print(f"Found {len(feedback_data)} feedback samples to integrate.")
    else:
        feedback_data = []
except Exception as e:
    print(f"No feedback artifact found. (Normally this means no failures reported yet). Info: {e}")
    feedback_data = []

if len(feedback_data) > 0:
    # 3.2 Fetch Base Configuration
    api = wandb.Api()
    
    # Resolve Sweep ID
    sweep_id = os.getenv("SWEEP_ID")
    if not sweep_id:
        sweeps = api.project(PROJECT_NAME, entity=ENTITY).sweeps()
        if len(sweeps) > 0:
            sweep_id = sweeps[0].id
            print(f"Using latest sweep: {sweep_id}")
    
    if not sweep_id:
        print("Skipping: No Sweep ID found.")
        run.finish()
        exit()

    sweep = api.sweep(f"{ENTITY}/{PROJECT_NAME}/{sweep_id}")
    best_run = sweep.best_run()
    config = best_run.config
    
    # 3.3 Ensure Dataset Available (Indices needed)
    if not os.path.exists("./data/processed/sim_indices.npy"):
        print("Downloading dataset artifact...")
        run.use_artifact(f'{ENTITY}/{PROJECT_NAME}/cifar10_dataset:latest').download("./data")

    dm = Cifar10DataManager(data_dir="./data")
    train_loader_base, test_loader = dm.get_loaders(config['batch_size'], config['architecture_option'])
    
    # 3.4 Build "Augmented" Dataset (Train + Feedback)
    sim_data_subset = dm.get_simulation_data() # Raw PIL
    
    class FeedbackDataset(torch.utils.data.Dataset):
        def __init__(self, subset, feedback_indices, transform=None):
            self.subset = subset
            self.indices = [int(x[0]) for x in feedback_indices]
            self.transform = transform
        def __len__(self): return len(self.indices)
        def __getitem__(self, idx):
            sim_idx = self.indices[idx]
            image, label = self.subset[sim_idx]
            if self.transform: image = self.transform(image)
            return image, label

    train_transform, _ = dm.get_transforms(config['architecture_option'])
    feedback_ds = FeedbackDataset(sim_data_subset, feedback_data, transform=train_transform)
    
    full_train_set = ConcatDataset([train_loader_base.dataset, feedback_ds])
    full_train_loader = DataLoader(full_train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
    
    print(f"Retraining on {len(full_train_set)} samples...")
    
    # 3.5 Load Base Model Weights
    artifacts = best_run.logged_artifacts()
    model_artifact = [a for a in artifacts if a.type == "model"][0]
    model_dir = model_artifact.download(root="./models")
    model_path = glob.glob(os.path.join(model_dir, "*.pth"))[0]
    
    model = build_model(config['architecture_option']).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # 3.6 Fine-tune
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(2):
        print(f"Retraining Epoch {epoch+1}...")
        loss = train_epoch(model, full_train_loader, criterion, optimizer, device)
        _, acc = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1} Loss: {loss:.4f} Acc: {acc:.2f}%")
        wandb.log({"retrain_loss": loss, "retrain_acc": acc})
        
    # 3.7 Save New Version
    os.makedirs("./models", exist_ok=True)
    torch.save(model.state_dict(), "./models/model_retrained_v2.pth")
    
    art = wandb.Artifact(f"model-retrained-v2", type="model")
    art.add_file("./models/model_retrained_v2.pth")
    run.log_artifact(art)
    print("Retrained model v2 saved and logged.")
    
else:
    print("No feedback data found. Retraining skipped.")

run.finish()