In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as f

from datasets import load_dataset 
from torch.utils.data import DataLoader
from tnn import Trainer, Model, Landscape

In [2]:
dataset = load_dataset("ylecun/mnist", num_proc=2)
train = dataset.get("train")
test = dataset.get("test")

In [3]:
def to_numpy(example):
    arr = np.reshape(example["image"], -1) / 255.0
    example["input"] = arr
    return example

train_dataset = train.map(to_numpy, num_proc=2).select_columns(["input", "label"])
test_dataset = test.map(to_numpy, num_proc=2).select_columns(["input", "label"])

In [4]:
def collate_fn(batch):
    inputs = torch.tensor([ex["input"] for ex in batch]).float()
    labels = torch.tensor([ex["label"] for ex in batch]).long()
    return inputs, labels

trainloader = data.DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=False, collate_fn=collate_fn, num_workers=2)
testloader = data.DataLoader(test_dataset, batch_size=1024, shuffle=False, drop_last=False, collate_fn=collate_fn, num_workers=2)

In [5]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(784, 512)
        self.drop_1 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(512, 512)
        self.drop_2 = nn.Dropout(0.25)
        self.linear_3 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.drop_1(f.relu(self.linear_1(x)))
        x = self.drop_2(f.relu(self.linear_2(x)))
        return dict(logits=f.relu(self.linear_3(x)))

In [6]:
model = Model(MLP())
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), momentum=0.9)

In [None]:
trainer = Trainer(model, optim, loss_fn, trainloader, testloader, path="./train.h5", verbose=True)

In [None]:
metrics = trainer.train(epochs=3)

In [7]:
model = Model(MLP())
landscape = Landscape.from_file("./train.h5", model, loss_fn, testloader, device="cuda", path="./train.h5", verbose=25)

In [None]:
data = landscape.create_meshgrid(resolution=25, endpoints=(-10.0, 10.0), mode="pca")

meshgrid creation using pca
model using cuda
meshgrid creation started
(iter: 25): iter loss: 4.2272
