In [1]:
from itertools import islice

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch

import torch.nn as nn
import torchvision


def train(params):
    data_size = int(params[0])
    alpha = params[1]
    
    steps = 10000
    
    def cycle(iterable):
        while True:
            for x in iterable:
                yield x
            
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    #train = torchvision.datasets.MNIST(root="/tmp", train=True, transform=torchvision.transforms.ToTensor(), download=True)
    #test = torchvision.datasets.MNIST(root="/tmp", train=False, transform=torchvision.transforms.ToTensor(), download=True)
    #train_loader = torch.utils.data.DataLoader(train, batch_size=50, shuffle=True)
    
    def accuracy(network, dataset, device, N=2000, batch_size=50):
        dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        correct = 0
        total = 0
        for x, labels in islice(dataset_loader, N // batch_size):
            logits = network(x.to(device))
            predicted_labels = torch.argmax(logits, dim=1)
            correct += torch.sum(predicted_labels == labels.to(device))
            total += x.size(0)
        return correct / total
    
    def loss_f(network, dataset, device, N=2000, batch_size=50):
        dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        loss = 0
        total = 0
        for x, labels in islice(dataset_loader, N // batch_size):
            logits = network(x.to(device))
            loss += torch.sum((logits-torch.eye(10,)[labels])**2)
            total += x.size(0)
        return loss / total

    train = torchvision.datasets.MNIST(root="/tmp", train=True, transform=torchvision.transforms.ToTensor(), download=True)
    test = torchvision.datasets.MNIST(root="/tmp", train=False, transform=torchvision.transforms.ToTensor(), download=True)

    data_size = data_size
    train = torch.utils.data.Subset(train, range(data_size))
    train_loader = torch.utils.data.DataLoader(train, batch_size=100, shuffle=True)
    
    def L2(model):
        L2_ = 0.
        for p in mlp.parameters():
            L2_ += torch.sum(p**2)
        return L2_

    def rescale(model, alpha):
        for p in mlp.parameters():
            p.data = alpha * p.data
            
            
    width = 200
    mlp = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28, width),
        nn.ReLU(),
        nn.Linear(width, width),
        nn.ReLU(),
        nn.Linear(width, 10)
    ).to(device)

    rescale(mlp, alpha)
    L2_ = L2(mlp)
    
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.AdamW(mlp.parameters(), lr=1e-3, weight_decay=0.0)
    
    one_hots = torch.eye(10, 10).to(device)
    
    mlp.eval()
    print("Initial accuracy: {0:.4f}".format(accuracy(mlp, test, device)))

    test_accuracies = []
    train_accuracies = []

    i = 0
    mlp.train()
    pbar = tqdm(islice(cycle(train_loader), steps), total=steps)
    
    best_train_loss = 1e4
    best_test_loss = 1e4
    best_train_acc = 0.
    best_test_acc = 0.
    
    for x, label in pbar:
        mlp.train()
        optimizer.zero_grad()
        loss_train = loss_fn(mlp(x.to(device)), one_hots[label])
        loss_train.backward()
        optimizer.step()
        L2_new = L2(mlp)
        # rescale weights such that the weight norm remains a constant in training.
        rescale(mlp, torch.sqrt(L2_/L2_new))
        if i % 50 == 0:
            with torch.no_grad():
                mlp.eval()
                train_acc = accuracy(mlp, train, device).item()
                test_acc = accuracy(mlp, test, device).item()
                train_loss = loss_f(mlp, train, device).item()
                test_loss = loss_f(mlp, test, device).item()
                #train_accuracies.append(train_acc)
                #test_accuracies.append(test_acc)
                if train_acc > best_train_acc:
                    best_train_acc = train_acc
                if test_acc > best_test_acc:
                    best_test_acc = test_acc
                if train_loss < best_train_loss:
                    best_train_loss = train_loss
                if test_loss < best_test_loss:
                    best_test_loss = test_loss
                mlp.train()
                pbar.set_description("{:3.3f} | {:3.3f} | {:3.3f} | {:3.3f}".format(train_acc, test_acc, train_loss, test_loss))
        i += 1
    np.savetxt("./mnist_landscape/trainacc_%d_%.2f.txt"%(data_size, alpha), np.array([best_train_acc]))
    np.savetxt("./mnist_landscape/testacc_%d_%.2f.txt"%(data_size, alpha), np.array([best_test_acc]))
    np.savetxt("./mnist_landscape/trainloss_%d_%.2f.txt"%(data_size, alpha), np.array([best_train_loss]))
    np.savetxt("./mnist_landscape/testloss_%d_%.2f.txt"%(data_size, alpha), np.array([best_test_loss]))
    


In [None]:
import numpy as np

data_sizes = list([int(item) for item in 10**np.linspace(1,4,num=22)]) + list([int(item) for item in list(10**np.linspace(4,5,num=8)[1:6])+[60000]])

alphas = 10**np.linspace(-1,1,num=21)


xx, yy = np.meshgrid(data_sizes, alphas)
params = list(np.transpose(np.array([xx.reshape(-1,), yy.reshape(-1,)])))

from multiprocess import Pool

if __name__ == '__main__':
    with Pool(11) as p:
        print(p.map(train, params))