In [None]:
import os
import json
import numpy as np
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import glob
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset, Subset

# ==========================================
# 1. Setup
# ==========================================
WANDB_API_KEY = "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"
PROJECT_NAME = "cifar10_mlops_project"
ENTITY = "esi-sba-dz"
wandb.login(key=WANDB_API_KEY)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



KeyboardInterrupt: 

In [None]:
# ==========================================
# 2. Helpers (Strict No-Download)
# ==========================================
class Cifar10DataManager:
    def __init__(self, data_dir="./data"):
        self.data_dir = data_dir
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2023, 0.1994, 0.2010)

    def get_loader_for_retrain(self, batch_size, architecture_option='standard'):
        tf_list = [transforms.ToTensor(), transforms.Normalize(self.mean, self.std)]
        train_tf = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4)] + tf_list)
        if architecture_option == 'upsample':
            # Fix: Ensure ToTensor and Normalize are included for upsample too
            train_tf = transforms.Compose([transforms.Resize(224), transforms.RandomHorizontalFlip()] + tf_list)

        # STRICT: download=False
        train_set = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=False, transform=train_tf)
        return train_set, train_tf

    def get_simulation_raw(self):
        return torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=False)
        
def build_model(architecture_option='standard'):
    model = torchvision.models.resnet18(pretrained=True)
    if architecture_option == 'modified':
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    elif architecture_option == 'upsample':
        pass
    model.fc = nn.Linear(model.fc.in_features, 10)
    return model

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = criterion(model(inputs), labels)
        loss.backward()
        optimizer.step()


    return running_loss / len(loader)        running_loss += loss.item()    return running_loss / len(loader)

In [None]:
# ==========================================
# 3. Execution
# ==========================================
run = wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type="retrain", tags=["retrain"])

try:
    print("Downloading Feedback...")
    f_art = run.use_artifact(f'{ENTITY}/{PROJECT_NAME}/cifar10-feedback:latest').download(root=".")
    feedback = np.load(os.path.join(f_art, "feedback_v1.npy"))
except:
    print("No feedback found.")
    feedback = []

if len(feedback) > 0:
    print("Downloading Baseline Data...")
    run.use_artifact(f'{ENTITY}/{PROJECT_NAME}/cifar10_dataset:latest').download("./data")
    
    api = wandb.Api()
    sweeps = api.project(PROJECT_NAME, entity=ENTITY).sweeps()
    best_run = api.sweep(f"{ENTITY}/{PROJECT_NAME}/{sweeps[0].id}").best_run()
    config = best_run.config
    
    print("Downloading Baseline Model...")
    m_dir = best_run.logged_artifacts()[0].download(root="./models")
    m_path = glob.glob(os.path.join(m_dir, "*.pth"))[0]
    
    # Dataset Merge
    dm = Cifar10DataManager()
    base_train, tf = dm.get_loader_for_retrain(config['batch_size'], config['architecture_option'])
    raw_sim = dm.get_simulation_raw()
    
    class FeedbackDS(torch.utils.data.Dataset):
        def __init__(self, raw, inds, tf):
             self.raw = raw; self.inds = [int(i[0]) for i in inds]; self.tf = tf
        def __len__(self): return len(self.inds)
        def __getitem__(self, i): 
             img, label = self.raw[self.inds[i]]
             return self.tf(img), label
             
    fb_ds = FeedbackDS(raw_sim, feedback, tf)
    loader = DataLoader(ConcatDataset([base_train, fb_ds]), batch_size=config['batch_size'], shuffle=True)
    
    # Retrain
    model = build_model(config['architecture_option']).to(device)
    model.load_state_dict(torch.load(m_path, map_location=device))
    
    opt = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    crit = nn.CrossEntropyLoss()
    
    print("Fine-tuning...")
    for e in range(2):
        l = train_epoch(model, loader, crit, opt, device)
        print(f"Epoch {e+1} Loss: {l:.4f}")
        wandb.log({"retrain_loss": l})
        
    torch.save(model.state_dict(), "retrained.pth")
    art = wandb.Artifact("retrained-model", type="model")
    art.add_file("retrained.pth")
    run.log_artifact(art)
    print("Retraining Complete.")
else:
    print("Skipping.")
    
run.finish()