In [None]:
!pip3 install functorch

In [None]:
from collections import OrderedDict
import math

from functorch import jacrev

import torch
from torch import nn
from torch import autograd
from torch.nn.functional import cross_entropy
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import optim

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0, 1)])
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=trans)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=trans)

In [None]:
batch_size = 128
t_batch_size = 128

In [None]:
trainloader = DataLoader(
    dataset=mnist_trainset,
    batch_size=batch_size,
    shuffle=True
)
testloader = DataLoader(
    dataset=mnist_testset,
    batch_size=t_batch_size,
    shuffle=False
)

In [None]:
x_train = []
y_train = []
for data_point in mnist_trainset:
    x_train.append(torch.flatten(data_point[0][0]))
    y_train.append(data_point[1])
x_train = tuple(x_train)
x_train = torch.stack(x_train)
# x_train = torch.flatten(x_train, start_dim=1)
x_train = x_train.to(device)
y_train = torch.LongTensor(y_train).to(device)

In [None]:
x_test = []
y_test = []
for data_point in mnist_testset:
    x_test.append(data_point[0][0])
    y_test.append(data_point[1])
x_test = torch.stack(x_test)
x_test = torch.flatten(x_test, start_dim=1)
y_test = torch.IntTensor(y_test)

In [None]:
hidden_size = 500

class MLP(torch.nn.Module):

    def __init__(self):
        super(MLP, self).__init__()
        self.model_logits = nn.Sequential(
            OrderedDict([
                ('linear1', nn.Linear(28*28, hidden_size)),
                ('relu1', nn.ReLU()),
                ('linear2', nn.Linear(hidden_size, hidden_size)),
                ('relu2', nn.ReLU()),
                ('linear3', nn.Linear(hidden_size, hidden_size)),
                ('relu3', nn.ReLU()),
                ('linear4', nn.Linear(hidden_size, hidden_size)),
                ('relu4', nn.ReLU()),
                ('linear5', nn.Linear(hidden_size, hidden_size)),
                ('relu5', nn.ReLU()),
                ('linear6', nn.Linear(hidden_size, hidden_size)),
                ('relu6', nn.ReLU()),
                ('linear7', nn.Linear(hidden_size, hidden_size)),
                ('relu7', nn.ReLU()),
                ('linear8', nn.Linear(hidden_size, 10))
            ])
        )
        self.apply(self._init_weights)
        self.softmax = nn.Softmax(dim=1)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            input_size = module.in_features
            scale = 1/math.sqrt(input_size)
            torch.nn.init.trunc_normal_(module.weight, mean=0.0, std=scale, a=-2*scale, b=2*scale)
            if module.bias is not None:
                module.bias.data.zero_()


    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.model_logits(x)
        x = self.softmax(x)
        return x

In [None]:
_ = torch.manual_seed(0)
mlp = MLP()
mlp.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mlp.parameters(), lr=.02, weight_decay=0)

x_test = x_test.to(device)
y_test = y_test.type(torch.LongTensor).to(device)

training_accs = []
validation_accs = []
training_losses = []
validation_losses = []

print("Starting training...")

for epoch in range(200):  # loop over the dataset multiple times

    training_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = mlp(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        batch_loss = loss.item() * batch_size
        training_loss += batch_loss
    
    print("Epoch", epoch+1, "finished.")

    training_loss = training_loss/60000
    training_losses.append(training_loss)
    print("Training loss is:", training_loss)

    y_train_pred = mlp(x_train)
    train_target = torch.argmax(y_train_pred, dim=1)
    train_acc = torch.sum(train_target==y_train).item()/len(train_target)
    training_accs.append(train_acc)
    print("Training accuracy is:", train_acc)
    
    y_pred = mlp(x_test).to(device)

    loss = criterion(y_pred, y_test)
    valid_loss = loss.item()
    validation_losses.append(valid_loss)
    print("Validation loss is:", valid_loss)


    target = torch.argmax(y_pred, dim=1)
    acc = torch.sum(target==y_test)
    validation_accs.append(acc)
    print("Validation accuracy is:", torch.sum(target==y_test).item()/(len(target)))

    l2_norm = 0
    for layer in mlp.model_logits:
        if isinstance(layer, nn.Linear):
            t = layer.weight
            l2_norm += torch.linalg.matrix_norm(t).item()
    print("L2 norm is", l2_norm, "\n")

    # print(f'Epoch {epoch+1} \n Training Loss: {running_loss} \n Validation Loss: {valid_loss / len(testloader)}')
    # if epoch%20 == 0:
    #     gcs = []
    #     i = 0
    #     while i < 10:
    #         jacobian_logits = jacrev(mlp.model_logits)(x_train[i])
    #         frob_square_norms = torch.square(torch.linalg.matrix_norm(jacobian_logits))
    #         gcs.append(frob_square_norms.item())
    #         i += 1
    #     gc = sum(gcs)/len(gcs)
    #     print("Approximate GC:", gc)
#     print(frob_square_norms)
#     print("EPOCH", epoch+1,"GEOMETRIC COMPLEXITY:")

print('Finished Training')

Starting training...
Epoch 1 finished.
Training loss is: 2.3037174743652344
Training accuracy is: 0.17813333333333334
Validation loss is: 2.302281141281128
Validation accuracy is: 0.179
L2 norm is 140.52698469161987 

Epoch 2 finished.
Training loss is: 2.303323304748535
Training accuracy is: 0.14108333333333334
Validation loss is: 2.3018643856048584
Validation accuracy is: 0.1444
L2 norm is 140.52956652641296 

Epoch 3 finished.
Training loss is: 2.3028902155558266
Training accuracy is: 0.13623333333333335
Validation loss is: 2.3013722896575928
Validation accuracy is: 0.1386
L2 norm is 140.53464150428772 

Epoch 4 finished.
Training loss is: 2.302337832132975
Training accuracy is: 0.16415
Validation loss is: 2.300708770751953
Validation accuracy is: 0.1664
L2 norm is 140.54297375679016 

Epoch 5 finished.
Training loss is: 2.3015418340047202
Training accuracy is: 0.23051666666666668
Validation loss is: 2.299710273742676
Validation accuracy is: 0.2343
L2 norm is 140.55613565444946 

Ep

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
os.chdir('/content/drive/My Drive/')

In [None]:
os.getcwd()

'/content/drive/MyDrive'

In [None]:
import pickle

training_stats = [
    training_accs,
    validation_accs,
    training_losses,
    validation_losses
]

output_file = open('6HL_500W_LR02_MNIST_MODEL_DATA.bin', 'wb')
pickle.dump(training_stats, output_file)
output_file.close()

In [None]:
d = pickle.load(open('6HL_500W_LR02_MNIST_MODEL_DATA.bin', 'rb'))
d

[[0.17813333333333334,
  0.14108333333333334,
  0.13623333333333335,
  0.16415,
  0.23051666666666668,
  0.29636666666666667,
  0.31211666666666665,
  0.2058,
  0.3596666666666667,
  0.6227,
  0.7253,
  0.74265,
  0.7515,
  0.75555,
  0.7624833333333333,
  0.7634333333333333,
  0.7655,
  0.7710166666666667,
  0.7732666666666667,
  0.7759,
  0.7750833333333333,
  0.7791166666666667,
  0.7799666666666667,
  0.7819833333333334,
  0.78325,
  0.7840833333333334,
  0.78375,
  0.78675,
  0.7874333333333333,
  0.7853333333333333,
  0.79025,
  0.7902833333333333,
  0.7915166666666666,
  0.7928,
  0.79275,
  0.7934666666666667,
  0.7935,
  0.7948333333333333,
  0.7959833333333334,
  0.8602,
  0.8762166666666666,
  0.9429666666666666,
  0.9449666666666666,
  0.9548166666666666,
  0.9564333333333334,
  0.9633166666666667,
  0.96695,
  0.97065,
  0.9718,
  0.9723666666666667,
  0.97515,
  0.97425,
  0.97195,
  0.9774333333333334,
  0.97425,
  0.9791666666666666,
  0.9787333333333333,
  0.9785,
  0.