In [None]:
!pip install -q lion-pytorch

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from lion_pytorch import Lion
import matplotlib.pyplot as plt
import time
import pandas as pd
from IPython.display import display
from torch.amp import autocast, GradScaler
import numpy as np
from model_utils import SimpleCNN, train_and_evaluate


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128,
                         shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = DataLoader(testset, batch_size=100,
                        shuffle=False, num_workers=2)

def train_and_time(optimizer_name, num_epochs=30, target_acc=70.0, model_name="simplecnn"):
    start_time = time.time()
    results = train_and_evaluate(
        optimizer_name,
        num_epochs=num_epochs,
        target_acc=target_acc,
        train_loader=trainloader,
        val_loader=test_loader,
        model_name=model_name
    )
    elapsed = time.time() - start_time
    return results, elapsed


def plot_metric(metric_name, results_dict, title):
    plt.figure(figsize=(10, 5))
    for name, result in results_dict.items():
        plt.plot(result[metric_name], label=name.upper())
    plt.xlabel("Epoch")
    plt.ylabel(metric_name.replace("_", " ").title())
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_times(results_dict):
    plt.figure(figsize=(10, 5))
    for name, result in results_dict.items():
        plt.plot(result["epoch_times"], label=name.upper())
    plt.xlabel("Epoch")
    plt.ylabel("Time (seconds)")
    plt.title("Epochs vs Learning Time")
    plt.legend()
    plt.grid(True)
    plt.show()

print("\n==== ARCHITECTURE: SIMPLECNN ====")
results_adam, time_adam = train_and_time("adam", num_epochs=30, model_name="simplecnn")
results_sgd, time_sgd = train_and_time("sgd", num_epochs=30, model_name="simplecnn")
results_lion, time_lion = train_and_time("lion", num_epochs=30, model_name="simplecnn")

print("\n==== ARCHITECTURE: RESNET18 ====")
results_resnet_adam, time_resnet_adam = train_and_time("adam", num_epochs=30, model_name="resnet18")
results_resnet_sgd, time_resnet_sgd = train_and_time("sgd", num_epochs=30, model_name="resnet18")
results_resnet_lion, time_resnet_lion = train_and_time("lion", num_epochs=30, model_name="resnet18")

def epochs_to_threshold(results, threshold=70.0):
    for epoch, acc in enumerate(results["train_acc"]):
        if acc >= threshold:
            return epoch + 1
    return None

all_results_simplecnn = {
    "adam (SimpleCNN)": results_adam,
    "sgd (SimpleCNN)": results_sgd,
    "lion (SimpleCNN)": results_lion
}
all_results_resnet = {
    "adam (ResNet)": results_resnet_adam,
    "sgd (ResNet)": results_resnet_sgd,
    "lion (ResNet)": results_resnet_lion
}


print("Epochs to reach 70% train accuracy:")
for name, res in all_results_simplecnn.items():
    ep = epochs_to_threshold(res, threshold=70.0)
    if ep:
        print(f"{name.upper()}: reached in {ep} epochs")
    else:
        print(f"{name.upper()}: not reached")

plot_metric("train_loss", all_results_simplecnn, "Train Loss Comparison")
plot_metric("train_acc", all_results_simplecnn, "Train Accuracy Comparison")
plot_times(all_results_simplecnn)
for name, res in all_results_simplecnn.items():
    print(f"{name.upper()} - Final Test Loss: {res['test_loss']:.4f}, Test Accuracy: {res['test_acc']:.2f}%, "
          f"Epochs to reach 70% acc: {res['epochs_to_target_acc']}")


print("Epochs to reach 70% train accuracy:")
for name, res in all_results_resnet.items():
    ep = epochs_to_threshold(res, threshold=70.0)
    if ep:
        print(f"{name.upper()}: reached in {ep} epochs")
    else:
        print(f"{name.upper()}: not reached")

plot_metric("train_loss", all_results_resnet, "Train Loss Comparison")
plot_metric("train_acc", all_results_resnet, "Train Accuracy Comparison")
plot_times(all_results_resnet)
for name, res in all_results_resnet.items():
    print(f"{name.upper()} - Final Test Loss: {res['test_loss']:.4f}, Test Accuracy: {res['test_acc']:.2f}%, "
          f"Epochs to reach 70% acc: {res['epochs_to_target_acc']}")


def print_summary(optimizer_name, results, training_time, threshold=70.0):
    print(f"--- {optimizer_name.upper()} Summary ---")
    final_acc = results['test_acc']
    epoch_thresh = epochs_to_threshold(results, threshold)
    print(f"Final Test Accuracy: {final_acc:.2f}%")
    if epoch_thresh is not None:
        print(f"Epoch to reach {threshold}% accuracy: {epoch_thresh}")
    else:
        print(f"Did NOT reach {threshold}% accuracy in {len(results['test_acc'])} epochs")
    print(f"Training time (seconds): {training_time:.2f}")
    print(f"Final Test Accuracy: {results['test_acc']:.2f}%")

print_summary("adam (SimpleCNN)", results_adam, time_adam)
print_summary("adam (ResNet)", results_resnet_adam, time_resnet_adam)
print_summary("sgd (SimpleCNN)", results_sgd, time_sgd)
print_summary("sgd (ResNet)", results_resnet_sgd, time_resnet_sgd)
print_summary("lion (SimpleCNN)", results_lion, time_lion)
print_summary("lion (ResNet)", results_resnet_lion, time_resnet_lion)


def extract_metrics_at_epochs(results_dict, optimizer_name, epochs=[10, 20, 30]):
    table_rows = []
    for epoch in epochs:
        idx = epoch - 1
        row = {
            "Epoch": epoch,
            "Optimizer": optimizer_name,
            "Train Accuracy": f"{results_dict['train_acc'][idx]:.2f}%",
            "Train Loss": f"{results_dict['train_loss'][idx]:.4f}",
        }
        table_rows.append(row)
    return table_rows

table_data = []
table_data += extract_metrics_at_epochs(results_adam, "Adam")
table_data += extract_metrics_at_epochs(results_sgd, "SGD")
table_data += extract_metrics_at_epochs(results_lion, "Lion")

df_comparison = pd.DataFrame(table_data)

def color_alternating_rows(row):
    idx = row.name
    colors = ['white', '#fff59d', '#bbdefb']
    color = colors[idx % 3]
    return ['background-color: {}'.format(color)] * len(row)

display(df_comparison.style.apply(color_alternating_rows, axis=1))


def summarize_result(name, results, training_time):
    last_losses = results["train_loss"][-5:]
    loss_std = np.std(last_losses)
    return {
        "Model (Opt)": name,
        "Time (s)": f"{training_time:.1f}",
        "Epochs to 70%": results["epochs_to_target_acc"],
        "Final Acc (%)": f"{results['test_acc']:.2f}",
        "Final Loss": f"{results['test_loss']:.4f}",
        "Loss Std (Last 5)": f"{loss_std:.4f}"
    }

summary_rows = []
summary_rows.append(summarize_result("SimpleCNN (Adam)", results_adam, time_adam))
summary_rows.append(summarize_result("SimpleCNN (SGD)", results_sgd, time_sgd))
summary_rows.append(summarize_result("SimpleCNN (Lion)", results_lion, time_lion))
summary_rows.append(summarize_result("ResNet18 (Adam)", results_resnet_adam, time_resnet_adam))
summary_rows.append(summarize_result("ResNet18 (SGD)", results_resnet_sgd, time_resnet_sgd))
summary_rows.append(summarize_result("ResNet18 (Lion)", results_resnet_lion, time_resnet_lion))

df_summary = pd.DataFrame(summary_rows)
display(df_summary)



summary_data = [
    {
        "Model (Opt)": "SimpleCNN (Adam)",
        "Time (s)": round(time_adam, 1),
        "Epochs to 70%": results_adam["epochs_to_target_acc"],
        "Final Acc (%)": round(results_adam["test_acc"], 2),
        "Final Loss": round(results_adam["test_loss"], 4),
        "Loss Std (Last 5)": round(np.std(results_adam["train_loss"][-5:]), 4),
    },
    {
        "Model (Opt)": "SimpleCNN (SGD)",
        "Time (s)": round(time_sgd, 1),
        "Epochs to 70%": results_sgd["epochs_to_target_acc"],
        "Final Acc (%)": round(results_sgd["test_acc"], 2),
        "Final Loss": round(results_sgd["test_loss"], 4),
        "Loss Std (Last 5)": round(np.std(results_sgd["train_loss"][-5:]), 4),
    },
    {
        "Model (Opt)": "SimpleCNN (Lion)",
        "Time (s)": round(time_lion, 1),
        "Epochs to 70%": results_lion["epochs_to_target_acc"],
        "Final Acc (%)": round(results_lion["test_acc"], 2),
        "Final Loss": round(results_lion["test_loss"], 4),
        "Loss Std (Last 5)": round(np.std(results_lion["train_loss"][-5:]), 4),
    },
    {
        "Model (Opt)": "ResNet18 (Adam)",
        "Time (s)": round(time_resnet_adam, 1),
        "Epochs to 70%": results_resnet_adam["epochs_to_target_acc"],
        "Final Acc (%)": round(results_resnet_adam["test_acc"], 2),
        "Final Loss": round(results_resnet_adam["test_loss"], 4),
        "Loss Std (Last 5)": round(np.std(results_resnet_adam["train_loss"][-5:]), 4),
    },
    {
        "Model (Opt)": "ResNet18 (SGD)",
        "Time (s)": round(time_resnet_sgd, 1),
        "Epochs to 70%": results_resnet_sgd["epochs_to_target_acc"],
        "Final Acc (%)": round(results_resnet_sgd["test_acc"], 2),
        "Final Loss": round(results_resnet_sgd["test_loss"], 4),
        "Loss Std (Last 5)": round(np.std(results_sgd["train_loss"][-5:]), 4),
    },
    {
        "Model (Opt)": "ResNet18 (Lion)",
        "Time (s)": round(time_resnet_lion, 1),
        "Epochs to 70%": results_resnet_lion["epochs_to_target_acc"],
        "Final Acc (%)": round(results_resnet_lion["test_acc"], 2),
        "Final Loss": round(results_resnet_lion["test_loss"], 4),
        "Loss Std (Last 5)": round(np.std(results_lion["train_loss"][-5:]), 4),
    },
]


df_summary = pd.DataFrame(summary_data)

print("### Summary table by optimizers and architectures:")
display(df_summary)
