In [None]:
# ==============================================================================
# Core ML Task: Robust Loss Functions for Noisy Labels (CIFAR-10)
# ==============================================================================

# --- Standard Libraries ---
import math
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- PyTorch Libraries ---
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets
from torchvision import transforms

# --- Environment Setup ---
try:
    # Check if running in Google Colab for Drive mounting
    from google.colab import drive
    COLAB_ENV = True
except ImportError:
    COLAB_ENV = False

# Print library versions and environment
print(f"PyTorch Version: {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")
print(f"Running in Colab: {COLAB_ENV}")


# ==============================================================================
# Configuration
# ==============================================================================

# --- Training Parameters ---
EPOCHS = 50           # Number of epochs for each training run
LEARNING_RATE = 0.1   # Initial learning rate for SGD
BATCH_SIZE = 128      # Training batch size
SEED = 42             # Random seed for reproducibility
WEIGHT_DECAY = 5e-4   # Weight decay for SGD optimizer

# --- Experiment Selection ---
RUN_SYMMETRIC = True    # Run experiments with symmetric noise
RUN_ASYMMETRIC = False  # Run experiments with asymmetric noise (Bonus)

# Select specific loss functions to run by listing their keys from `all_losses`.
# Set to None to run ALL defined losses.
# Example: LOSSES_TO_RUN = ['CE', 'NCE (s=1)', 'APL (NCE+MAE)']
LOSSES_TO_RUN = ['CE', 'NCE (s=1)', 'APL (NCE+MAE)', 'APL (NCE+RCE)']

# Define the noise rates for each type of noise experiment
SYMMETRIC_NOISE_RATES = [0.0, 0.2, 0.4, 0.6, 0.8]
ASYMMETRIC_NOISE_RATES = [0.0, 0.1, 0.2, 0.3, 0.4]

# --- Model Selection ---
USE_SIMPLE_CNN = False # If True, use the basic CNN architecture
USE_RESNET9 = True     # If True, use the ResNet9 architecture
# If both are False, ResNet18 will be used by default

# --- Results Saving ---
SAVE_RESULTS = False # Set to True to enable saving results periodically
# Define results directory (local default, overridden if on Colab with saving)
RESULTS_DIR = 'CoreML_Results_ResNet9'
if COLAB_ENV and SAVE_RESULTS:
    RESULTS_DIR = '/content/drive/MyDrive/CoreML_Results_ResNet9'
    try:
        print("Attempting to mount Google Drive...")
        drive.mount('/content/drive')
        os.makedirs(RESULTS_DIR, exist_ok=True)
        print(f"Google Drive mounted. Results directory: {RESULTS_DIR}")
    except Exception as e:
        print(f"WARN: Drive mount/creation failed: {e}. Disabling saving.")
        SAVE_RESULTS = False
elif SAVE_RESULTS:
     # Ensure local directory exists if saving locally
     os.makedirs(RESULTS_DIR, exist_ok=True)
     print(f"Results directory (local): {RESULTS_DIR}")


# ==============================================================================
# 1. Data Preparation
# ==============================================================================

print("\nSetting up data transformations...")
# Training transforms include augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    )
])
# Testing transforms include only normalization
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    )
])

print("Loading CIFAR-10 dataset...")
try:
    # Load raw training data (no transforms yet)
    train_data_clean = datasets.CIFAR10(
        root="data",
        train=True,
        download=True,
        transform=None
    )
    # Load test data with test transforms
    test_data = datasets.CIFAR10(
        root="data",
        train=False,
        download=True,
        transform=transform_test
    )
    class_names = train_data_clean.classes
    num_classes = len(class_names)
    print(
        f"CIFAR-10 loaded: {len(train_data_clean)} train, "
        f"{len(test_data)} test images."
    )
except Exception as e:
    print(f"FATAL: Error loading CIFAR-10 dataset: {e}")
    exit(1) # Exit script if data loading fails


# --- Noisy Dataset Class ---
class NoisyCIFAR10(Dataset):
    """
    A Dataset wrapper for CIFAR-10 that injects label noise.

    Args:
        dataset: The original CIFAR-10 dataset instance.
        noise_type (str): Type of noise ('symmetric' or 'asymmetric').
        eta (float): Noise rate (probability of flipping a label).
        num_classes (int): Number of classes in the dataset.
        transform (callable, optional): Transform to be applied on samples.
        seed (int): Random seed for noise generation reproducibility.
    """
    def __init__(
        self, dataset, noise_type='symmetric', eta=0.0,
        num_classes=10, transform=None, seed=42
    ):
        super().__init__()
        self.dataset = dataset
        self.noise_type = noise_type
        self.eta = eta
        self.num_classes = num_classes
        self.transform = transform
        self.seed = seed

        try:
            original_labels = np.array(dataset.targets)
        except AttributeError:
            original_labels = np.array([s[1] for s in dataset.samples])

        self.original_labels = original_labels.copy()
        self.noisy_labels = self._create_noisy_labels()

    def _create_noisy_labels(self):
        """Creates and returns the array of noisy labels."""
        labels = self.original_labels.copy()
        np.random.seed(self.seed + int(self.eta * 100))

        if self.eta > 0:
            if self.noise_type == 'symmetric':
                self._add_symmetric_noise(labels)
            elif self.noise_type == 'asymmetric':
                self._add_asymmetric_noise(labels)
            else:
                 print(f"Warning: Unknown noise_type '{self.noise_type}'.")
        return labels

    def _add_symmetric_noise(self, labels):
        """Applies symmetric noise in-place to the labels array."""
        mask = np.random.rand(len(labels)) < self.eta
        indices_to_corrupt = np.where(mask)[0]

        num_flipped = 0
        for i in indices_to_corrupt:
            original_label = labels[i]
            possible_labels = list(range(self.num_classes))
            if original_label in possible_labels:
                 possible_labels.remove(original_label)
            if possible_labels:
                labels[i] = np.random.choice(possible_labels)
                num_flipped += 1
        # print(f"    Symmetric noise ({self.eta*100:.0f}%): {num_flipped} labels flipped.")

    def _add_asymmetric_noise(self, labels):
        """Applies asymmetric noise in-place to the labels array."""
        source_classes = [9, 2, 4, 3, 5] # Truck, Bird, Deer, Cat, Dog
        noise_map = {9: 1, 2: 0, 4: 7, 3: 5, 5: 3}

        # Determine candidates based on original labels
        indices_in_source_classes = np.isin(
            self.original_labels, source_classes
        )
        mask_random = np.random.rand(len(labels)) < self.eta
        indices_to_corrupt = np.where(
            indices_in_source_classes & mask_random
        )[0]

        num_flipped = 0
        for i in indices_to_corrupt:
            current_label = labels[i]
            # Only flip if the *current* label is one of the source classes
            if current_label in noise_map:
                 labels[i] = noise_map[current_label]
                 num_flipped += 1
        # print(
        #     f"    Asymmetric noise ({self.eta*100:.0f}%): "
        #     f"{num_flipped} labels flipped from sources."
        # )

    def __getitem__(self, index):
        """Retrieves an image and its (potentially noisy) label."""
        img, _ = self.dataset[index]
        label = self.noisy_labels[index]
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label).long()

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.dataset)


# ==============================================================================
# 2. Loss Functions
# ==============================================================================
print("\nDefining Loss Functions...")

class CrossEntropyLoss(nn.Module):
    """Standard Cross Entropy Loss wrapper."""
    def __init__(self):
        super().__init__()
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        return self.cross_entropy(logits, labels)


class FocalLoss(nn.Module):
    """Focal Loss implementation."""
    def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
        super().__init__()
        if reduction not in ['mean', 'sum', 'none']:
            raise ValueError(f"Invalid reduction type: {reduction}")
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, labels):
        ce_loss = F.cross_entropy(logits, labels, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_term = self.alpha * (1 - pt)**self.gamma
        loss = focal_term * ce_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else: # 'none'
            return loss


class NormalizedCrossEntropy(nn.Module):
    """Normalized Cross Entropy Loss (NCE)."""
    def __init__(self, num_classes=10, scale=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.scale = scale

    def forward(self, logits, labels):
        logits_scaled = logits / self.scale
        log_softmax = F.log_softmax(logits_scaled, dim=1)
        loss_numerator = -log_softmax.gather(
            1, labels.unsqueeze(1)
        ).squeeze(1)
        loss_denominator = -log_softmax.sum(dim=1)
        loss_denominator = torch.clamp(loss_denominator, min=1e-6)
        loss = loss_numerator / loss_denominator
        return loss.mean()


class NormalizedFocalLoss(nn.Module):
    """Normalized Focal Loss (NFL)."""
    def __init__(
        self, num_classes=10, alpha=1.0, gamma=2.0, scale=1.0,
        reduction='mean'
    ):
        super().__init__()
        if reduction not in ['mean', 'sum', 'none']:
            raise ValueError(f"Invalid reduction type: {reduction}")
        self.alpha = alpha
        self.gamma = gamma
        self.scale = scale
        self.num_classes = num_classes
        self.reduction = reduction

    def forward(self, logits, labels):
        logits_scaled = logits / self.scale
        probs = F.softmax(logits_scaled, dim=1)
        pt = probs.gather(1, labels.unsqueeze(1)).squeeze(1)
        pt = torch.clamp(pt, min=1e-6, max=1.0 - 1e-6)

        focal_term_true = (1 - pt)**self.gamma
        log_pt = torch.log(pt)

        focal_term_all = (1 - probs)**self.gamma
        log_probs_all = F.log_softmax(logits_scaled, dim=1)
        denominator = torch.sum(focal_term_all * (-log_probs_all), dim=1)
        denominator = torch.clamp(denominator, min=1e-6)

        numerator = focal_term_true * (-log_pt)
        loss = self.alpha * numerator / denominator

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else: # 'none'
            return loss


class MeanAbsoluteErrorLoss(nn.Module):
    """Mean Absolute Error Loss (Passive Loss)."""
    def __init__(self, num_classes=10):
        super().__init__()
        self.num_classes = num_classes

    def forward(self, logits, labels):
        probs = F.softmax(logits, dim=1)
        labels_one_hot = F.one_hot(
            labels, num_classes=self.num_classes
        ).float()
        mae_sum_per_sample = torch.abs(probs - labels_one_hot).sum(dim=1)
        return (mae_sum_per_sample / self.num_classes).mean()


class ReverseCrossEntropy(nn.Module):
    """Reverse Cross Entropy Loss (Passive Loss)."""
    def __init__(self, num_classes=10):
        super().__init__()
        self.num_classes = num_classes

    def forward(self, logits, labels):
        probs = F.softmax(logits, dim=1)
        probs = torch.clamp(probs, min=1e-7, max=1.0) # Avoid log(0)
        labels_one_hot = F.one_hot(
            labels, num_classes=self.num_classes
        ).float()
        # Loss for incorrect classes: -log(1 - p_k)
        loss_incorrect = -torch.log(1.0 - probs + 1e-7)
        # Zero out contribution from the correct class
        loss_incorrect = loss_incorrect * (1.0 - labels_one_hot)
        return loss_incorrect.sum(dim=1).mean()


class APL(nn.Module):
    """Active-Passive Loss framework."""
    def __init__(
        self, active_loss: nn.Module, passive_loss: nn.Module,
        alpha: float = 1.0, beta: float = 1.0
    ):
        super().__init__()
        self.active_loss = active_loss
        self.passive_loss = passive_loss
        self.alpha = alpha
        self.beta = beta

    def forward(self, logits, labels):
        active_l = self.active_loss(logits, labels)
        passive_l = self.passive_loss(logits, labels)

        if torch.isnan(active_l) or torch.isinf(active_l) or \
           torch.isnan(passive_l) or torch.isinf(passive_l):
            print(
                f"WARN: NaN/Inf in APL! A:{active_l.item():.4f}, "
                f"P:{passive_l.item():.4f}."
            )
            return torch.tensor(
                1000.0, requires_grad=True
            ).to(logits.device)

        return self.alpha * active_l + self.beta * passive_l


# ==============================================================================
# 3. Model Architectures
# ==============================================================================
print("\nDefining Model Architectures...")

# --- Simple CNN ---
class SimpleCNN(nn.Module):
    """A basic CNN architecture."""
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(inplace=False),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


# --- ResNet Components ---
def conv3x3(in_planes, out_planes, stride=1):
    """Defines a 3x3 convolution with padding."""
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride,
        padding=1, bias=False
    )

class BasicBlock(nn.Module): # Used in ResNet18
    """Standard Basic Residual Block for ResNet."""
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes, self.expansion * planes,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        identity = x
        # Apply operations sequentially using functional ReLU
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out) # Functional ReLU
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(identity)
        out = F.relu(out) # Functional ReLU
        return out


class ResNet(nn.Module): # Base class for ResNet architectures
    """Generic ResNet base class."""
    def __init__(self, block, layers, num_classes=10):
        super().__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(
            3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.in_planes)
        # Residual layers are defined using _make_layer
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        # Classifier head
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        # Initialize weights after defining layers
        self._initialize_weights()

    def _make_layer(self, block, planes, num_blocks, stride):
        """Builds a ResNet layer composed of residual blocks."""
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for current_stride in strides:
            layers.append(block(self.in_planes, planes, current_stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def _initialize_weights(self):
         """Initializes model weights."""
         for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu'
                )
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x) # Use functional ReLU

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def ResNet18(num_classes=10):
    """Constructs a ResNet-18 model for CIFAR-10."""
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)


# --- ResNet9 Implementation (Using Functional ReLU) ---
def conv_bn(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    """Helper sequence: Convolution -> Batch Normalization."""
    return nn.Sequential(
        nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size,
            stride=stride, padding=padding, bias=False
        ),
        nn.BatchNorm2d(out_channels)
    )

class ResNet9Block(nn.Module):
    """Residual block specific to the ResNet9 architecture used here."""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = conv_bn(channels, channels)
        self.conv2 = conv_bn(channels, channels)

    def forward(self, x):
        residual = x
        # Apply functional ReLU *after* each Conv-BN sequence
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        # Add residual connection *before* the final ReLU of the block
        out += residual
        out = F.relu(out)
        return out

class ResNet9(nn.Module):
    """ResNet9 architecture often used for faster training on CIFAR-10."""
    def __init__(self, num_classes=10):
        super().__init__()
        # Initial Convolution: 3 -> 64 channels
        self.prep = conv_bn(3, 64)

        # Layer 1: Conv (64->128), Pool, ResBlock
        self.layer1_conv = conv_bn(64, 128)
        self.layer1_pool = nn.MaxPool2d(2) # 32x32 -> 16x16
        self.layer1_res = ResNet9Block(128)

        # Layer 2: Conv (128->256), Pool, ResBlock
        self.layer2_conv = conv_bn(128, 256)
        self.layer2_pool = nn.MaxPool2d(2) # 16x16 -> 8x8
        self.layer2_res = ResNet9Block(256)

        # Layer 3: Conv (256->512), Pool, ResBlock
        self.layer3_conv = conv_bn(256, 512)
        self.layer3_pool = nn.MaxPool2d(2) # 8x8 -> 4x4
        self.layer3_res = ResNet9Block(512)

        # Classifier Head
        self.pool = nn.MaxPool2d(4) # Pool features to 1x1
        self.flat = nn.Flatten()
        self.fc = nn.Linear(512, num_classes) # Final linear layer

    def forward(self, x):
        x = F.relu(self.prep(x)) # ReLU after initial Conv-BN

        x = self.layer1_conv(x)
        x = self.layer1_pool(x)
        x = F.relu(x) # ReLU after Conv-Pool
        x = self.layer1_res(x) # Residual block includes internal ReLUs

        x = self.layer2_conv(x)
        x = self.layer2_pool(x)
        x = F.relu(x)
        x = self.layer2_res(x)

        x = self.layer3_conv(x)
        x = self.layer3_pool(x)
        x = F.relu(x)
        x = self.layer3_res(x)

        x = self.pool(x)
        x = self.flat(x)
        x = self.fc(x)
        return x

print("Model definitions complete.")


# ==============================================================================
# 4. Training and Evaluation Loop
# ==============================================================================
print("\nDefining Training and Evaluation Function...")

def train_eval_model(
    loss_fn_instance: nn.Module,
    loss_name: str,
    noise_type: str = 'symmetric',
    eta: float = 0.2,
    epochs: int = EPOCHS,
    lr: float = LEARNING_RATE,
    batch_size: int = BATCH_SIZE,
    seed: int = SEED
) -> float:
    """Trains and evaluates a model for one experimental configuration."""
    print(
        f"\n--- Training Start: {loss_name} | "
        f"Noise: {noise_type}@{eta*100:.0f}% | Epochs: {epochs} ---"
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    torch.manual_seed(seed)
    np.random.seed(seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed_all(seed)

    noisy_trainset = NoisyCIFAR10(
        train_data_clean, noise_type=noise_type, eta=eta,
        num_classes=num_classes, transform=transform_train, seed=seed
    )
    trainloader = DataLoader(
        noisy_trainset, batch_size=batch_size, shuffle=True,
        num_workers=2, pin_memory=(device.type == 'cuda'), drop_last=True
    )
    testloader = DataLoader(
        test_data, batch_size=batch_size * 2, shuffle=False,
        num_workers=2, pin_memory=(device.type == 'cuda')
    )

    # --- Model Selection ---
    if USE_SIMPLE_CNN:
        model = SimpleCNN(num_classes=num_classes).to(device)
    elif USE_RESNET9:
        model = ResNet9(num_classes=num_classes).to(device)
    else:
        model = ResNet18(num_classes=num_classes).to(device)
    print(f"Using Model: {type(model).__name__}")

    # --- Optimizer, Scheduler, Loss ---
    optimizer = optim.SGD(
        model.parameters(), lr=lr, momentum=0.9, weight_decay=WEIGHT_DECAY
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = loss_fn_instance.to(device)

    best_test_acc = 0.0
    start_time = time.time()
    print(
        f"Optimizer: SGD(lr={lr:.1e}, momentum=0.9, wd={WEIGHT_DECAY:.1e})"
    )
    print(f"Scheduler: CosineAnnealingLR(T_max={epochs})")

    # --- Training Loop ---
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        num_batches = len(trainloader)

        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            if torch.isnan(loss) or torch.isinf(loss):
                print(
                    f"WARN - E{epoch+1:02d} B{i+1}/{num_batches}: "
                    f"NaN/Inf loss ({loss.item():.4f})! Skipping batch."
                )
                continue

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            with torch.no_grad():
                 _, predicted_train = torch.max(outputs.data, 1)
                 total_train += labels.size(0)
                 correct_train += (predicted_train == labels).sum().item()

        # --- Evaluation Step ---
        model.eval()
        correct_test = 0
        total_test = 0
        with torch.no_grad():
            for inputs, labels in testloader:
                 inputs, labels = inputs.to(device), labels.to(device)
                 outputs = model(inputs)
                 _, predicted = torch.max(outputs.data, 1)
                 total_test += labels.size(0)
                 correct_test += (predicted == labels).sum().item()

        # --- Log Epoch Metrics ---
        epoch_train_loss = running_loss / num_batches if num_batches > 0 else 0
        epoch_train_acc = 100.0 * correct_train / total_train if total_train > 0 else 0
        epoch_test_acc = 100.0 * correct_test / total_test if total_test > 0 else 0
        best_test_acc = max(best_test_acc, epoch_test_acc)
        current_lr = scheduler.get_last_lr()[0]

        print(
            f"E {epoch+1:02d}/{epochs} | L:{epoch_train_loss:.4f} | "
            f"TrA:{epoch_train_acc:6.2f}% | TeA:{epoch_test_acc:6.2f}% "
            f"(B:{best_test_acc:6.2f}%) | LR:{current_lr:.6f}"
        )
        scheduler.step()

    # --- End of Training ---
    end_time = time.time()
    elapsed_time_min = (end_time - start_time) / 60
    print(
        f"--- Training End: {loss_name} (Eta:{eta:.2f}). "
        f"Best Test Acc: {best_test_acc:.2f}%. "
        f"Duration: {elapsed_time_min:.2f} min ---"
    )
    return best_test_acc

print("Training function defined.")


# ==============================================================================
# 5. Experiment Setup
# ==============================================================================
print("\n" + "="*20 + " Experiment Configuration " + "="*20)

all_losses = {
    'CE': CrossEntropyLoss(),
    'FL (g=1)': FocalLoss(gamma=1.0),
    'FL (g=2)': FocalLoss(gamma=2.0),
    'NCE (s=1)': NormalizedCrossEntropy(num_classes=num_classes, scale=1.0),
    'NFL (g=1,s=1)': NormalizedFocalLoss(num_classes, alpha=1.0, gamma=1.0, scale=1.0),
    'NFL (g=2,s=1)': NormalizedFocalLoss(num_classes, alpha=1.0, gamma=2.0, scale=1.0),
    'MAE': MeanAbsoluteErrorLoss(num_classes=num_classes),
    'RCE': ReverseCrossEntropy(num_classes=num_classes),
    'APL (NCE+MAE)': APL(NormalizedCrossEntropy(num_classes, scale=1.0), MeanAbsoluteErrorLoss(num_classes), alpha=1.0, beta=1.0),
    'APL (NCE+RCE)': APL(NormalizedCrossEntropy(num_classes, scale=1.0), ReverseCrossEntropy(num_classes), alpha=1.0, beta=0.5),
    'APL (NFLg1+MAE)': APL(NormalizedFocalLoss(num_classes, gamma=1.0, scale=1.0), MeanAbsoluteErrorLoss(num_classes), alpha=1.0, beta=1.0),
}

if LOSSES_TO_RUN is None:
    losses_to_test = all_losses
    print("Running ALL defined loss functions.")
else:
    losses_to_test = {}
    for name in LOSSES_TO_RUN:
         if name in all_losses: losses_to_test[name] = all_losses[name]
         else: print(f"Warning: Loss '{name}' not found. Skipping.")
    if not losses_to_test: print(f"FATAL: No valid losses selected."); exit(1)
    print(f"Running SELECTED losses: {list(losses_to_test.keys())}")

model_name_str = "ResNet9" if USE_RESNET9 else ("SimpleCNN" if USE_SIMPLE_CNN else "ResNet18")
print(f"Model Architecture: {model_name_str}")
print(f"Epochs per run: {EPOCHS}")
print(f"Symmetric Noise Rates: {SYMMETRIC_NOISE_RATES}")
if RUN_ASYMMETRIC: print(f"Asymmetric Noise Rates: {ASYMMETRIC_NOISE_RATES}")
print(f"Saving Results: {SAVE_RESULTS}" + (f" to {RESULTS_DIR}" if SAVE_RESULTS else ""))
print("="*60)


# ==============================================================================
# 6. Run Experiments
# ==============================================================================
print("\n *** Starting Experiment Execution ***\n")
results_symmetric = {}
results_asymmetric = {}

if SAVE_RESULTS:
    sym_file = os.path.join(RESULTS_DIR, 'results_symmetric.pth')
    asym_file = os.path.join(RESULTS_DIR, 'results_asymmetric.pth')
    if os.path.exists(sym_file):
        try: results_symmetric = torch.load(sym_file); print(f"Loaded sym results from {sym_file}")
        except Exception as e: print(f"WARN: Err load sym:{e}"); results_symmetric = {}
    if os.path.exists(asym_file):
        try: results_asymmetric = torch.load(asym_file); print(f"Loaded asym results from {asym_file}")
        except Exception as e: print(f"WARN: Err load asym:{e}"); results_asymmetric = {}

def save_experiment_results():
    if SAVE_RESULTS:
        try:
            print(f"\nCheckpoint: Saving results..."); os.makedirs(RESULTS_DIR, exist_ok=True)
            sym_file=os.path.join(RESULTS_DIR, 'results_symmetric.pth'); asym_file=os.path.join(RESULTS_DIR, 'results_asymmetric.pth')
            torch.save(results_symmetric, sym_file)
            if RUN_ASYMMETRIC: torch.save(results_asymmetric, asym_file)
            print("Checkpoint saved.")
        except Exception as e: print(f"WARN: Err saving results checkpoint: {e}")

try:
    if RUN_SYMMETRIC:
        print(f"\n=== Processing Symmetric Noise Experiments ===")
        for loss_name, loss_instance in losses_to_test.items():
            current_results = results_symmetric.get(loss_name)
            if isinstance(current_results, list) and len(current_results) == len(SYMMETRIC_NOISE_RATES):
                print(f" -> Skip {loss_name}(Sym)-Exist"); continue
            print(f" --> Running {loss_name}(Sym)..."); noise_accuracies = []
            for eta in SYMMETRIC_NOISE_RATES:
                acc = train_eval_model(loss_instance, loss_name, 'symmetric', eta) # Pass correct args
                noise_accuracies.append(acc)
            results_symmetric[loss_name] = noise_accuracies; save_experiment_results()

    if RUN_ASYMMETRIC:
        print(f"\n=== Processing Asymmetric Noise Experiments ===")
        for loss_name, loss_instance in losses_to_test.items():
             current_results = results_asymmetric.get(loss_name)
             if isinstance(current_results, list) and len(current_results) == len(ASYMMETRIC_NOISE_RATES):
                 print(f" -> Skip {loss_name}(Asym)-Exist"); continue
             print(f" --> Running {loss_name}(Asym)..."); noise_accuracies = []
             for eta in ASYMMETRIC_NOISE_RATES:
                 acc = train_eval_model(loss_instance, loss_name, 'asymmetric', eta) # Pass correct args
                 noise_accuracies.append(acc)
             results_asymmetric[loss_name] = noise_accuracies; save_experiment_results()

finally:
    print("\nExperiment execution phase finished or interrupted.")
    save_experiment_results()


# ==============================================================================
# 7. Plot Results
# ==============================================================================
print("\nGenerating Plots...")
loss_styles={'CE':{'marker':'o','color':'C0','linestyle':'-'},'FL (g=1)':{'marker':'P','color':'C1','linestyle':'-'},'FL (g=2)':{'marker':'X','color':'C2','linestyle':'-'},'NCE (s=1)':{'marker':'s','color':'C3','linestyle':'-'},'NFL (g=1,s=1)':{'marker':'d','color':'C4','linestyle':'-'},'NFL (g=2,s=1)':{'marker':'D','color':'C5','linestyle':'-'},'MAE':{'marker':'^','color':'C6','linestyle':'-'},'RCE':{'marker':'v','color':'C7','linestyle':'-'},'APL (NCE+MAE)':{'marker':'s','color':'C3','linestyle':'--'},'APL (NCE+RCE)':{'marker':'v','color':'C7','linestyle':'--'},'APL (NFLg1+MAE)':{'marker':'d','color':'C4','linestyle':'--'},}
default_style={'marker':'*','color':'black','linestyle':':'}
def plot_results(results, noise_rates, title_suffix):
    valid_losses=[name for name,data in results.items() if data and isinstance(data,list) and len(data)==len(noise_rates)]
    if not valid_losses: print(f"No plottable results for {title_suffix}."); return
    plt.figure(figsize=(12,8)); model_name_plot="ResNet9" if USE_RESNET9 else("SimpleCNN" if USE_SIMPLE_CNN else "ResNet18")
    for loss_name in valid_losses:
        acc_list=results[loss_name]; style=loss_styles.get(loss_name,default_style)
        plt.plot(noise_rates,acc_list,label=loss_name,marker=style.get('marker','*'),color=style.get('color'),linestyle=style.get('linestyle',':'))
    plt.xlabel('Noise Rate (η)'); plt.ylabel('Best Test Accuracy (%)'); plt.title(f'Model Perf Under {title_suffix} Noise\n(E:{EPOCHS}, M:{model_name_plot})'); num_entries=len(valid_losses); num_cols=1 if num_entries<=6 else math.ceil(num_entries/6); plt.legend(loc='best',fontsize='small',ncol=num_cols); plt.grid(True,linestyle='--',alpha=0.7); plt.ylim(bottom=0,top=100); plt.xticks(noise_rates); plt.tight_layout(); plt.show()
if RUN_SYMMETRIC: plot_results(results_symmetric,SYMMETRIC_NOISE_RATES,"Symmetric")
if RUN_ASYMMETRIC: plot_results(results_asymmetric,ASYMMETRIC_NOISE_RATES,"Asymmetric")


# ==============================================================================
# 8. Print Final Results Tables
# ==============================================================================
print("\n--- Final Results Tables ---")
def print_results_table(results_dict,rates,name):
    print(f"\n{name} Noise Results (Best Test Accuracy %):"); valid_results={k:v for k,v in results_dict.items() if isinstance(v,list) and len(v)==len(rates)}
    if not valid_results: print("  No valid results."); return
    try:
        df=pd.DataFrame(valid_results); df=df.reindex(index=rates); df.index.name='Noise Rate (η)'
        ordered_cols=[col for col in all_losses if col in df.columns]; remaining_cols=[col for col in df.columns if col not in ordered_cols]; final_cols=ordered_cols+remaining_cols; df=df[final_cols]
        print(df.to_string(float_format="%.2f"))
    except Exception as e: print(f"  Table Error:{e}\n  Raw Data:{results_dict}")
if RUN_SYMMETRIC: print_results_table(results_symmetric,SYMMETRIC_NOISE_RATES,"Symmetric")
if RUN_ASYMMETRIC: print_results_table(results_asymmetric,ASYMMETRIC_NOISE_RATES,"Asymmetric")

print("\n--- Experiment Run Complete ---")

PyTorch Version: 2.6.0+cu124
Torchvision Version: 0.21.0+cu124
Running in Colab: True

Setting up data transformations...
Loading CIFAR-10 dataset...


100%|██████████| 170M/170M [00:12<00:00, 13.1MB/s]


CIFAR-10 loaded: 50000 train, 10000 test images.

Defining Loss Functions...

Defining Model Architectures...
Model definitions complete.

Defining Training and Evaluation Function...
Training function defined.

Running SELECTED losses: ['CE', 'NCE (s=1)', 'APL (NCE+MAE)', 'APL (NCE+RCE)']
Model Architecture: ResNet9
Epochs per run: 50
Symmetric Noise Rates: [0.0, 0.2, 0.4, 0.6, 0.8]
Saving Results: False

 *** Starting Experiment Execution ***


=== Processing Symmetric Noise Experiments ===
 --> Running CE(Sym)...

--- Training Start: CE | Noise: symmetric@0% | Epochs: 50 ---
Using device: cuda
Using Model: ResNet9
Optimizer: SGD(lr=1.0e-01, momentum=0.9, wd=5.0e-04)
Scheduler: CosineAnnealingLR(T_max=50)
E 01/50 | L:3.5731 | TrA: 15.99% | TeA: 19.77% (B: 19.77%) | LR:0.100000
E 02/50 | L:2.0438 | TrA: 20.06% | TeA: 22.12% (B: 22.12%) | LR:0.099901
E 03/50 | L:1.9187 | TrA: 23.83% | TeA: 22.89% (B: 22.89%) | LR:0.099606
E 04/50 | L:1.7899 | TrA: 29.45% | TeA: 27.69% (B: 27.69%) | LR: