https://jtuckerk.github.io/loss_landscape.html <br>
https://github.com/tomgoldstein/loss-landscape

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate the model, loss function, and optimizer
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Load the MNIST dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=100, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=100, shuffle=False)

num_epochs = 2
checkpoint_interval = 1
weights_matrix = []
trajectory_loss = []
saves_per_epoch = 10

for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % save_every == 0:
            weights = [w.detach().clone().cpu().numpy() for w in model.parameters()]
            weights_matrix.append(weights)
            trajectory_loss.append(loss.item())
            print(f"Saving {int((batch_idx/save_every + 1) * (epoch + 1))}/{int(len(train_loader.dataset) / train_loader.batch_size / save_every * num_epochs)}")

In [None]:
import numpy as np

# Flatten weights and stack them into a single matrix
flat_weights = [np.hstack([w.flatten() for w in epoch_weights]) for epoch_weights in weights_matrix]
weights_matrix_np = np.vstack(flat_weights) #(12, 206922)

from sklearn.decomposition import PCA

pca = PCA(n_components=2)
reduced_weights = pca.fit_transform(weights_matrix_np) #12,2

grid_size = 10 # => grid_size^2 points grid [-2, 2] x [-2 x 2]
grid_range = np.linspace(-10, 10, grid_size)
xx, yy = np.meshgrid(grid_range, grid_range)
all_points = np.column_stack((xx.ravel(), yy.ravel()))

print(all_points.shape)

mapped_points = pca.inverse_transform(all_points) #(100, 206922)


In [None]:
trajectory_loss_reevaluated_landscape = []

for idx, grid_point in enumerate(weights_matrix_np):
    grid_weights = [grid_point[start:end].reshape(shape) for start, end, shape in zip(
        np.cumsum([0] + [w.size for w in weights]),
        np.cumsum([w.size for w in weights]),
        [w.shape for w in weights]
    )]
    with torch.no_grad():
        for i, w in enumerate(model.parameters()):
            w.copy_(torch.tensor(grid_weights[i], device=device))

        # Compute the loss for the current grid point
        total_loss = 0
        total_samples = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item() * len(data)
            total_samples += len(data)
        trajectory_loss_reevaluated_landscape.append(total_loss / total_samples)
        print(f"{idx+1} {trajectory_loss_reevaluated_landscape[-1]}")

In [None]:
loss_landscape = []

for idx, grid_point in enumerate(mapped_points):
    grid_weights = [grid_point[start:end].reshape(shape) for start, end, shape in zip(
        np.cumsum([0] + [w.size for w in weights]),
        np.cumsum([w.size for w in weights]),
        [w.shape for w in weights]
    )]
    with torch.no_grad():
        for i, w in enumerate(model.parameters()):
            w.copy_(torch.tensor(grid_weights[i], device=device))

        # Compute the loss for the current grid point
        total_loss = 0
        total_samples = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item() * len(data)
            total_samples += len(data)
        loss_landscape.append(total_loss / total_samples)
        print(f"{idx+1} {loss_landscape[-1]}")
        
loss_landscape_grid = np.array(loss_landscape).reshape(grid_size, grid_size)

In [None]:
trajectory_loss_np = np.array(trajectory_loss)
trajectory_loss_reevaluated_np = np.array(trajectory_loss_reevaluated_landscape)

# Compute absolute differences
absolute_diff = np.abs(trajectory_loss_np - trajectory_loss_reevaluated_np)

# Compute relative differences
relative_diff = np.abs((trajectory_loss_np - trajectory_loss_reevaluated_np) / trajectory_loss_np)

# Print the differences
print("Absolute differences:", absolute_diff)
print("Relative differences:", relative_diff)

# Compute and print summary statistics
print("Mean absolute difference:", np.mean(absolute_diff))
print("Mean relative difference:", np.mean(relative_diff))


In [None]:
import plotly.graph_objs as go

surface = go.Surface(x=xx, y=yy, z=loss_landscape_grid, opacity=0.8, name="grid point", coloraxis="coloraxis", colorscale="Viridis")

# Create a custom color array for the markers, with the first color as blue and the rest as red
colors = ["blue"] + ["red"] * (len(reduced_weights) - 2) + ["green"]
sizes = [8] + [5] * (len(reduced_weights) - 2) + [8]

trajectory = go.Scatter3d(
    x=reduced_weights[:, 0],
    y=reduced_weights[:, 1],
    z=trajectory_loss_reevaluated_landscape,
    mode='markers+lines',
    line=dict(color="red"),
    marker=dict(color=colors, size=sizes),
    name="Training Trajectory",
)

layout = go.Layout(
    scene=dict(
        xaxis_title='PC1',
        yaxis_title='PC2',
        zaxis_title=' Loss'
    ),
    coloraxis=dict(colorbar=dict(title="Loss magnitude"), colorscale="Viridis"), 
)

# Create the interactive plot
fig = go.Figure(data=[surface, trajectory], layout=layout)

# Show the plot
fig.show()


In [None]:
"""import plotly.io as pio

# Save the figure as an interactive HTML file
pio.write_html(fig, 'output_figure.html')"""