# Baseline

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import time
import matplotlib.pyplot as plt
import os
from sklearn.metrics import classification_report  # For additional classification metrics

# --- Configuration & Hyperparameters ---
LEARNING_RATE = 0.001
BATCH_SIZE = 128
EPOCHS = 20
DROPOUT_RATE = 0.5
ADAM_BETA_1 = 0.9
ADAM_BETA_2 = 0.999
ADAM_EPSILON = 1e-7 # Note: PyTorch Adam uses 'eps' not 'epsilon'
NUM_RUNS = 3
PLOT_SAVE_DIR = "plots_pytorch"  # Directory to save the plots

# --- Device Selection ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- Dataset Configurations (Adapted for PyTorch/Torchvision) ---
DATASETS = {
    "mnist": {
        "name": "mnist",
        "load_fn": torchvision.datasets.MNIST,
        "input_shape": (1, 28, 28), # PyTorch: C, H, W
        "num_classes": 10,
        "target_transform": None # Labels are already integers
    },
    "fashion_mnist": {
        "name": "fashion_mnist",
        "load_fn": torchvision.datasets.FashionMNIST,
        "input_shape": (1, 28, 28),
        "num_classes": 10,
         "target_transform": None
    },
    "cifar10": {
        "name": "cifar10",
        "load_fn": torchvision.datasets.CIFAR10,
        "input_shape": (3, 32, 32),
        "num_classes": 10,
         "target_transform": None
    },
    "cifar100": {
        "name": "cifar100",
        "load_fn": torchvision.datasets.CIFAR100,
        "input_shape": (3, 32, 32),
        "num_classes": 100,
         "target_transform": None
    },
    "oxford_iiit_pet": {
        "name": "oxford_iiit_pet",
        "load_fn": torchvision.datasets.OxfordIIITPet,
        "input_shape": (3, 128, 128), # Resize target
        "num_classes": 37,
        # Labels are 0-36, already suitable for CrossEntropyLoss
        "target_transform": None
    }
}

# --- Preprocessing Function (using torchvision.transforms) ---
def get_transforms(input_shape):
    """Returns appropriate transforms for training and testing."""
    # Basic transforms: Resize, convert to tensor, normalize to [0, 1]
    # PyTorch ToTensor automatically scales to [0, 1] and permutes to C, H, W
    img_size = input_shape[1:] # H, W
    num_channels = input_shape[0] # C

    transform_list = [transforms.Resize(img_size)]

    # Handle grayscale conversion if needed (input is RGB but model needs Gray)
    if num_channels == 1:
         # Add grayscale transform *if* the source dataset isn't already grayscale
         # MNIST/FashionMNIST are loaded as grayscale by default.
         # For OxfordPet, if we wanted grayscale, we'd add Grayscale here.
         # transform_list.append(transforms.Grayscale(num_output_channels=1))
         pass # Assuming dataset loader provides correct number of channels

    # Handle RGB conversion if needed (input is Gray but model needs RGB)
    elif num_channels == 3:
         # Add RGB conversion *if* the source dataset isn't already RGB
         # MNIST/FashionMNIST would need this if target shape was 3 channels.
         # transform_list.append(transforms.Grayscale(num_output_channels=3)) # Hacky way
         pass # Assuming dataset loader provides correct number of channels


    transform_list.append(transforms.ToTensor()) # Converts to Tensor, scales to [0,1], changes to C, H, W

    # Optional: Add normalization (e.g., for CIFAR)
    # if input_shape == (3, 32, 32): # Example for CIFAR
    #     transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))

    return transforms.Compose(transform_list)

# --- Model Building Function (PyTorch nn.Module) ---
class ClassifierCNN(nn.Module):
    def __init__(self, input_shape, num_classes, dropout_rate):
        super(ClassifierCNN, self).__init__()
        C, H, W = input_shape

        self.conv_block1 = nn.Sequential(
            nn.Conv2d(in_channels=C, out_channels=32, kernel_size=3, padding='same'), # Use padding='same' for simplicity
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # Calculate shape after conv1 + pool1
        H //= 2
        W //= 2

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # Calculate shape after conv2 + pool2
        H //= 2
        W //= 2

        # Calculate the flattened size
        flattened_size = 64 * H * W

        self.fc_block = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=flattened_size, out_features=64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            # Output raw logits for nn.CrossEntropyLoss
            nn.Linear(in_features=64, out_features=num_classes)
        )

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.fc_block(x)
        return x

# --- Plotting Functions (Adapted for list-based history) ---
def plot_mean_std_curves(histories, dataset_name, epochs, save_dir="plots_pytorch"):
    """Plots mean and std dev of training and validation accuracy and saves the figure."""
    # histories is expected to be a list of dictionaries, where each dict contains
    # 'accuracy', 'val_accuracy', 'loss', 'val_loss' lists.

    all_acc = [h['accuracy'] for h in histories]
    all_val_acc = [h['val_accuracy'] for h in histories]

    # Pad shorter histories if runs ended early (though not expected with fixed epochs)
    max_len = epochs # Assume all runs completed all epochs
    all_acc = [np.pad(acc, (0, max_len - len(acc)), 'edge') for acc in all_acc]
    all_val_acc = [np.pad(val_acc, (0, max_len - len(val_acc)), 'edge') for val_acc in all_val_acc]

    mean_acc = np.mean(all_acc, axis=0)
    std_acc = np.std(all_acc, axis=0)
    mean_val_acc = np.mean(all_val_acc, axis=0)
    std_val_acc = np.std(all_val_acc, axis=0)

    epoch_range = range(1, max_len + 1)

    plt.figure(figsize=(12, 6))
    plt.title(f'{dataset_name.upper()} - Mean Accuracy over {len(histories)} Runs (PyTorch)')

    plt.plot(epoch_range, mean_acc, label='Mean Training Accuracy', color='blue')
    plt.fill_between(epoch_range, mean_acc - std_acc, mean_acc + std_acc,
                     alpha=0.2, color='blue', label='Training Acc ±1 std dev')
    plt.plot(epoch_range, mean_val_acc, label='Mean Validation Accuracy', color='orange')
    plt.fill_between(epoch_range, mean_val_acc - std_val_acc, mean_val_acc + std_val_acc,
                     alpha=0.2, color='orange', label='Validation Acc ±1 std dev')

    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.ylim(0, 1.05)
    plt.legend()
    plt.grid(True)

    try:
        os.makedirs(save_dir, exist_ok=True)
        filename = os.path.join(save_dir, f"results_{dataset_name}_accuracy_plot_pytorch.png")
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Accuracy plot saved to: {filename}")
    except OSError as e:
        print(f"Error saving plot to {save_dir}: {e}")
    plt.close() # Close the figure to free memory

def plot_loss_curves_log(history, dataset_name, save_dir="plots_pytorch"):
    """Plots training and validation loss curves from the last run."""
    epoch_range = range(1, len(history['loss']) + 1)
    plt.figure(figsize=(12, 6))
    plt.title(f'{dataset_name.upper()} - Loss Curves (Last Run) (PyTorch)')
    plt.plot(epoch_range, history['loss'], label='Training Loss', color='red')
    plt.plot(epoch_range, history['val_loss'], label='Validation Loss', color='green')
    # plt.yscale('log') # Optional: log scale
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    try:
        os.makedirs(save_dir, exist_ok=True)
        filename = os.path.join(save_dir, f"results_{dataset_name}_loss_log_plot_pytorch.png")
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Loss plot saved to: {filename}")
    except OSError as e:
        print(f"Error saving plot to {save_dir}: {e}")
    plt.close() # Close the figure

# --- Training Loop Function ---
def train_on_dataset(dataset_key, save_plot_dir):
    """Loads, preprocesses, trains, and evaluates the model on a given dataset using PyTorch."""
    print(f"\n--- Processing Dataset: {dataset_key.upper()} (PyTorch) ---")
    start_time_total = time.time()

    config = DATASETS[dataset_key]
    input_shape = config["input_shape"]
    num_classes = config["num_classes"]
    DatasetClass = config["load_fn"]
    target_transform = config["target_transform"]

    # --- Data Loading and Preprocessing ---
    transform = get_transforms(input_shape)

    # Special handling for OxfordIIITPet split argument
    if dataset_key == "oxford_iiit_pet":
        train_dataset = DatasetClass(root='./data', split='trainval', download=True, transform=transform, target_transform=target_transform)
        test_dataset = DatasetClass(root='./data', split='test', download=True, transform=transform, target_transform=target_transform)
    else:
        # Most torchvision datasets use train=True/False
        train_dataset = DatasetClass(root='./data', train=True, download=True, transform=transform, target_transform=target_transform)
        test_dataset = DatasetClass(root='./data', train=False, download=True, transform=transform, target_transform=target_transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    print(f"Loaded dataset: {config['name']} with {len(train_dataset)} training samples and {len(test_dataset)} test samples.")

    # --- Model Training over Multiple Runs ---
    histories = []
    test_accuracies = []
    run_times = []
    final_model_state = None # Store state_dict of the last model

    for run in range(NUM_RUNS):
        print(f"\n--- Run {run + 1}/{NUM_RUNS} ---")
        start_time_run = time.time()

        model = ClassifierCNN(input_shape, num_classes, DROPOUT_RATE).to(DEVICE)
        # Use CrossEntropyLoss which expects raw logits
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(),
                               lr=LEARNING_RATE,
                               betas=(ADAM_BETA_1, ADAM_BETA_2),
                               eps=ADAM_EPSILON)

        run_history = {'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': []}

        for epoch in range(EPOCHS):
            # --- Training Phase ---
            model.train()
            running_loss = 0.0
            correct_train = 0
            total_train = 0
            for i, (inputs, labels) in enumerate(train_loader):
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward + backward + optimize
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                # Statistics
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()

            epoch_loss = running_loss / len(train_loader)
            epoch_acc = correct_train / total_train
            run_history['loss'].append(epoch_loss)
            run_history['accuracy'].append(epoch_acc)

            # --- Validation Phase ---
            model.eval()
            running_val_loss = 0.0
            correct_val = 0
            total_val = 0
            with torch.no_grad():
                for inputs, labels in test_loader:
                    inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    running_val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total_val += labels.size(0)
                    correct_val += (predicted == labels).sum().item()

            epoch_val_loss = running_val_loss / len(test_loader)
            epoch_val_acc = correct_val / total_val
            run_history['val_loss'].append(epoch_val_loss)
            run_history['val_accuracy'].append(epoch_val_acc)

            print(f"Epoch [{epoch + 1}/{EPOCHS}] Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f}")

        # --- End of Run ---
        histories.append(run_history)
        test_accuracies.append(epoch_val_acc) # Use the final epoch's validation accuracy
        end_time_run = time.time()
        run_time = end_time_run - start_time_run
        run_times.append(run_time)
        print(f"Run {run + 1} completed in {run_time:.2f} seconds. Final Test Accuracy: {epoch_val_acc:.4f}")
        if run == NUM_RUNS - 1: # Save the state of the last model
             final_model_state = model.state_dict()


    avg_acc = np.mean(test_accuracies)
    std_acc = np.std(test_accuracies)
    avg_time = np.mean(run_times)
    total_time = time.time() - start_time_total

    print(f"\n--- {dataset_key.upper()} Final Results ({NUM_RUNS} Runs) (PyTorch) ---")
    print(f"Individual Test Accuracies: {[f'{acc:.4f}' for acc in test_accuracies]}")
    print(f"Average Test Accuracy: {avg_acc:.4f}")
    print(f"Standard Deviation of Test Accuracy: {std_acc:.4f}")
    print(f"Average Run Time: {avg_time:.2f} seconds")
    print(f"Total Time for {dataset_key.upper()}: {total_time:.2f} seconds")

    # --- Plotting Accuracy and Loss Curves ---
    plot_mean_std_curves(histories, dataset_key, EPOCHS, save_dir=save_plot_dir)
    plot_loss_curves_log(histories[-1], dataset_key, save_dir=save_plot_dir) # Plot last run's loss

    # --- Additional Classification Metrics (using the last run's model) ---
    if final_model_state is not None:
        print("\n=== Additional Classification Metrics (sklearn) ===")
        # Load the final model state
        final_model = ClassifierCNN(input_shape, num_classes, DROPOUT_RATE).to(DEVICE)
        final_model.load_state_dict(final_model_state)
        final_model.eval()

        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.to(DEVICE)
                outputs = final_model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                all_preds.extend(predicted.cpu().numpy()) # Move preds to CPU and convert to numpy
                all_labels.extend(labels.cpu().numpy())   # Move labels to CPU and convert to numpy (already numpy/int if loaded correctly)

        print(classification_report(all_labels, all_preds, digits=4, zero_division=0))
        # Added zero_division=0 to handle cases where a class might have no predicted samples
    else:
        print("Warning: No final model state available for computing additional metrics.")

# --- Main Execution ---
if __name__ == "__main__":
    # GPU check is done globally via DEVICE definition
    # Memory growth management is generally less explicit in PyTorch

    try:
        os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
        print(f"Plots will be saved in '{PLOT_SAVE_DIR}/' directory.")
    except OSError as e:
        print(f"Could not create plot directory '{PLOT_SAVE_DIR}': {e}. Using current directory.")
        PLOT_SAVE_DIR = "."

    for dataset in DATASETS.keys():
        train_on_dataset(dataset, save_plot_dir=PLOT_SAVE_DIR)

    print("\n--- All Dataset Processing Complete (PyTorch) ---")

Using device: cuda
Plots will be saved in 'plots_pytorch/' directory.

--- Processing Dataset: MNIST (PyTorch) ---


100%|██████████| 9.91M/9.91M [00:18<00:00, 545kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 66.1kB/s]
100%|██████████| 1.65M/1.65M [00:06<00:00, 245kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 2.88MB/s]


Loaded dataset: mnist with 60000 training samples and 10000 test samples.

--- Run 1/3 ---
Epoch [1/20] Train Loss: 0.4648 Acc: 0.8529 | Val Loss: 0.0842 Acc: 0.9744
Epoch [2/20] Train Loss: 0.1771 Acc: 0.9463 | Val Loss: 0.0534 Acc: 0.9823
Epoch [3/20] Train Loss: 0.1306 Acc: 0.9604 | Val Loss: 0.0394 Acc: 0.9869
Epoch [4/20] Train Loss: 0.1081 Acc: 0.9667 | Val Loss: 0.0380 Acc: 0.9872
Epoch [5/20] Train Loss: 0.0975 Acc: 0.9698 | Val Loss: 0.0359 Acc: 0.9894
Epoch [6/20] Train Loss: 0.0862 Acc: 0.9736 | Val Loss: 0.0315 Acc: 0.9903
Epoch [7/20] Train Loss: 0.0789 Acc: 0.9754 | Val Loss: 0.0324 Acc: 0.9902
Epoch [8/20] Train Loss: 0.0738 Acc: 0.9769 | Val Loss: 0.0334 Acc: 0.9893
Epoch [9/20] Train Loss: 0.0638 Acc: 0.9791 | Val Loss: 0.0314 Acc: 0.9913
Epoch [10/20] Train Loss: 0.0578 Acc: 0.9814 | Val Loss: 0.0279 Acc: 0.9913
Epoch [11/20] Train Loss: 0.0517 Acc: 0.9823 | Val Loss: 0.0266 Acc: 0.9909
Epoch [12/20] Train Loss: 0.0478 Acc: 0.9837 | Val Loss: 0.0289 Acc: 0.9912
Epoch 

100%|██████████| 26.4M/26.4M [00:02<00:00, 11.3MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 213kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.40MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 11.6MB/s]

Loaded dataset: fashion_mnist with 60000 training samples and 10000 test samples.

--- Run 1/3 ---





Epoch [1/20] Train Loss: 0.7566 Acc: 0.7235 | Val Loss: 0.4328 Acc: 0.8417
Epoch [2/20] Train Loss: 0.5082 Acc: 0.8138 | Val Loss: 0.3752 Acc: 0.8642
Epoch [3/20] Train Loss: 0.4482 Acc: 0.8357 | Val Loss: 0.3309 Acc: 0.8792
Epoch [4/20] Train Loss: 0.4087 Acc: 0.8510 | Val Loss: 0.3118 Acc: 0.8864
Epoch [5/20] Train Loss: 0.3836 Acc: 0.8590 | Val Loss: 0.2947 Acc: 0.8942
Epoch [6/20] Train Loss: 0.3608 Acc: 0.8659 | Val Loss: 0.2827 Acc: 0.8974
Epoch [7/20] Train Loss: 0.3439 Acc: 0.8721 | Val Loss: 0.2755 Acc: 0.8982
Epoch [8/20] Train Loss: 0.3276 Acc: 0.8786 | Val Loss: 0.2649 Acc: 0.9034
Epoch [9/20] Train Loss: 0.3124 Acc: 0.8840 | Val Loss: 0.2823 Acc: 0.8976
Epoch [10/20] Train Loss: 0.3020 Acc: 0.8871 | Val Loss: 0.2566 Acc: 0.9082
Epoch [11/20] Train Loss: 0.2894 Acc: 0.8931 | Val Loss: 0.2446 Acc: 0.9113
Epoch [12/20] Train Loss: 0.2783 Acc: 0.8948 | Val Loss: 0.2458 Acc: 0.9111
Epoch [13/20] Train Loss: 0.2679 Acc: 0.8999 | Val Loss: 0.2463 Acc: 0.9097
Epoch [14/20] Train L

100%|██████████| 170M/170M [00:57<00:00, 2.96MB/s] 


Loaded dataset: cifar10 with 50000 training samples and 10000 test samples.

--- Run 1/3 ---
Epoch [1/20] Train Loss: 1.8068 Acc: 0.3275 | Val Loss: 1.4730 Acc: 0.4741
Epoch [2/20] Train Loss: 1.5411 Acc: 0.4333 | Val Loss: 1.3301 Acc: 0.5346
Epoch [3/20] Train Loss: 1.4168 Acc: 0.4836 | Val Loss: 1.2090 Acc: 0.5748
Epoch [4/20] Train Loss: 1.3287 Acc: 0.5153 | Val Loss: 1.1473 Acc: 0.5996
Epoch [5/20] Train Loss: 1.2598 Acc: 0.5431 | Val Loss: 1.0937 Acc: 0.6207
Epoch [6/20] Train Loss: 1.2123 Acc: 0.5609 | Val Loss: 1.0357 Acc: 0.6384
Epoch [7/20] Train Loss: 1.1752 Acc: 0.5774 | Val Loss: 1.0309 Acc: 0.6338
Epoch [8/20] Train Loss: 1.1442 Acc: 0.5874 | Val Loss: 0.9866 Acc: 0.6553
Epoch [9/20] Train Loss: 1.1107 Acc: 0.6008 | Val Loss: 0.9609 Acc: 0.6729
Epoch [10/20] Train Loss: 1.0872 Acc: 0.6102 | Val Loss: 0.9636 Acc: 0.6729
Epoch [11/20] Train Loss: 1.0660 Acc: 0.6158 | Val Loss: 0.9239 Acc: 0.6787
Epoch [12/20] Train Loss: 1.0445 Acc: 0.6223 | Val Loss: 0.9194 Acc: 0.6785
Epoc

100%|██████████| 169M/169M [00:17<00:00, 9.86MB/s] 


Loaded dataset: cifar100 with 50000 training samples and 10000 test samples.

--- Run 1/3 ---
Epoch [1/20] Train Loss: 4.3862 Acc: 0.0352 | Val Loss: 3.9690 Acc: 0.1003
Epoch [2/20] Train Loss: 4.0742 Acc: 0.0651 | Val Loss: 3.7945 Acc: 0.1342
Epoch [3/20] Train Loss: 3.9537 Acc: 0.0781 | Val Loss: 3.6501 Acc: 0.1625
Epoch [4/20] Train Loss: 3.8973 Acc: 0.0868 | Val Loss: 3.5931 Acc: 0.1744
Epoch [5/20] Train Loss: 3.8453 Acc: 0.0922 | Val Loss: 3.5193 Acc: 0.1895
Epoch [6/20] Train Loss: 3.8185 Acc: 0.0937 | Val Loss: 3.4550 Acc: 0.1968
Epoch [7/20] Train Loss: 3.7884 Acc: 0.0969 | Val Loss: 3.4643 Acc: 0.1927
Epoch [8/20] Train Loss: 3.7644 Acc: 0.1019 | Val Loss: 3.4469 Acc: 0.1998
Epoch [9/20] Train Loss: 3.7440 Acc: 0.1049 | Val Loss: 3.4307 Acc: 0.2137
Epoch [10/20] Train Loss: 3.7190 Acc: 0.1060 | Val Loss: 3.3915 Acc: 0.2128
Epoch [11/20] Train Loss: 3.7012 Acc: 0.1072 | Val Loss: 3.3336 Acc: 0.2183
Epoch [12/20] Train Loss: 3.6784 Acc: 0.1126 | Val Loss: 3.2887 Acc: 0.2232
Epo

100%|██████████| 792M/792M [00:32<00:00, 24.6MB/s] 
100%|██████████| 19.2M/19.2M [00:01<00:00, 11.5MB/s]


Loaded dataset: oxford_iiit_pet with 3680 training samples and 3669 test samples.

--- Run 1/3 ---
Epoch [1/20] Train Loss: 3.6806 Acc: 0.0272 | Val Loss: 3.6140 Acc: 0.0273
Epoch [2/20] Train Loss: 3.6137 Acc: 0.0223 | Val Loss: 3.6136 Acc: 0.0273
Epoch [3/20] Train Loss: 3.6101 Acc: 0.0288 | Val Loss: 3.5978 Acc: 0.0401
Epoch [4/20] Train Loss: 3.5973 Acc: 0.0378 | Val Loss: 3.5824 Acc: 0.0523
Epoch [5/20] Train Loss: 3.5760 Acc: 0.0418 | Val Loss: 3.5547 Acc: 0.0548
Epoch [6/20] Train Loss: 3.5621 Acc: 0.0408 | Val Loss: 3.5402 Acc: 0.0537
Epoch [7/20] Train Loss: 3.5606 Acc: 0.0424 | Val Loss: 3.5241 Acc: 0.0526
Epoch [8/20] Train Loss: 3.5447 Acc: 0.0418 | Val Loss: 3.5082 Acc: 0.0561
Epoch [9/20] Train Loss: 3.5348 Acc: 0.0394 | Val Loss: 3.5041 Acc: 0.0624
Epoch [10/20] Train Loss: 3.5110 Acc: 0.0437 | Val Loss: 3.4600 Acc: 0.0597
Epoch [11/20] Train Loss: 3.4910 Acc: 0.0489 | Val Loss: 3.4577 Acc: 0.0725
Epoch [12/20] Train Loss: 3.4953 Acc: 0.0459 | Val Loss: 3.4515 Acc: 0.063

# LearnableLoss

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torch.autograd.functional import jacobian # For potential Hessian calculation alternative
import numpy as np
import time
import datetime # Although not used in the original logic, kept import if needed
import matplotlib.pyplot as plt
import os
from sklearn.metrics import classification_report # additional classification metrics
import warnings

# Suppress specific warnings if needed (e.g., UserWarnings from matplotlib)
warnings.filterwarnings("ignore", category=UserWarning)

# Set random seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    # Might add these for potentially better reproducibility, but can slow down training
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

# --- Configuration & Hyperparameters ---
LEARNING_RATE_CLASSIFIER = 0.001
LEARNING_RATE_LOSS_NETWORK = 0.001
BATCH_SIZE = 1024
EPOCHS = 20  # Adjust as needed
LAMBDA_CONST = 0.1  # Weight for the constraint losses
NUM_RUNS = 3        # Number of runs for averaging
PLOT_SAVE_DIR = "plots_learnable_loss_detailed_pytorch" # Directory to save plots

# --- Device Selection ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- Dataset Configurations (Adapted for PyTorch/Torchvision) ---
DATASETS = {
    "mnist": {
        "name": "mnist",
        "load_fn": torchvision.datasets.MNIST,
        "input_shape": (1, 28, 28), # PyTorch: C, H, W
        "num_classes": 10,
        "target_transform": None # Labels are already integers
    },
    "fashion_mnist": {
        "name": "fashion_mnist",
        "load_fn": torchvision.datasets.FashionMNIST,
        "input_shape": (1, 28, 28),
        "num_classes": 10,
         "target_transform": None
    },
    "cifar10": {
        "name": "cifar10",
        "load_fn": torchvision.datasets.CIFAR10,
        "input_shape": (3, 32, 32),
        "num_classes": 10,
         "target_transform": None
    },
    "cifar100": {
        "name": "cifar100",
        "load_fn": torchvision.datasets.CIFAR100,
        "input_shape": (3, 32, 32),
        "num_classes": 100,
         "target_transform": None
    },
    "oxford_iiit_pet": {
        "name": "oxford_iiit_pet",
        "load_fn": torchvision.datasets.OxfordIIITPet,
        "input_shape": (3, 128, 128), # Resize target
        "num_classes": 37,
        # Labels are 0-36, already suitable for CrossEntropyLoss if used directly
        "target_transform": None
    }
}

# --- Preprocessing Function (using torchvision.transforms) ---
def get_transforms(input_shape):
    """Returns appropriate transforms for training and testing."""
    img_size = input_shape[1:] # H, W
    return transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor() # Converts to Tensor, scales to [0,1], changes to C, H, W
        # Add normalization if needed, e.g., for CIFAR:
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

# --- Model Building Functions (PyTorch nn.Module) ---
class Classifier(nn.Module):
    def __init__(self, input_shape, num_classes):
        super(Classifier, self).__init__()
        C, H, W = input_shape

        self.conv1 = nn.Conv2d(C, 32, kernel_size=3, padding='same')
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)
        H //= 2; W //= 2

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)
        H //= 2; W //= 2

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding='same')
        self.bn3 = nn.BatchNorm2d(128)
        # No pooling after conv3

        flattened_size = 128 * H * W
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(flattened_size, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, num_classes)
        # IMPORTANT: Output softmax probabilities as expected by the TF loss logic
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = F.relu(self.bn3(self.conv3(x))) # No pool here
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        # Apply softmax ONLY if not using nn.CrossEntropyLoss directly
        # In this case, the custom loss needs probabilities, so we apply it.
        p = self.softmax(x)
        return p

class LossNetwork(nn.Module):
    def __init__(self, num_classes):
        super(LossNetwork, self).__init__()
        input_dim = num_classes * 2
        self.fc1 = nn.Linear(input_dim, 64)
        self.bn1 = nn.BatchNorm1d(64) # Use BatchNorm1d for FC layers
        self.fc2 = nn.Linear(64, 32)
        self.bn2 = nn.BatchNorm1d(32)
        self.fc3 = nn.Linear(32, 16)
        self.fc4 = nn.Linear(16, 1) # Output scalar energy

    def forward(self, x):
        # Input x is expected to be concatenation of p and y_one_hot
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = F.relu(self.fc3(x))
        energy = self.fc4(x) # No activation on final output
        return energy

# --- Helper for Hessian Trace (using autograd.grad) ---
def get_hessian_trace(grad_E, p):
    """ Computes trace(d(grad_E)/dp) using autograd loops. """
    trace_val = 0.
    # Need to ensure p has requires_grad=True for this
    p.requires_grad_(True)

    for i in range(p.shape[1]): # Iterate through classes / columns of p
        # Gradient of the i-th element of grad_E w.r.t p
        # retain_graph=True needed because we reuse parts of the graph in the loop
        # create_graph=True needed because the trace itself might be part of a larger loss graph
        grad_grad_E_i = torch.autograd.grad(grad_E[:, i].sum(), p, create_graph=True, retain_graph=True)[0]
        trace_val += grad_grad_E_i[:, i] # Add the diagonal element

    # Detach p's requires_grad if it was set internally
    # p.requires_grad_(False) # Be careful if p is needed later with grad
    return trace_val

# --- Training Step Function (Combined logic) ---
# No separate function creation needed, integrated into epoch loop

# --- Evaluation Step Function (Combined logic) ---
# No separate function creation needed, integrated into epoch loop

# --- Detailed Plotting Function ---
def plot_detailed_learnable_loss_metrics(
    dataset_name, epochs, num_runs, save_dir, num_classes,
    mean_train_total_loss, std_train_total_loss,
    mean_test_loss, std_test_loss,
    mean_test_acc, std_test_acc,
    mean_nonneg_loss, mean_convex_loss, mean_lips_loss,
    final_loss_network_state # Pass state dict
    ):
    """Generates and saves a 4-panel plot for learnable loss results."""
    epoch_range = range(1, epochs + 1)
    plt.figure(figsize=(14, 12))
    plt.suptitle(f'{dataset_name.upper()} - Detailed Learnable Loss Metrics (Averaged over {num_runs} Runs) (PyTorch)',
                 fontsize=16, y=1.02)

    # (a) Training Total Loss / Validation Loss
    plt.subplot(2, 2, 1)
    plt.plot(epoch_range, mean_train_total_loss, label='Mean Training Total Loss', color='purple')
    plt.fill_between(epoch_range, mean_train_total_loss - std_train_total_loss, mean_train_total_loss + std_train_total_loss,
                     alpha=0.2, color='purple')
    plt.plot(epoch_range, mean_test_loss, label='Mean Validation Task Loss', color='orange') # Label clarification
    plt.fill_between(epoch_range, mean_test_loss - std_test_loss, mean_test_loss + std_test_loss,
                     alpha=0.2, color='orange')
    plt.title("Avg. Training Total vs Validation Task Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    # (b) Validation Accuracy
    plt.subplot(2, 2, 2)
    plt.plot(epoch_range, mean_test_acc, label='Mean Validation Accuracy', color='blue')
    plt.fill_between(epoch_range, mean_test_acc - std_test_acc, mean_test_acc + std_test_acc,
                     alpha=0.2, color='blue')
    plt.title("Avg. Validation Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.ylim(0, 1.05)
    plt.legend()
    plt.grid(True)

    # (c) Constraint Penalties (log scale)
    plt.subplot(2, 2, 3)
    plt.plot(epoch_range, mean_nonneg_loss, label="Avg. Non-negativity Loss", marker='.')
    plt.plot(epoch_range, mean_convex_loss, label="Avg. Convexity Loss", marker='.')
    plt.plot(epoch_range, mean_lips_loss, label="Avg. Lipschitz Loss", marker='.')
    # Handle potential zero or negative values before log scale
    min_positive = min(m for m in (np.min(mean_nonneg_loss), np.min(mean_convex_loss), np.min(mean_lips_loss)) if m > 0)
    if min_positive > 0:
         plt.ylim(bottom=min_positive * 0.1) # Set lower limit for log scale
         plt.yscale('log')
    else:
        print("Warning: Non-positive values encountered in constraint losses, using linear scale.")

    plt.title("Avg. Constraint Penalties")
    plt.xlabel("Epochs")
    plt.ylabel("Avg. Penalty Value" + (" (log scale)" if min_positive > 0 else ""))
    plt.legend()
    plt.grid(True)

    # (d) Learned Loss Function Visualization (from last run)
    plt.subplot(2, 2, 4)
    energies = []
    n_points = 50
    p_true_class_values = np.linspace(1e-3, 1.0 - 1e-3, n_points)
    y_fixed = np.zeros((1, num_classes), dtype=np.float32)
    y_fixed[0, 0] = 1.0 # Assume true class is 0 for visualization
    y_fixed_tensor = torch.tensor(y_fixed).to(DEVICE)

    # Load the final loss network
    loss_vis_net = LossNetwork(num_classes).to(DEVICE)
    loss_vis_net.load_state_dict(final_loss_network_state)
    loss_vis_net.eval()

    if num_classes > 1:
        remaining_prob = (1.0 - p_true_class_values) / (num_classes - 1)
    else:
        remaining_prob = np.zeros_like(p_true_class_values) # Handle binary case

    with torch.no_grad():
        for i, p0 in enumerate(p_true_class_values):
            p_np = np.full((1, num_classes), remaining_prob[i], dtype=np.float32)
            p_np[0, 0] = p0
            # Normalize p_np to ensure it sums to 1
            p_np = p_np / np.sum(p_np, axis=1, keepdims=True)
            p_tensor = torch.tensor(p_np).to(DEVICE)

            loss_net_input = torch.cat([p_tensor, y_fixed_tensor], dim=1)
            try:
                E_val = loss_vis_net(loss_net_input)
                energies.append(E_val.cpu().numpy().squeeze())
            except Exception as e:
                print(f"Warning: Error during loss visualization: {e}")
                energies.append(np.nan)

    energies = np.array(energies) # Ensure it's a numpy array
    if not np.all(np.isnan(energies)) and len(energies) > 0:
        plt.plot(p_true_class_values, energies, marker='.')
        min_energy_idx = np.nanargmin(energies)
        if not np.isnan(energies[min_energy_idx]):
            plt.scatter(p_true_class_values[min_energy_idx], energies[min_energy_idx], color='red', zorder=5,
                        label=f'Min E at p[0]~{p_true_class_values[min_energy_idx]:.2f}')
        plt.legend()
    else:
         plt.text(0.5, 0.5, 'Visualization Error or No Data', horizontalalignment='center',
                  verticalalignment='center', transform=plt.gca().transAxes)

    plt.title(f"Learned Loss E(p,y) vs p[true_class=0] (Last Run)")
    plt.xlabel("p[0] (Probability for true class 0)")
    plt.ylabel("Energy E(p,y)")
    plt.grid(True)

    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    try:
        os.makedirs(save_dir, exist_ok=True)
        filename = os.path.join(save_dir, f"results_{dataset_name}_learnable_loss_detailed_plots_pytorch.png")
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Detailed plot saved to: {filename}")
    except OSError as e:
        print(f"Error saving detailed plot to {save_dir}: {e}")
    plt.close() # Close the figure


# --- F1 Score Plotting Function ---
def plot_f1_scores(dataset_name, save_dir, report_dict, num_classes):
    """ Generates and saves a bar plot of per-class F1 scores. """
    class_labels = [str(i) for i in range(num_classes)]
    f1_scores = [report_dict.get(label, {}).get('f1-score', 0) for label in class_labels]

    plt.figure(figsize=(max(8, num_classes * 0.3), 5)) # Adjust width based on num_classes
    bars = plt.bar(class_labels, f1_scores, color='dodgerblue', alpha=0.8)
    plt.xlabel("Class Label")
    plt.ylabel("F1-Score")
    plt.title(f"Per-Class F1 Scores for {dataset_name.upper()} (Last Run) (PyTorch)")
    plt.grid(True, linestyle='--', alpha=0.6, axis='y')
    plt.ylim(0, 1.05)

    # Add text labels on bars if few classes, otherwise skip
    if num_classes <= 20:
        for bar in bars:
             yval = bar.get_height()
             plt.text(bar.get_x() + bar.get_width()/2.0, yval + 0.01, f'{yval:.2f}', va='bottom', ha='center') # Add text labels

    if num_classes > 20: # Rotate labels if too many classes
         plt.xticks(rotation=90, fontsize=8)

    plt.tight_layout()
    f1_plot_filename = os.path.join(save_dir, f"f1_scores_{dataset_name}_pytorch.png")
    try:
        plt.savefig(f1_plot_filename, dpi=300, bbox_inches='tight')
        print(f"F1 score plot saved to: {f1_plot_filename}")
    except OSError as e:
        print(f"Error saving F1 score plot: {e}")
    plt.close() # Close the figure

# ---------------------------
# Main Execution Block
# ---------------------------
if __name__ == "__main__":
    # GPU check is done globally via DEVICE definition

    # Create plot directory
    try:
        os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
        print(f"Plots will be saved in '{PLOT_SAVE_DIR}/' directory.")
    except OSError as e:
        print(f"Could not create plot directory '{PLOT_SAVE_DIR}': {e}. Plots will fallback to current directory.")
        PLOT_SAVE_DIR = "."

    # Loop through Datasets
    for dataset_key, config in DATASETS.items():
        print(f"\n{'='*20} Processing Dataset: {dataset_key.upper()} {'='*20}")
        start_time_dataset = time.time()

        input_shape = config["input_shape"]
        num_classes = config["num_classes"]
        DatasetClass = config["load_fn"]
        target_transform = config["target_transform"]

        # Data Loading & Preprocessing
        print("Loading and preprocessing data...")
        transform = get_transforms(input_shape)

        # Handle different dataset loading arguments
        try:
            if dataset_key == "oxford_iiit_pet":
                train_dataset = DatasetClass(root='./data', split='trainval', download=True, transform=transform, target_transform=target_transform)
                test_dataset = DatasetClass(root='./data', split='test', download=True, transform=transform, target_transform=target_transform)
            else:
                train_dataset = DatasetClass(root='./data', train=True, download=True, transform=transform, target_transform=target_transform)
                test_dataset = DatasetClass(root='./data', train=False, download=True, transform=transform, target_transform=target_transform)

            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
            print(f"Data loading complete. Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

        except Exception as e:
            print(f"ERROR: Failed to load dataset {dataset_key}. Skipping. Error: {e}")
            continue # Skip to next dataset

        # Multiple Runs Loop
        all_runs_train_total_loss_hist = []
        all_runs_train_nonneg_loss_hist = []
        all_runs_train_convex_loss_hist = []
        all_runs_train_lips_loss_hist = []
        all_runs_test_loss_hist = [] # Stores task loss on test set
        all_runs_test_acc_hist = []
        run_times = []
        final_loss_network_state_last_run = None  # For visualization
        final_classifier_state_last_run = None    # For final metrics/plot

        for run in range(NUM_RUNS):
            print(f"\n--- Starting Run {run + 1}/{NUM_RUNS} for {dataset_key.upper()} ---")
            start_time_run = time.time()

            # Build new models and optimizers
            classifier = Classifier(input_shape, num_classes).to(DEVICE)
            loss_network = LossNetwork(num_classes).to(DEVICE)
            optimizer_classifier = optim.Adam(classifier.parameters(), lr=LEARNING_RATE_CLASSIFIER)
            optimizer_loss_network = optim.Adam(loss_network.parameters(), lr=LEARNING_RATE_LOSS_NETWORK)

            # Task loss for evaluation (using probabilities and one-hot)
            # BinaryCrossEntropy is suitable here as p is softmax and y is one-hot
            eval_task_loss_fn = nn.BCELoss()

            run_train_total_losses = []
            run_train_nonneg_losses = []
            run_train_convex_losses = []
            run_train_lips_losses = []
            run_test_losses = []
            run_test_accuracies = []

            for epoch in range(EPOCHS):
                epoch_start_time = time.time()
                # --- Training Phase ---
                classifier.train()
                loss_network.train()
                epoch_train_total_loss = 0.0
                epoch_train_nonneg_loss = 0.0
                epoch_train_convex_loss = 0.0
                epoch_train_lips_loss = 0.0
                train_batches = 0

                for x_batch, y_batch_labels in train_loader:
                    x_batch = x_batch.to(DEVICE)
                    # Labels are integers, convert to one-hot for loss network input & BCE task loss
                    y_one_hot = F.one_hot(y_batch_labels, num_classes=num_classes).float().to(DEVICE)

                    # Zero gradients
                    optimizer_classifier.zero_grad()
                    optimizer_loss_network.zero_grad()

                    # Forward pass classifier
                    p = classifier(x_batch)
                    p = torch.clamp(p, 1e-7, 1.0 - 1e-7) # Clamp for numerical stability with BCE/log

                    # Prepare input for loss network
                    loss_net_input = torch.cat((p, y_one_hot), dim=1)
                    # Ensure input requires grad for gradient calculation w.r.t. itself
                    loss_net_input.requires_grad_(True)

                    # Forward pass loss network
                    E = loss_network(loss_net_input) # Energy output

                    # Calculate grad_E w.r.t loss_net_input
                    # create_graph=True needed for Hessian calculation later
                    grad_E_full = torch.autograd.grad(E.sum(), loss_net_input, create_graph=True)[0]
                    grad_E = grad_E_full[:, :num_classes] # Extract gradient w.r.t p

                    # Calculate Hessian trace (trace(d(grad_E)/dp))
                    # Ensure p requires grad for this step
                    if not p.requires_grad:
                         p.requires_grad_(True) # Should be set by default if params require grad
                    hessian_trace = get_hessian_trace(grad_E, p)


                    # Calculate loss components
                    L_task = eval_task_loss_fn(p, y_one_hot) # Task loss (BCE on probs/one-hot)
                    L_nonneg = torch.mean(torch.square(F.relu(-E)))
                    L_convex = torch.mean(torch.square(F.relu(-hessian_trace)))
                    grad_norm = torch.linalg.norm(grad_E, dim=1)
                    L_lips = torch.mean(torch.square(grad_norm - 1.0))
                    L_constraint = L_nonneg + L_convex + L_lips

                    # Total loss
                    total_loss = L_task + LAMBDA_CONST * L_constraint

                    # Backward pass (computes gradients for BOTH networks)
                    total_loss.backward()

                    # Optimize
                    optimizer_classifier.step()
                    optimizer_loss_network.step()

                    # Accumulate epoch losses (use .item() to detach from graph)
                    epoch_train_total_loss += total_loss.item()
                    epoch_train_nonneg_loss += L_nonneg.item()
                    epoch_train_convex_loss += L_convex.item()
                    epoch_train_lips_loss += L_lips.item()
                    train_batches += 1

                    # --- Explicitly delete tensors and clear cache if memory issues arise ---
                    # del loss_net_input, E, grad_E_full, grad_E, hessian_trace, total_loss
                    # if DEVICE == 'cuda': torch.cuda.empty_cache()


                # --- Evaluation Phase ---
                classifier.eval()
                loss_network.eval() # Not strictly needed if only classifier used, but good practice
                epoch_test_loss = 0.0
                correct_test = 0
                total_test = 0
                test_batches = 0
                with torch.no_grad():
                    for x_batch_test, y_batch_test_labels in test_loader:
                        x_batch_test = x_batch_test.to(DEVICE)
                        y_batch_test_labels = y_batch_test_labels.to(DEVICE)
                        # Convert labels to one-hot for consistency with eval task loss function
                        y_batch_test_one_hot = F.one_hot(y_batch_test_labels, num_classes=num_classes).float().to(DEVICE)

                        p_test = classifier(x_batch_test)
                        p_test = torch.clamp(p_test, 1e-7, 1.0 - 1e-7) # Clamp for BCE

                        test_loss = eval_task_loss_fn(p_test, y_batch_test_one_hot)
                        epoch_test_loss += test_loss.item()

                        # Calculate accuracy using integer labels
                        _, predicted = torch.max(p_test.data, 1)
                        total_test += y_batch_test_labels.size(0)
                        correct_test += (predicted == y_batch_test_labels).sum().item()
                        test_batches += 1

                # Calculate average losses and accuracy for the epoch
                avg_train_total_loss = epoch_train_total_loss / train_batches
                avg_train_nonneg_loss = epoch_train_nonneg_loss / train_batches
                avg_train_convex_loss = epoch_train_convex_loss / train_batches
                avg_train_lips_loss = epoch_train_lips_loss / train_batches
                avg_test_loss = epoch_test_loss / test_batches
                avg_test_acc = correct_test / total_test

                run_train_total_losses.append(avg_train_total_loss)
                run_train_nonneg_losses.append(avg_train_nonneg_loss)
                run_train_convex_losses.append(avg_train_convex_loss)
                run_train_lips_losses.append(avg_train_lips_loss)
                run_test_losses.append(avg_test_loss)
                run_test_accuracies.append(avg_test_acc)

                epoch_duration = time.time() - epoch_start_time
                print(f"  Epoch {epoch+1:03d}/{EPOCHS} | Train Total Loss: {avg_train_total_loss:.4f} "
                      f"| Test Acc: {avg_test_acc:.4%} | Test Task Loss: {avg_test_loss:.4f} | Time: {epoch_duration:.2f}s")

            # --- End of Run ---
            all_runs_train_total_loss_hist.append(run_train_total_losses)
            all_runs_train_nonneg_loss_hist.append(run_train_nonneg_losses)
            all_runs_train_convex_loss_hist.append(run_train_convex_losses)
            all_runs_train_lips_loss_hist.append(run_train_lips_losses)
            all_runs_test_loss_hist.append(run_test_losses)
            all_runs_test_acc_hist.append(run_test_accuracies)

            run_time = time.time() - start_time_run
            run_times.append(run_time)
            print(f"--- Run {run + 1} completed in {run_time:.2f} seconds. Final Test Accuracy: {run_test_accuracies[-1]:.4%} ---")
            if run == NUM_RUNS - 1: # Save the state of the last run's models
                final_loss_network_state_last_run = loss_network.state_dict()
                final_classifier_state_last_run = classifier.state_dict()

        # --- Aggregating and Plotting ---
        def calc_mean_std(history_list, max_len):
            # Ensure all lists are numpy arrays before padding
            history_list_np = [np.array(hist) for hist in history_list]
            # Pad sequences to max_len (EPOCHS)
            padded_list = [np.pad(hist, (0, max_len - len(hist)), 'edge') if len(hist) < max_len else hist[:max_len] for hist in history_list_np]
            if not padded_list: # Handle case where no runs completed
                return np.full(max_len, np.nan), np.full(max_len, np.nan)
            mean_vals = np.nanmean(padded_list, axis=0) # Use nanmean for safety
            std_vals = np.nanstd(padded_list, axis=0)   # Use nanstd for safety
            return mean_vals, std_vals

        mean_train_total, std_train_total = calc_mean_std(all_runs_train_total_loss_hist, EPOCHS)
        mean_test_loss, std_test_loss = calc_mean_std(all_runs_test_loss_hist, EPOCHS)
        mean_test_acc, std_test_acc = calc_mean_std(all_runs_test_acc_hist, EPOCHS)
        mean_nonneg, std_nonneg = calc_mean_std(all_runs_train_nonneg_loss_hist, EPOCHS)
        mean_convex, std_convex = calc_mean_std(all_runs_train_convex_loss_hist, EPOCHS)
        mean_lips, std_lips = calc_mean_std(all_runs_train_lips_loss_hist, EPOCHS)

        # Print final averaged metrics (last epoch)
        print("\n=== Averaged Final Epoch Metrics (PyTorch) ===")
        # Check if metrics have valid values (not NaN) before printing
        if not np.isnan(mean_train_total[-1]): print(f"Train Total Loss: {mean_train_total[-1]:.4f} ± {std_train_total[-1]:.4f}")
        if not np.isnan(mean_test_loss[-1]): print(f"Test Task Loss:   {mean_test_loss[-1]:.4f} ± {std_test_loss[-1]:.4f}") # Clarified label
        if not np.isnan(mean_test_acc[-1]): print(f"Test Accuracy:    {mean_test_acc[-1]*100:.2f}% ± {std_test_acc[-1]*100:.2f}%")
        if not np.isnan(mean_nonneg[-1]): print(f"Non-Negativity Loss:{mean_nonneg[-1]:.4f} ± {std_nonneg[-1]:.4f}")
        if not np.isnan(mean_convex[-1]): print(f"Convexity Loss:   {mean_convex[-1]:.4f} ± {std_convex[-1]:.4f}")
        if not np.isnan(mean_lips[-1]): print(f"Lipschitz Loss:   {mean_lips[-1]:.4f} ± {std_lips[-1]:.4f}")
        print(f"Average Run Time: {np.mean(run_times):.2f} seconds")

        # Detailed plots (using last run's final loss network state)
        if final_loss_network_state_last_run is not None:
            plot_detailed_learnable_loss_metrics(
                dataset_name=dataset_key, epochs=EPOCHS, num_runs=NUM_RUNS, save_dir=PLOT_SAVE_DIR,
                num_classes=num_classes,
                mean_train_total_loss=mean_train_total, std_train_total_loss=std_train_total,
                mean_test_loss=mean_test_loss, std_test_loss=std_test_loss,
                mean_test_acc=mean_test_acc, std_test_acc=std_test_acc,
                mean_nonneg_loss=mean_nonneg, mean_convex_loss=mean_convex, mean_lips_loss=mean_lips,
                final_loss_network_state=final_loss_network_state_last_run
            )
        else:
            print("Warning: Could not generate detailed plot as final loss network state was not available.")


        # Compute additional classification metrics and F1 plot using the final classifier from the last run
        classification_report_str = "N/A"
        report_dict = {}
        if final_classifier_state_last_run is not None:
            print("\n=== Generating Final Classification Report & F1 Plot (Last Run) ===")
            final_classifier = Classifier(input_shape, num_classes).to(DEVICE)
            final_classifier.load_state_dict(final_classifier_state_last_run)
            final_classifier.eval()

            all_preds = []
            all_labels = []
            with torch.no_grad():
                for x_batch, y_batch_labels in test_loader:
                    x_batch = x_batch.to(DEVICE)
                    outputs = final_classifier(x_batch)
                    _, predicted = torch.max(outputs.data, 1)
                    all_preds.extend(predicted.cpu().numpy())
                    all_labels.extend(y_batch_labels.cpu().numpy())

            print("\nClassification Metrics (sklearn):")
            # Use zero_division=0 to avoid warnings/errors for classes with no support
            report_dict = classification_report(all_labels, all_preds, digits=4, output_dict=True, zero_division=0)
            classification_report_str = classification_report(all_labels, all_preds, digits=4, zero_division=0)
            print(classification_report_str)

            # Generate and save F1 score plot
            plot_f1_scores(dataset_key, PLOT_SAVE_DIR, report_dict, num_classes)

        else:
            print("Warning: No final classifier state available for computing final metrics or F1 plot.")

        end_time_dataset = time.time()
        print(f"--- Completed processing {dataset_key.upper()} in {end_time_dataset - start_time_dataset:.2f} seconds ---")

    print(f"\n{'='*20} All Dataset Processing Complete (PyTorch) {'='*20}")

Using device: cuda
Plots will be saved in 'plots_learnable_loss_detailed_pytorch/' directory.

Loading and preprocessing data...
Data loading complete. Train samples: 60000, Test samples: 10000

--- Starting Run 1/3 for MNIST ---
  Epoch 001/20 | Train Total Loss: 0.0607 | Test Acc: 97.7100% | Test Task Loss: 0.0131 | Time: 6.42s
  Epoch 002/20 | Train Total Loss: 0.0146 | Test Acc: 98.2500% | Test Task Loss: 0.0095 | Time: 5.61s
  Epoch 003/20 | Train Total Loss: 0.0099 | Test Acc: 98.8300% | Test Task Loss: 0.0060 | Time: 5.65s
  Epoch 004/20 | Train Total Loss: 0.0078 | Test Acc: 99.2000% | Test Task Loss: 0.0045 | Time: 5.71s
  Epoch 005/20 | Train Total Loss: 0.0063 | Test Acc: 98.8200% | Test Task Loss: 0.0070 | Time: 5.85s
  Epoch 006/20 | Train Total Loss: 0.0070 | Test Acc: 99.1900% | Test Task Loss: 0.0050 | Time: 5.68s
  Epoch 007/20 | Train Total Loss: 0.0059 | Test Acc: 98.9300% | Test Task Loss: 0.0058 | Time: 5.73s
  Epoch 008/20 | Train Total Loss: 0.0051 | Test Acc: 99