In [1]:
from sklearn.datasets import load_iris
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Dataset
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../')))
from sklearn.decomposition import PCA
from optimizers.AdaFisher import AdaFisher
from optimizers.kfac import KFACOptimizer
from asdl.precondition import PreconditioningConfig, ShampooGradientMaker
from optimizers.AdaHessian import Adahessian
from optimizers.Adam import Adam
from optimizers.sgd import SGD
from sklearn.preprocessing import StandardScaler
from matplotlib import pyplot as plt
import numpy as np
from scipy.interpolate import griddata
from matplotlib.ticker import FormatStrFormatter
import numpy as np
from tqdm import tqdm

# Load dataset and apply Data redution and normalization

In [None]:
iris = load_iris()
data = iris['data']
target = iris['target']
pca = PCA(n_components=2)
data_red = pca.fit_transform(data)
# Preprocess the data
scaler = StandardScaler()
data_scaled = scaler.fit_transform(data_red)
# Convert to torch tensors
data_tensor = torch.tensor(data_scaled, dtype=torch.float32)
labels_tensor = torch.tensor(target, dtype=torch.long)

# Create the IRIS dataset 

In [None]:
class IrisDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

train_dataset = IrisDataset(data_tensor, labels_tensor)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

# Create a Toy MLP classifier 

In [None]:
class MLP(nn.Module):
    def __init__(self, target_size: int):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(2, 1)
        self.fc2 = nn.Linear(1, target_size)
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
        with torch.no_grad():
            self.fc1.weight.fill_(0.5)  # Set all weights to 0.5
            self.fc1.bias.zero_() 
            self.fc2.weight.fill_(0.5)  # Set all weights to 0.5
            self.fc2.bias.zero_() 
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        return x

# Training function where the loss and the parameters of the first layer are catched

In [None]:
model_AdaHesian =  MLP(target_size=3)
model_Adam =  MLP(target_size=3)
model_AdaFisher =  MLP(target_size=3)
model_Shampoo =  MLP(target_size=3)
model_kfac =  MLP(target_size=3)

In [None]:
def train_model(model, train_loader, optimizer, loss_fn, epochs, gm = None):
    model.fc1.weight.data = torch.Tensor([[0.2], [0.2]]).reshape(1,2)
    weight_history = []
    loss_history = []
    for _ in range(epochs):
        total_loss = 0
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            if optimizer.__class__.__name__ == "SGD":
                dummy_y = gm.setup_model_call(model, inputs)
                gm.setup_loss_call(loss_fn, dummy_y, targets)
                outputs, loss = gm.forward_and_backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                optimizer.step()
            else:
                if optimizer.__class__.__name__ == "Adahessian":
                    loss.backward(create_graph=True)
                else:
                    loss.backward()
                optimizer.step()
            total_loss += loss.item()
        weight_history.append(model.fc1.weight.data.cpu().numpy().flatten())
        loss_history.append(total_loss / len(train_loader))
    return weight_history, loss_history

# Train for SGD, AdaFisher and Adam

In [None]:
# model = MLP(target_size=3)
config = PreconditioningConfig(data_size=10, 
                                           damping=1e-12,
                                           preconditioner_upd_interval=1,
                                           curvature_upd_interval=1,
                                           ema_decay=-1,
                                           ignore_modules=[nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d,nn.LayerNorm])

optimizers = {
    'KFAC': KFACOptimizer(model_kfac, lr=0.001),
    'Adam': Adam(model_Adam.parameters(), lr=0.001),
    'AdaFisher': AdaFisher(model_AdaFisher, lr=0.001),
    'AdaHessian': Adahessian(model_AdaHesian.parameters(), lr=0.01),
    'Shampoo': SGD(model_Shampoo.parameters(), lr=0.001, momentum=0.9)
}
config = PreconditioningConfig(data_size=10, 
                                           damping=1e-12,
                                           preconditioner_upd_interval=100,
                                           curvature_upd_interval=100,
                                           ema_decay=-1,
                                           ignore_modules=[nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d,nn.LayerNorm])
loss = nn.CrossEntropyLoss()
device = 'cpu'
EPOCHS = 20
optimizer_results = {}
weight = np.zeros((100, 2))
for name, opt in tqdm(optimizers.items()):
    # Reset model for each optimizer
    if name == "Shampoo":
        model_Shampoo.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)
        gm = ShampooGradientMaker(model_Shampoo, config)
        optimizer_results[name] = train_model(model_Shampoo, train_loader, opt, loss, EPOCHS, gm)
    elif name == "Adam":
        model_Adam.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)
        optimizer_results[name] = train_model(model_Adam, train_loader, opt, loss, EPOCHS)
    elif name == "AdaFisher":
        model_AdaFisher.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)
        optimizer_results[name] = train_model(model_AdaFisher, train_loader, opt, loss, EPOCHS)
    elif name == "AdaHessian":
        model_AdaHesian.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)
        optimizer_results[name] = train_model(model_AdaHesian, train_loader, opt, loss, EPOCHS)
    elif name == "KFAC":
        model_kfac.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)
        optimizer_results[name] = train_model(model_kfac, train_loader, opt, loss, EPOCHS)


# Plot the results

In [None]:

def loss_landscape_visualization(zoom_window: bool = False):
    # Assuming optimizer_results is predefined as in your script
    weights_flat = np.vstack([weights for _, (weights, _) in optimizer_results.items()])
    loss_flat = np.concatenate([losses for _, (_, losses) in optimizer_results.items()])

    grid_x, grid_y = np.mgrid[min(weights_flat[:,0]):max(weights_flat[:,0]):200j, min(weights_flat[:,1]):max(weights_flat[:,1]):200j]
    grid_z = griddata(weights_flat, loss_flat, (grid_x, grid_y), method='cubic')

    plt.figure(figsize=(12, 10))  # Larger figure size

    contour = plt.contourf(grid_x, grid_y, grid_z, levels=1000, cmap='jet', alpha=0.3)  # Smoother contour, better colormap
    cbar = plt.colorbar(contour, pad=0.01)
    cbar.formatter = FormatStrFormatter('%.2f')  # Format with two decimals
    cbar.update_ticks()  # Update ticks to use the new formatter
    cbar.ax.set_ylabel('Loss Value', fontsize=24)
    cbar.ax.tick_params(labelsize='large')

    colors = {'AdaFisher': 'black', 'KFAC': 'purple', 'Adam': 'cyan', 'AdaHessian': 'green', "Shampoo": "blue"}
    for name, (weights, _) in optimizer_results.items():
        weights = np.vstack(weights)
        diff_weights = np.diff(weights, axis=0)
        plt.quiver(weights[:-1, 0], weights[:-1, 1], diff_weights[:, 0], diff_weights[:, 1], scale_units='xy', angles='xy', scale=1,
                color=colors[name], label=name, width=0.004, alpha=1)

    # Zoomed-in window
    plt.title('Weight Trajectories on Loss Landscape', fontsize=24, fontweight='bold')
    plt.xlabel('$W_1$', fontsize=30, fontweight='bold')
    plt.ylabel('$W_2$', fontsize=30, fontweight='bold')
    plt.legend(fontsize='xx-large', loc='upper right')
    plt.tick_params(labelsize='large')
    plt.grid(True, linestyle='--', alpha=0.5)
    if zoom_window:
        axins = plt.axes([0.1, 0.15, 0.25, 0.25], facecolor='w')
        zoom_area = {'KFAC': 'purple', 'Adam': 'cyan'}
        for name in ['KFAC', 'Adam']:
            weights = np.vstack(optimizer_results[name][0])
            plt.sca(axins)
            axins.contourf(grid_x, grid_y, grid_z, levels=1000, cmap='jet', alpha=0.3)
            diff_weights = np.diff(weights, axis=0)
            axins.quiver(weights[:-1, 0], weights[:-1, 1], diff_weights[:, 0], diff_weights[:, 1], scale_units='xy', angles='xy', scale=1,
                        color=zoom_area[name], label=name, width=0.01, alpha=1)
            axins.set_xlim(min(weights[:, 0]), max(weights[:, 0])+ 0.3)
            axins.set_ylim(min(weights[:, 1]), max(weights[:, 1]))
            axins.set_xticklabels([])
            axins.set_yticklabels([])
            axins.set_xticks([])  # Remove x-axis ticks
            axins.set_yticks([])  # Remove y-axis ticks

    plt.tight_layout()
    plt.show()

In [None]:
loss_landscape_visualization()