In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import json
import numpy as np
from src.utils import load_env_vars
from src.dataset import Cifar10DataManager
from src.model import build_model
from src.training import train_epoch, validate
from torch.utils.data import DataLoader, Subset, ConcatDataset

# Setup
env = load_env_vars()
PROJECT_NAME = env.get("WANDB_PROJECT", "cifar10_mlops_project")
ENTITY = env.get("WANDB_ENTITY", None)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Checking for Feedback Artifacts...")

# 1. Initialize Run
run = wandb.init(project=PROJECT_NAME, job_type="automated_retraining", tags=["retrain"])

# 2. Check/Download Feedback
try:
    # We look for the artifact created in the previous step
    # Note: In a real automated pipeline, this would be triggered by a webhook or a specific pipeline orchestrator
    # Here, we just pull the latest version of the feedback artifact
    artifact = run.use_artifact(f'{ENTITY}/{PROJECT_NAME}/cifar10-feedback:latest', type='dataset')
    artifact_dir = artifact.download()
    feedback_path = os.path.join(artifact_dir, "feedback_v1.npy")
    
    if os.path.exists(feedback_path):
        feedback_data = np.load(feedback_path)
        print(f"Found {len(feedback_data)} feedback samples to integrate.")
    else:
        print("No feedback data found.")
        feedback_data = []
except Exception as e:
    print(f"No feedback artifact found: {e}")
    feedback_data = []

if len(feedback_data) > 0:
    # 3. Load Base Dataset & Config
    with open("../artifacts/best_config.json", "r") as f:
        config = json.load(f)
        
    dm = Cifar10DataManager(data_dir="../data")
    train_loader_base, test_loader = dm.get_loaders(config['batch_size'], config['architecture_option'])
    
    # 4. Create "Augmented" Train Set (Base + Feedback)
    # The feedback data contains (idx, label) from the SIMULATION split (which was held out).
    # We need to construct a dataset from these specific indices.
    
    sim_data_subset = dm.get_simulation_data() # RAW PIL images
    
    # We need to wrap them to apply the same Transforms as training
    class FeedbackDataset(torch.utils.data.Dataset):
        def __init__(self, subset, feedback_indices, transform=None):
            self.subset = subset
            self.indices = [int(x[0]) for x in feedback_indices] # List of indices relative to sim_subset
            self.transform = transform
            
        def __len__(self):
            return len(self.indices)
            
        def __getitem__(self, idx):
            # Map valid 0..N index to the actual sim index
            sim_idx = self.indices[idx]
            image, label = self.subset[sim_idx]
            if self.transform:
                image = self.transform(image)
            return image, label

    train_transform, _ = dm.get_transforms(config['architecture_option'])
    feedback_ds = FeedbackDataset(sim_data_subset, feedback_data, transform=train_transform)
    
    # Concatenate
    full_train_set = ConcatDataset([train_loader_base.dataset, feedback_ds])
    
    full_train_loader = DataLoader(full_train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0) # workers=0 avoids issues in some colab/win environments with specific loaders
    
    print(f"Retraining on {len(full_train_set)} samples (Original + Feedback)...")
    
    # 5. Retrain Model (Fine-tune Best Model)
    model = build_model(config['architecture_option']).to(device)
    model.load_state_dict(torch.load("../models/model_best_sweep.pth", map_location=device))
    
    # Optimization (Lower LR for fine-tuning)
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Conservative LR
    criterion = nn.CrossEntropyLoss()
    
    # Quick Retrain (1-2 epochs)
    for epoch in range(2):
        print(f"Retraining Epoch {epoch+1}...")
        loss = train_epoch(model, full_train_loader, criterion, optimizer, device)
        _, acc = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1} Loss: {loss:.4f} Acc: {acc:.2f}%")
        wandb.log({"retrain_loss": loss, "retrain_acc": acc})
        
    # Save Retrained Model
    torch.save(model.state_dict(), "../models/model_retrained_v2.pth")
    
    # Version Model v2
    art = wandb.Artifact(f"model-retrained-v2", type="model")
    art.add_file("../models/model_retrained_v2.pth")
    run.log_artifact(art)
    print("Retrained model v2 saved and logged.")
    
else:
    print("Skipping retraining.")

run.finish()