In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import copy
import math
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import levy_stable
from torch.optim.optimizer import Optimizer
import torch.optim as optim

# Hyperparameters
LEARNING_RATE = 0.01
CLIP = 5
NOISE_INIT = 0.005
NOISE_DECAY = 0.55
EPOCHS = 10

# =====================================
# 1. Optimizer Definitions
# =====================================

class AnnealisingOptimiser(Optimizer):
    """
    A heavy-tailed optimizer that applies a fixed, non-adaptive update.
    The parameter alpha is updated as: alpha = 2 - exp(-k * num_steps)
    Noise variance decays per step.
    """
    def __init__(self,
                 params,
                 lr=1e-3,
                 alpha=1.2,
                 rho=0.05,
                 decay=0.95,
                 noise_init=NOISE_INIT,
                 noise_decay=NOISE_DECAY,
                 weight_decay=0,
                 k=1/5,
                 **kwargs):
        defaults = dict(lr=lr, alpha=alpha, rho=rho, weight_decay=weight_decay)
        super(AnnealisingOptimiser, self).__init__(params, defaults)
        self.decay = decay
        self.noise_init = noise_init
        self.noise_decay = noise_decay
        self.beta = 0  # symmetric noise
        self.num_steps = 1
        self.current_sigma = noise_init
        self.noise_samples = []
        self.epoch = 0
        self.k = k

    def update_epoch(self):
        self.epoch += 1

    @torch.no_grad()
    def step(self, closure=None):
        if closure is None:
            raise ValueError("Requires closure for loss computation")
        closure = torch.enable_grad()(closure)
        loss = closure()

        # Update alpha and noise scale
        self.alpha = 2 - math.exp(-self.k * self.num_steps)
        self.current_sigma = np.sqrt(self.noise_init) / ((1 + self.num_steps) ** (self.noise_decay/2))

        total_elems = sum(p.numel() for group in self.param_groups for p in group['params'])
        big_noise = levy_stable.rvs(self.alpha, self.beta, scale=self.current_sigma, size=total_elems)
        np.clip(big_noise, -CLIP, CLIP, out=big_noise)
        device = self.param_groups[0]['params'][0].device
        big_noise = torch.from_numpy(big_noise).float().to(device=device)
        self.noise_samples.append(big_noise.detach().cpu().numpy())

        idx_start = 0
        for group in self.param_groups:
            lr = group['lr']
            for p in group['params']:
                if p.grad is None:
                    continue
                if group['weight_decay'] != 0:
                    p.grad.data.add_(p.data, alpha=group['weight_decay'])
                elem_count = p.numel()
                noise_slice = big_noise[idx_start:idx_start + elem_count].view_as(p.data)
                idx_start += elem_count
                noise_coeff = (lr ** (1 / self.alpha)) * noise_slice
                p.data.add_(p.grad.data, alpha=-lr)
                p.data.add_(noise_coeff)
        self.num_steps += 1
        return loss.item()

class ImprovedGaussianOptimizer(Optimizer):
    """
    An optimizer that applies Gaussian noise (mean=0) to the gradient update.
    """
    def __init__(self,
                 params,
                 lr=1e-3,
                 noise_init=0.01,
                 noise_decay=NOISE_DECAY,
                 weight_decay=0,
                 **kwargs):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super(ImprovedGaussianOptimizer, self).__init__(params, defaults)
        self.noise_init = noise_init
        self.noise_decay = noise_decay
        self.num_steps = 1
        self.epoch = 0
        self.current_sigma = noise_init

    def update_epoch(self):
        self.epoch += 1

    @torch.no_grad()
    def step(self, closure=None):
        if closure is None:
            raise ValueError("Requires closure for loss computation")
        closure = torch.enable_grad()(closure)
        loss = closure()
        self.current_sigma = self.noise_init / ((1 + self.num_steps) ** self.noise_decay)
        total_elems = sum(p.numel() for group in self.param_groups for p in group['params'])
        device = self.param_groups[0]['params'][0].device
        noise = torch.randn(total_elems, device=device) * self.current_sigma
        idx_start = 0
        for group in self.param_groups:
            lr = group['lr']
            for p in group['params']:
                if p.grad is None:
                    continue
                elem_count = p.numel()
                noise_slice = noise[idx_start:idx_start + elem_count].view_as(p.data)
                idx_start += elem_count
                p.data.add_(p.grad.data + noise_slice, alpha=-lr)
        self.num_steps += 1
        return loss.item()

class AdaptiveAnnealisingOptimizer(Optimizer):
    """
    An optimizer that adapts its alpha parameter using a logistic mapping from the EMA of the Hessian trace (sharpness)
    to the interval [alpha_min, alpha_max].
    """
    def __init__(self,
                 params,
                 lr=1e-3,
                 momentum=0.0,
                 alpha=1.0,
                 alpha_min=1.0,
                 alpha_max=2.0,
                 window_size=20,
                 noise_init=0.001,
                 noise_decay=0.55,
                 weight_decay=0,
                 num_hutchinson_samples=1,
                 logistic_scale=0.5,
                 logistic_center=3.0):
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
        super(AdaptiveAnnealisingOptimizer, self).__init__(params, defaults)
        self.alpha_min = alpha_min
        self.alpha_max = alpha_max
        self.window_size = window_size
        self.noise_init = noise_init
        self.noise_decay = noise_decay
        self.num_hutchinson_samples = num_hutchinson_samples
        self.logistic_scale = logistic_scale
        self.logistic_center = logistic_center
        self.alpha = alpha
        self.beta = 0
        self.num_steps = 1
        self.current_sigma = noise_init
        self.noise_samples = []
        self.sharpness_list = []
        self.ema_list = []
        self.alpha_list = []
        self.epoch = 0
        self.ema_sharpness = None

    def update_epoch(self):
        self.epoch += 1

    def compute_hessian_trace(self, loss, params):
        trace_est = 0.0
        for _ in range(self.num_hutchinson_samples):
            r = [torch.randint(0, 2, p.shape, device=p.device).float()*2 - 1 for p in params]
            grads = torch.autograd.grad(loss, params, create_graph=True, retain_graph=True)
            dot = sum((g * r_i).sum() for g, r_i in zip(grads, r))
            hv = torch.autograd.grad(dot, params, retain_graph=True)
            trace_est += sum((r_i * hv_i).sum() for r_i, hv_i in zip(r, hv))
        trace_est /= self.num_hutchinson_samples
        return trace_est

    def adapt_alpha(self, sharpness):
        alpha_ema = 0.1
        alpha_smoothing = 0.1
        if self.ema_sharpness is None:
            self.ema_sharpness = sharpness
            self.ema_list.append(self.ema_sharpness)
            self.alpha = self.alpha_min
            self.alpha_list.append(self.alpha)
            return
        self.ema_sharpness = (1 - alpha_ema)*self.ema_sharpness + alpha_ema*sharpness
        self.ema_list.append(self.ema_sharpness)
        if len(self.ema_list) < self.window_size:
            mean_recent = self.ema_sharpness
        else:
            mean_recent = np.mean(self.ema_list[-self.window_size:])
        A = self.alpha_min
        B = self.alpha_max
        k = self.logistic_scale
        c = self.logistic_center
        logistic_val = 1.0 / (1.0 + math.exp(-k*(mean_recent-c)))
        new_alpha_raw = A + (B-A)*logistic_val
        self.alpha = self.alpha + alpha_smoothing*(new_alpha_raw - self.alpha)
        self.alpha_list.append(self.alpha)

    def step(self, closure=None):
        if closure is None:
            raise ValueError("AdaptiveAnnealisingOptimizer requires a closure.")
        with torch.enable_grad():
            loss = closure()
            params = [p for group in self.param_groups for p in group['params'] if p.requires_grad]
            hessian_trace = self.compute_hessian_trace(loss, params)
            sharpness_value = torch.log1p(torch.abs(hessian_trace)).item()
            self.sharpness_list.append(sharpness_value)
            self.adapt_alpha(sharpness_value)
            loss.backward()
        with torch.no_grad():
            self.current_sigma = np.sqrt(self.noise_init) / ((1 + self.num_steps)**(self.noise_decay/2))
            total_elems = sum(p.numel() for group in self.param_groups for p in group['params'])
            big_noise = levy_stable.rvs(self.alpha, self.beta, scale=self.current_sigma, size=total_elems)
            np.clip(big_noise, -CLIP, CLIP, out=big_noise)
            self.noise_samples.append(big_noise)
            device = self.param_groups[0]['params'][0].device
            big_noise = torch.from_numpy(big_noise).float().to(device=device)
            idx_start = 0
            for group in self.param_groups:
                lr = group['lr']
                momentum = group['momentum']
                for p in group['params']:
                    if p.grad is None:
                        continue
                    if group['weight_decay'] != 0:
                        p.grad.data.add_(p.data, alpha=group['weight_decay'])
                    elem_count = p.numel()
                    noise_slice = big_noise[idx_start:idx_start+elem_count].view_as(p.grad.data)
                    idx_start += elem_count
                    noise_coeff = (lr**(1/self.alpha)) * noise_slice
                    if momentum != 0:
                        state = self.state[p]
                        if 'momentum_buffer' not in state:
                            buf = state['momentum_buffer'] = torch.zeros_like(p.data)
                        else:
                            buf = state['momentum_buffer']
                        buf.mul_(momentum).add_(p.grad.data + noise_coeff)
                        p.data.add_(buf, alpha=-lr)
                    else:
                        p.data.add_(p.grad.data, alpha=-lr)
                        p.data.add_(noise_coeff)
            self.num_steps += 1
        return loss.item()

class SAM(Optimizer):
    """
    Implements Sharpness-Aware Minimization (SAM).
    This optimizer wraps a base optimizer (e.g. SGD) and perturbs the weights along the ascent direction.
    """
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.rho = rho
        self.adaptive = adaptive

    @torch.no_grad()
    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(torch.stack([
            ((torch.abs(p) if self.adaptive else 1.0)*p.grad).norm(p=2).to(shared_device)
            for group in self.param_groups for p in group["params"] if p.grad is not None
        ]), p=2)
        return norm

    @torch.no_grad()
    def first_step(self, zero_grad=True):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = self.rho/(grad_norm+1e-12)
            for p in group["params"]:
                if p.grad is None:
                    continue
                e_w = (p.abs() if self.adaptive else 1.0)*p.grad*scale
                p.add_(e_w)
                self.state[p]["e_w"] = e_w
        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=True):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]["e_w"])
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

    def step(self, closure=None):
        if closure is None:
            raise ValueError("SAM requires a closure.")
        closure = torch.enable_grad()(closure)
        loss = closure()
        loss.backward()
        self.first_step(zero_grad=True)
        loss_perturbed = closure()
        loss_perturbed.backward()
        self.second_step(zero_grad=True)
        return loss.item()

# =====================================
# 2. Simple MLP Model Definition (MNIST)
# =====================================

class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.hidden_layers = nn.ModuleList(
            [nn.Linear(28*28, 50)] + [nn.Linear(50, 50) for _ in range(3)]
        )
        self.output_layer = nn.Linear(50, 10)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.constant_(m.weight, 0)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = x.view(-1, 28*28)
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        return self.output_layer(x)

# =====================================
# 3. Data Loading (MNIST)
# =====================================

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =====================================
# 4. Train & Test Utilities
# =====================================

def train_sgd(model, loader, optimizer, epoch, losses):
    model.train()
    running_loss, total = 0.0, 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        total += 1
    losses.append(running_loss/total)

def train_noise(model, loader, optimizer, epoch, losses, noise_list):
    model.train()
    running_loss, total = 0.0, 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        def closure():
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            return loss
        loss_val = optimizer.step(closure=closure)
        running_loss += loss_val
        total += 1
    losses.append(running_loss/total)
    noise_list.append(optimizer.current_sigma)

def train_adaptive(model, loader, optimizer, epoch, losses, alpha_list, sharpness_list, noise_list):
    model.train()
    running_loss, total = 0.0, 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        def closure():
            optimizer.zero_grad()
            output = model(data)
            return F.cross_entropy(output, target)
        loss_val = optimizer.step(closure=closure)
        running_loss += loss_val
        total += 1
    losses.append(running_loss/total)
    if optimizer.sharpness_list:
        sharpness_list.append(optimizer.sharpness_list[-1])
    else:
        sharpness_list.append(0.0)
    if optimizer.alpha_list:
        alpha_list.append(optimizer.alpha_list[-1])
    else:
        alpha_list.append(optimizer.alpha)
    noise_list.append(optimizer.current_sigma)

def evaluate_model(model, loader):
    model.eval()
    total_loss, correct = 0.0, 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.cross_entropy(output, target, reduction='sum')
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return total_loss/len(loader.dataset), 100.0*correct/len(loader.dataset)

def train_sam(model, loader, optimizer, epoch, losses, sam_rho_list):
    model.train()
    running_loss, total = 0.0, 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        def closure():
            optimizer.zero_grad()
            output = model(data)
            return F.cross_entropy(output, target)
        loss_val = optimizer.step(closure=closure)
        running_loss += loss_val
        total += 1
        sam_rho_list.append(optimizer.rho)
    losses.append(running_loss/total)

# =====================================
# 5. Main Training Loop & Metric Collection
# =====================================

# Save initial state for reinitialization
model_ref = SimpleMLP().to(device)
initial_state = copy.deepcopy(model_ref.state_dict())

# ----- Instantiate Models & Optimizers -----
# Annealising Optimizer with k = 1/1000
model_k1000 = SimpleMLP().to(device)
model_k1000.load_state_dict(initial_state)
optimizer_k1000 = AnnealisingOptimiser(
    model_k1000.parameters(),
    lr=LEARNING_RATE,
    alpha=1.7,
    rho=0.05,
    decay=0.95,
    noise_init=NOISE_INIT,
    noise_decay=NOISE_DECAY,
    k=1/1000,
    weight_decay=0
)

# SGD
model_sgd = SimpleMLP().to(device)
model_sgd.load_state_dict(initial_state)
optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=LEARNING_RATE, momentum=0.0, weight_decay=0)

# Gaussian
model_gaussian = SimpleMLP().to(device)
model_gaussian.load_state_dict(initial_state)
optimizer_gaussian = ImprovedGaussianOptimizer(
    model_gaussian.parameters(),
    lr=LEARNING_RATE,
    noise_init=NOISE_INIT,
    noise_decay=NOISE_DECAY,
    weight_decay=0
)

# Annealising Optimizer with k = 1/5
model_k5 = SimpleMLP().to(device)
model_k5.load_state_dict(initial_state)
optimizer_k5 = AnnealisingOptimiser(
    model_k5.parameters(),
    lr=LEARNING_RATE,
    alpha=1.7,
    rho=0.05,
    decay=0.95,
    noise_init=NOISE_INIT,
    noise_decay=NOISE_DECAY,
    k=1/5,
    weight_decay=0
)

# HT-Hutch Optimizer with k = 1/10 (using AnnealisingOptimiser here)
model_k10 = SimpleMLP().to(device)
model_k10.load_state_dict(initial_state)
optimizer_k10 = AnnealisingOptimiser(
    model_k10.parameters(),
    lr=LEARNING_RATE,
    alpha=1.7,
    rho=0.05,
    decay=0.95,
    noise_init=NOISE_INIT,
    noise_decay=NOISE_DECAY,
    k=1/10,
    weight_decay=0
)

# HT-Power Optimizer with k = 1/50 (using AnnealisingOptimiser here)
model_k50 = SimpleMLP().to(device)
model_k50.load_state_dict(initial_state)
optimizer_k50 = AnnealisingOptimiser(
    model_k50.parameters(),
    lr=LEARNING_RATE,
    alpha=1.7,
    rho=0.05,
    decay=0.95,
    noise_init=NOISE_INIT,
    noise_decay=NOISE_DECAY,
    k=1/50,
    weight_decay=0
)

# AdaptiveAnnealisingOptimizer
model_adaptive = SimpleMLP().to(device)
model_adaptive.load_state_dict(initial_state)
optimizer_adaptive = AdaptiveAnnealisingOptimizer(
    model_adaptive.parameters(),
    lr=LEARNING_RATE,
    momentum=0.0,
    alpha_min=1.0,
    alpha_max=2.0,
    window_size=50,
    noise_init=NOISE_INIT,
    noise_decay=NOISE_DECAY,
    weight_decay=0,
    num_hutchinson_samples=1
)

# SAM (using AnnealisingOptimiser with k = 1/100)
model_k100 = SimpleMLP().to(device)
model_k100.load_state_dict(initial_state)
optimizer_k100 = AnnealisingOptimiser(
    model_k100.parameters(),
    lr=LEARNING_RATE,
    alpha=1.7,
    rho=0.05,
    decay=0.95,
    noise_init=NOISE_INIT,
    noise_decay=NOISE_DECAY,
    k=1/100,
    weight_decay=0
)

# ----- Initialize Metric Lists -----
sgd_train_losses, sgd_test_losses, sgd_test_errors, sgd_train_acc = [], [], [], []
gauss_train_losses, gauss_test_losses, gauss_test_errors, gauss_train_acc, gauss_noise = [], [], [], [], []
k1000_train_losses, k1000_test_losses, k1000_test_errors, k1000_train_acc = [], [], [], []
k5_train_losses, k5_test_losses, k5_test_errors, k5_train_acc, k5_noise = [], [], [], [], []
k10_train_losses, k10_test_losses, k10_test_errors, k10_train_acc = [], [], [], []
k50_train_losses, k50_test_losses, k50_test_errors, k50_train_acc = [], [], [], []
adaptive_train_losses, adaptive_test_losses, adaptive_test_errors, adaptive_train_acc = [], [], [], []
adaptive_alpha, adaptive_sharpness, adaptive_noise = [], [], []
k100_train_losses, k100_test_losses, k100_test_errors, k100_train_acc, k100_noise = [], [], [], [], []
k100_rho_list = []

# ----- Main Training Loop -----
for ep in range(1, EPOCHS + 1):
    # SGD
    train_sgd(model_sgd, train_loader, optimizer_sgd, ep, sgd_train_losses)
    loss_sgd, acc_sgd = evaluate_model(model_sgd, train_loader)
    sgd_train_acc.append(acc_sgd)
    loss_sgd_test, acc_sgd_test = evaluate_model(model_sgd, test_loader)
    sgd_test_losses.append(loss_sgd_test)
    sgd_test_errors.append(1.0 - (acc_sgd_test/100.0))
    print(f"[SGD] Epoch {ep} | Train Loss: {loss_sgd:.4f} | Train Acc: {acc_sgd:.2f}% | Test Loss: {loss_sgd_test:.4f} | Test Acc: {acc_sgd_test:.2f}%")

    # Gaussian
    train_noise(model_gaussian, train_loader, optimizer_gaussian, ep, gauss_train_losses, gauss_noise)
    loss_gauss, acc_gauss = evaluate_model(model_gaussian, train_loader)
    gauss_train_acc.append(acc_gauss)
    loss_gauss_test, acc_gauss_test = evaluate_model(model_gaussian, test_loader)
    gauss_test_losses.append(loss_gauss_test)
    gauss_test_errors.append(1.0 - (acc_gauss_test/100.0))
    print(f"[Gaussian] Epoch {ep} | Train Loss: {loss_gauss:.4f} | Train Acc: {acc_gauss:.2f}% | Test Loss: {loss_gauss_test:.4f} | Test Acc: {acc_gauss_test:.2f}%")

    # k = 1/1000 (HT)
    train_noise(model_k1000, train_loader, optimizer_k1000, ep, k1000_train_losses, [])
    loss_k1000, acc_k1000 = evaluate_model(model_k1000, train_loader)
    k1000_train_acc.append(acc_k1000)
    loss_k1000_test, acc_k1000_test = evaluate_model(model_k1000, test_loader)
    k1000_test_losses.append(loss_k1000_test)
    k1000_test_errors.append(1.0 - (acc_k1000_test/100.0))
    print(f"[k=1/1000] Epoch {ep} | Train Loss: {loss_k1000:.4f} | Train Acc: {acc_k1000:.2f}% | Test Loss: {loss_k1000_test:.4f} | Test Acc: {acc_k1000_test:.2f}%")

    # k = 1/5 (Annealising)
    train_noise(model_k5, train_loader, optimizer_k5, ep, k5_train_losses, k5_noise)
    loss_k5, acc_k5 = evaluate_model(model_k5, train_loader)
    k5_train_acc.append(acc_k5)
    loss_k5_test, acc_k5_test = evaluate_model(model_k5, test_loader)
    k5_test_losses.append(loss_k5_test)
    k5_test_errors.append(1.0 - (acc_k5_test/100.0))
    print(f"[k=1/5] Epoch {ep} | Train Loss: {loss_k5:.4f} | Train Acc: {acc_k5:.2f}% | Test Loss: {loss_k5_test:.4f} | Test Acc: {acc_k5_test:.2f}%")

    # k = 1/10 (HT-Hutch)
    train_noise(model_k10, train_loader, optimizer_k10, ep, k10_train_losses, [])
    loss_k10, acc_k10 = evaluate_model(model_k10, train_loader)
    k10_train_acc.append(acc_k10)
    loss_k10_test, acc_k10_test = evaluate_model(model_k10, test_loader)
    k10_test_losses.append(loss_k10_test)
    k10_test_errors.append(1.0 - (acc_k10_test/100.0))
    print(f"[k=1/10] Epoch {ep} | Train Loss: {loss_k10:.4f} | Train Acc: {acc_k10:.2f}% | Test Loss: {loss_k10_test:.4f} | Test Acc: {acc_k10_test:.2f}%")

    # k = 1/50 (HT-Power)
    train_noise(model_k50, train_loader, optimizer_k50, ep, k50_train_losses, [])
    loss_k50, acc_k50 = evaluate_model(model_k50, train_loader)
    k50_train_acc.append(acc_k50)
    loss_k50_test, acc_k50_test = evaluate_model(model_k50, test_loader)
    k50_test_losses.append(loss_k50_test)
    k50_test_errors.append(1.0 - (acc_k50_test/100.0))
    print(f"[k=1/50] Epoch {ep} | Train Loss: {loss_k50:.4f} | Train Acc: {acc_k50:.2f}% | Test Loss: {loss_k50_test:.4f} | Test Acc: {acc_k50_test:.2f}%")

    # Adaptive
    train_adaptive(model_adaptive, train_loader, optimizer_adaptive, ep, adaptive_train_losses, adaptive_alpha, adaptive_sharpness, adaptive_noise)
    loss_adaptive, acc_adaptive = evaluate_model(model_adaptive, train_loader)
    adaptive_train_acc.append(acc_adaptive)
    loss_adaptive_test, acc_adaptive_test = evaluate_model(model_adaptive, test_loader)
    adaptive_test_losses.append(loss_adaptive_test)
    adaptive_test_errors.append(1.0 - (acc_adaptive_test/100.0))
    print(f"[Adaptive] Epoch {ep} | Train Loss: {loss_adaptive:.4f} | Train Acc: {acc_adaptive:.2f}% | Test Loss: {loss_adaptive_test:.4f} | Test Acc: {acc_adaptive_test:.2f}%")

    # k = 1/100 (SAM)
    train_noise(model_k100, train_loader, optimizer_k100, ep, k100_train_losses, k100_noise)
    loss_k100, acc_k100 = evaluate_model(model_k100, train_loader)
    k100_train_acc.append(acc_k100)
    loss_k100_test, acc_k100_test = evaluate_model(model_k100, test_loader)
    k100_test_losses.append(loss_k100_test)
    k100_test_errors.append(1.0 - (acc_k100_test/100.0))
    k100_rho_list.append(optimizer_k100.rho)
    print(f"[k=1/100] Epoch {ep} | Train Loss: {loss_k100:.4f} | Train Acc: {acc_k100:.2f}% | Test Loss: {loss_k100_test:.4f} | Test Acc: {acc_k100_test:.2f}%")

    print("--------------------------------------------------")

epochs_list = list(range(1, EPOCHS + 1))

# =====================================
# 6. Overlaid Plots for All Optimizers (MNIST)
# =====================================
plt.figure(figsize=(8,6))
plt.plot(epochs_list, sgd_train_losses, label="SGD", linewidth=2, color='blue')
plt.plot(epochs_list, gauss_train_losses, label="Gaussian", linewidth=2, color='red')
plt.plot(epochs_list, k1000_train_losses, label="k = 1/1000", linewidth=2, color='black')
plt.plot(epochs_list, k5_train_losses, label="k = 1/5", linewidth=2, color='green')
plt.plot(epochs_list, k10_train_losses, label="k = 1/10", linewidth=2, color='purple')
plt.plot(epochs_list, k50_train_losses, label="k = 1/50", linewidth=2, color='orange')
plt.plot(epochs_list, adaptive_train_losses, label="Adaptive", linewidth=2, color='brown')
plt.plot(epochs_list, k100_train_losses, label="k = 1/100", linewidth=2, color='magenta')
plt.xlabel("Epoch")
plt.ylabel("Training Loss (Cross-Entropy)")
plt.title("Training Loss - All Optimizers (MNIST)")
plt.legend()
plt.tight_layout()
plt.savefig("mnist_all_train_loss.pdf", bbox_inches="tight")
plt.show()

plt.figure(figsize=(8,6))
plt.plot(epochs_list, sgd_test_losses, label="SGD", linewidth=2, color='blue')
plt.plot(epochs_list, gauss_test_losses, label="Gaussian", linewidth=2, color='red')
plt.plot(epochs_list, k1000_test_losses, label="k = 1/1000", linewidth=2, color='black')
plt.plot(epochs_list, k5_test_losses, label="k = 1/5", linewidth=2, color='green')
plt.plot(epochs_list, k10_test_losses, label="k = 1/10", linewidth=2, color='purple')
plt.plot(epochs_list, k50_test_losses, label="k = 1/50", linewidth=2, color='orange')
plt.plot(epochs_list, adaptive_test_losses, label="Adaptive", linewidth=2, color='brown')
plt.plot(epochs_list, k100_test_losses, label="k = 1/100", linewidth=2, color='magenta')
plt.xlabel("Epoch")
plt.ylabel("Test Loss")
plt.title("Test Loss - All Optimizers (MNIST)")
plt.legend()
plt.tight_layout()
plt.savefig("mnist_all_test_loss.pdf", bbox_inches="tight")
plt.show()

sgd_test_acc = [(1 - err)*100 for err in sgd_test_errors]
gauss_test_acc = [(1 - err)*100 for err in gauss_test_errors]
k1000_test_acc = [(1 - err)*100 for err in k1000_test_errors]
k5_test_acc = [(1 - err)*100 for err in k5_test_errors]
k10_test_acc = [(1 - err)*100 for err in k10_test_errors]
k50_test_acc = [(1 - err)*100 for err in k50_test_errors]
adaptive_test_acc = [(1 - err)*100 for err in adaptive_test_errors]
k100_test_acc = [(1 - err)*100 for err in k100_test_errors]

plt.figure(figsize=(8,6))
plt.plot(epochs_list, sgd_test_acc, label="SGD", linewidth=2, color='blue')
plt.plot(epochs_list, gauss_test_acc, label="Gaussian", linewidth=2, color='red')
plt.plot(epochs_list, k1000_test_acc, label="k = 1/1000", linewidth=2, color='black')
plt.plot(epochs_list, k5_test_acc, label="k = 1/5", linewidth=2, color='green')
plt.plot(epochs_list, k10_test_acc, label="k = 1/10", linewidth=2, color='purple')
plt.plot(epochs_list, k50_test_acc, label="k = 1/50", linewidth=2, color='orange')
plt.plot(epochs_list, adaptive_test_acc, label="Adaptive", linewidth=2, color='brown')
plt.plot(epochs_list, k100_test_acc, label="k = 1/100", linewidth=2, color='magenta')
plt.xlabel("Epoch")
plt.ylabel("Test Accuracy (%)")
plt.title("Test Accuracy - All Optimizers (MNIST)")
plt.legend()
plt.tight_layout()
plt.savefig("mnist_all_test_accuracy.pdf", bbox_inches="tight")
plt.show()

plt.figure(figsize=(8,6))
plt.plot(epochs_list, sgd_train_acc, label="SGD", linewidth=2, color='blue')
plt.plot(epochs_list, gauss_train_acc, label="Gaussian", linewidth=2, color='red')
plt.plot(epochs_list, k5_train_acc, label="k = 1/5", linewidth=2, color='green')
plt.plot(epochs_list, k10_train_acc, label="k = 1/10", linewidth=2, color='purple')
plt.plot(epochs_list, k50_train_acc, label="k = 1/50", linewidth=2, color='orange')
plt.plot(epochs_list, adaptive_train_acc, label="Adaptive", linewidth=2, color='brown')
plt.plot(epochs_list, k100_train_acc, label="k = 1/100", linewidth=2, color='magenta')
plt.xlabel("Epoch")
plt.ylabel("Training Accuracy (%)")
plt.title("Training Accuracy - All Optimizers (MNIST)")
plt.legend()
plt.tight_layout()
plt.savefig("mnist_all_train_accuracy.pdf", bbox_inches="tight")
plt.show()

plt.figure(figsize=(8,6))
plt.plot(epochs_list, sgd_test_errors, label="SGD", linewidth=2, color='blue')
plt.plot(epochs_list, k1000_test_errors, label="k = 1/1000", linewidth=2, color='black')
plt.plot(epochs_list, gauss_test_errors, label="Gaussian", linewidth=2, color='red')
plt.plot(epochs_list, k5_test_errors, label="k = 1/5", linewidth=2, color='green')
plt.plot(epochs_list, k10_test_errors, label="k = 1/10", linewidth=2, color='purple')
plt.plot(epochs_list, k50_test_errors, label="k = 1/50", linewidth=2, color='orange')
plt.plot(epochs_list, adaptive_test_errors, label="Adaptive", linewidth=2, color='brown')
plt.plot(epochs_list, k100_test_errors, label="k = 1/100", linewidth=2, color='magenta')
plt.xlabel("Epoch")
plt.ylabel("Test Error (Misclassification Ratio)")
plt.title("Test Error - All Optimizers (MNIST) [Log Scale]")
plt.yscale("log")
plt.legend()
plt.tight_layout()
plt.savefig("mnist_all_test_error.pdf", bbox_inches="tight")
plt.show()

print("Done!\nPlots saved: mnist_all_train_loss.pdf, mnist_all_test_loss.pdf, mnist_all_test_accuracy.pdf, mnist_all_train_accuracy.pdf, mnist_all_test_error.pdf")


In [None]:
import matplotlib.pyplot as plt

def sharpness_alpha(optimizer):
    # Extract values from the optimizer
    sharpness_values = optimizer.sharpness_list
    alpha_values = optimizer.alpha_list
    iterations = list(range(len(sharpness_values)))

    # Define figure size and font sizes
    figsize = (10, 6)
    label_fontsize = 16
    title_fontsize = 18
    legend_fontsize = 14
    tick_fontsize = 14

    # Create figure and primary axis
    fig, ax1 = plt.subplots(figsize=figsize)

    # Plot sharpness (left y-axis)
    ax1.plot(iterations, sharpness_values, color='blue', label='Sharpness', linewidth=2)
    ax1.set_xlabel('Iteration', fontsize=label_fontsize)
    ax1.set_ylabel('Sharpness', fontsize=label_fontsize, color='blue')
    ax1.tick_params(axis='y', labelcolor='blue', labelsize=tick_fontsize)
    ax1.tick_params(axis='x', labelsize=tick_fontsize)

    # Create second y-axis for alpha values
    ax2 = ax1.twinx()
    ax2.plot(iterations, alpha_values, color='red', linestyle='--', label='Alpha', linewidth=2)
    ax2.set_ylabel('Alpha', fontsize=label_fontsize, color='red')
    ax2.tick_params(axis='y', labelcolor='red', labelsize=tick_fontsize)

    # Title and legends
    plt.title('Sharpness vs Alpha over Iterations', fontsize=title_fontsize)

    # Combine legends from both axes
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    ax1.legend(lines_1 + lines_2, labels_1 + labels_2, fontsize=legend_fontsize, loc='upper right')

    # Tight layout and save
    plt.tight_layout()
    plt.savefig("YES_sharpness_vs_alpha.pdf", bbox_inches="tight")
    plt.show()


In [None]:
sharpness_alpha(optimizer_adaptive)