In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.utils.data as data
import torch.optim as optim
import numpy as np

from datasets import load_dataset
from src.landscape import Landscape
from src.trainer import Trainer
from src.plot import Plot

In [None]:
def pre_process(example):
    arr = np.reshape(example["input"], -1)
    example["input"] = arr
    return example

In [None]:
mnist = load_dataset("mnist", trust_remote_code=True)
train, test = mnist.get("train"), mnist.get("test")

In [None]:
train.set_format(type="numpy", columns=["image", "label"])
test.set_format(type="numpy", columns=["image", "label"])
train = train.rename_column("image", "input")
test = test.rename_column("image", "input")
train = train.map(pre_process, num_proc=4)
test = test.map(pre_process, num_proc=4)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
train_inputs = torch.from_numpy(train["input"]).float().squeeze() / 255.0
test_inputs = torch.from_numpy(test["input"]).float().squeeze() / 255.0
train_labels = torch.from_numpy(train["label"]).long()
test_labels = torch.from_numpy(test["label"]).long()

In [None]:
train_dataset = data.TensorDataset(train_inputs, train_labels)
test_dataset = data.TensorDataset(test_inputs, test_labels)

In [None]:
class Model(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(28 * 28, 512)
        self.drop_1 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(512, 512)
        self.drop_2 = nn.Dropout(0.25)
        self.linear_3 = nn.Linear(512, 10)

    def forward(self, x):
        x = f.relu(self.drop_1(self.linear_1(x)))
        x = f.relu(self.drop_2(self.linear_2(x)))
        return f.relu(self.linear_3(x))


In [None]:
model = Model()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()
batch_size = 256

In [None]:
train_loader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    pin_memory=True,
)

test_loader = data.DataLoader(
    test_dataset,
    batch_size=len(test_dataset),
    shuffle=True,
    drop_last=False,
    pin_memory=True,
)

In [None]:
trainer = Trainer(model, optimizer, loss_fn, write=True, metric_path="./metrics.h5", param_path="./params.pt", traj_path="./traj.pt")

In [None]:
metrics, trajectory = trainer.train(train_loader, test_loader, epochs=50, device=device, print_every=10)

In [None]:
model = Model()
landscape = Landscape.from_files(model, loss_fn, param_path="./params.pt", traj_path="./traj.pt")

In [None]:
(X, Y, Z), trajectory = landscape.create_landscape(test_loader, mode="pca", print_every=25)

In [None]:
plot = Plot.from_files(mesh_path="./landscape.h5")

In [None]:
plot.plot_surface_3D(file_path="./loss-landscape.png")

In [None]:
plot.plot_contour(levels=40, plot_trajectory=True, file_path="./loss-contour.png")

In [None]:
plot.animate_contour(levels=40,  file_path="./animated-contour.gif")