In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Setup
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from src.data.cifar_loader import load_cifar_batch
from src.data.cifar_dataset import CIFAR10Dataset
from src.models.resnet import ResNet18CIFAR
from src.data.label_noise import inject_label_noise # Import the tool we just made

# --- CONFIG ---
NOISE_RATE = 0.2  # 20% of data is poisoned
NOISE_TYPE = 'symmetric' # 'symmetric' or 'asymmetric'
EPOCHS = 30 # Need enough epochs to see the "Memorization Spike"

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# 1. PREPARE DATA
print("Loading CIFAR-10 data...")
DATA_PATH = '/Users/daulet/Desktop/data centric ai/cifar-10-batches-py' # Check if this path is still correct

# --- A. Load Training Data (Batches 1-5) ---
all_train_images = []
all_train_labels = []

for i in range(1, 6):
    fpath = os.path.join(DATA_PATH, f'data_batch_{i}')
    if os.path.exists(fpath):
        imgs, lbls = load_cifar_batch(fpath)
        all_train_images.append(imgs)
        all_train_labels.extend(lbls)
    else:
        print(f"Warning: Could not find {fpath}")

x_train = np.concatenate(all_train_images)
y_train = np.array(all_train_labels)

# --- B. Load Test Data ---
test_fpath = os.path.join(DATA_PATH, 'test_batch')
# If test_batch is missing, sometimes it's named data_batch_5 or similar, but standard is test_batch
if os.path.exists(test_fpath):
    x_test, y_test = load_cifar_batch(test_fpath)
    x_test = np.array(x_test)
    y_test = np.array(y_test)
else:
    print("Error: Test batch not found! Checking fallback...")
    # Fallback if needed, but usually test_batch exists
    
print(f"Data Loaded: {len(x_train)} training images, {len(x_test)} test images")

# --- C. Define Transforms (Must match your baseline) ---
import torchvision.transforms as T
train_transform = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 2. INJECT POISON
print(f"Injecting {NOISE_RATE*100}% {NOISE_TYPE} noise...")
y_train_noisy, noisy_idx = inject_label_noise(y_train, NOISE_TYPE, NOISE_RATE)
y_train_noisy = np.array(y_train_noisy)

# Create Datasets
# We use standard transform for training
train_dataset = CIFAR10Dataset(x_train, y_train_noisy, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# We need a separate loader for the "Noisy Subset" to track metrics
# We don't shuffle this so we can match indices easily
noisy_subset = Subset(CIFAR10Dataset(x_train, y_train_noisy, transform=test_transform), noisy_idx)
noisy_loader = DataLoader(noisy_subset, batch_size=128, shuffle=False)

# Test Set (Clean)
test_dataset = CIFAR10Dataset(x_test, y_test, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# 3. SETUP MODEL
model = ResNet18CIFAR(num_classes=10).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# 4. TRACKING STORAGE
history = {
    'clean_test_acc': [],
    'memorization_ratio': [],
    'learning_ratio': [] # Bonus: Are they learning the TRUE label despite the noise?
}

print("Starting Timeline Analysis...")

for epoch in range(EPOCHS):
    # --- TRAIN ---
    model.train()
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(imgs)
        loss = criterion(output, labels) # We train on the LIE (noisy label)
        loss.backward()
        optimizer.step()
    
    # --- EVALUATE (The "Probe") ---
    model.eval()
    
    # A. Clean Test Accuracy
    correct = 0; total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            out = model(imgs)
            _, pred = torch.max(out, 1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)
    clean_acc = 100 * correct / total
    history['clean_test_acc'].append(clean_acc)
    
    # B. Memorization Ratio (On the Poisoned Data)
    # We check: Did the model predict the NOISY label (y_tilde)?
    mem_correct = 0 # Pred == y_tilde
    learn_correct = 0 # Pred == y_true (The hidden truth)
    total_noisy = 0
    
    with torch.no_grad():
        # iterate through our noisy_subset
        # Note: The dataset returns (img, label). The label is y_tilde (noisy).
        # We need y_true for "Learning Ratio", but let's stick to "Memorization" first.
        for imgs, current_labels in noisy_loader:
            imgs, current_labels = imgs.to(device), current_labels.to(device)
            out = model(imgs)
            _, pred = torch.max(out, 1)
            
            # current_labels IS y_tilde (the lie)
            mem_correct += (pred == current_labels).sum().item()
            total_noisy += current_labels.size(0)
            
    mem_ratio = 100 * mem_correct / total_noisy
    history['memorization_ratio'].append(mem_ratio)
    
    print(f"Epoch {epoch+1} | Clean Acc: {clean_acc:.1f}% | Memorization: {mem_ratio:.1f}%")

# 5. PLOT THE CURVES
plt.figure(figsize=(10,6))
plt.plot(history['clean_test_acc'], label='Clean Test Acc (Generalization)', color='green')
plt.plot(history['memorization_ratio'], label='Memorization Ratio (Fitting Noise)', color='red', linestyle='--')
plt.xlabel('Epochs')
plt.ylabel('Percentage')
plt.title(f'The Learning Timeline ({NOISE_TYPE.capitalize()} Noise)')
plt.legend()
plt.grid(True)
plt.show()

Loading CIFAR-10 data...
Data Loaded: 50000 training images, 10000 test images
Injecting 20.0% symmetric noise...
Starting Timeline Analysis...
Epoch 1 | Clean Acc: 46.9% | Memorization: 6.1%
Epoch 2 | Clean Acc: 50.6% | Memorization: 5.6%
Epoch 3 | Clean Acc: 62.9% | Memorization: 4.6%
Epoch 4 | Clean Acc: 62.9% | Memorization: 4.4%
Epoch 5 | Clean Acc: 70.1% | Memorization: 3.8%
Epoch 6 | Clean Acc: 71.7% | Memorization: 3.6%
Epoch 7 | Clean Acc: 75.2% | Memorization: 3.3%
