In [None]:
from collections import defaultdict
from itertools import islice
import random
import math

import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torchvision

import matplotlib.pyplot as plt

In [None]:
train_points = 1000
test_points= 1000
robustness_points=500
optimization_steps = 100000
batch_size = 200
weight_decay = 0.01
lr = 1e-3
initialization_scale = 8.0
download_directory = "."

depth = 3
width = 200

log_freq = math.ceil(optimization_steps / 150)

device = 'cpu'
dtype = torch.float64
seed = 0

In [None]:
torch.set_default_dtype(dtype)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)

In [None]:
train = torchvision.datasets.MNIST(root=download_directory, train=True, 
    transform=torchvision.transforms.ToTensor(), download=True)
test = torchvision.datasets.MNIST(root=download_directory, train=False, 
    transform=torchvision.transforms.ToTensor(), download=True)

sub_train = torch.utils.data.Subset(train, range(train_points))
sub_test = torch.utils.data.Subset(test, range(test_points))
train_loader = torch.utils.data.DataLoader(sub_train, batch_size=batch_size, shuffle=True)

In [None]:
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

def compute_accuracy(network, dataset, device, N=2000, batch_size=50):
    with torch.no_grad():
        N = min(len(dataset), N)
        batch_size = min(batch_size, N)
        dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        correct = 0
        total = 0
        for x, labels in islice(dataset_loader, N // batch_size):
            logits = network(x.to(device))
            predicted_labels = torch.argmax(logits, dim=1)
            correct += torch.sum(predicted_labels == labels.to(device))
            total += x.size(0)
        return (correct / total).item()

def compute_loss(network, dataset, device, N=2000, batch_size=50):
    with torch.no_grad():
        N = min(len(dataset), N)
        batch_size = min(batch_size, N)
        dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        one_hots = torch.eye(10, 10).to(device)
        total = 0
        points = 0
        for x, labels in islice(dataset_loader, N // batch_size):
            y = network(x.to(device))
            total += nn.MSELoss()(y, one_hots[labels]).item()
            points += len(labels)
        return total / points

In [None]:
layers = [nn.Flatten()]
for i in range(depth):
    if i == 0:
        layers.append(nn.Linear(784, width))
        layers.append(nn.ReLU())
    elif i == depth - 1:
        layers.append(nn.Linear(width, 10))
    else:
        layers.append(nn.Linear(width, width))
        layers.append(nn.ReLU())
mlp = nn.Sequential(*layers).to(device)
with torch.no_grad():
    for p in mlp.parameters():
        p.data = initialization_scale * p.data

optimizer = torch.optim.AdamW(mlp.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []
norms = []
last_layer_norms = []
log_steps = []

memorised_saved=False
generalised_saved=False

steps = 0
one_hots = torch.eye(10, 10).to(device)
with tqdm(total=optimization_steps) as pbar:
    for x, labels in islice(cycle(train_loader), optimization_steps):
        if (steps < 30) or (steps < 150 and steps % 10 == 0) or steps % log_freq == 0:
            train_losses.append(compute_loss(mlp, sub_train, device, N=len(train)))
            train_accuracies.append(compute_accuracy(mlp, sub_train, device, N=len(train)))
            test_losses.append(compute_loss(mlp, sub_test, device, N=len(test)))
            test_accuracies.append(compute_accuracy(mlp, sub_test, device, N=len(test)))
            log_steps.append(steps)
            with torch.no_grad():
                total = sum(torch.pow(p, 2).sum() for p in mlp.parameters())
                norms.append(float(np.sqrt(total.item())))
                last_layer = sum(torch.pow(p, 2).sum() for p in mlp[-1].parameters())
                last_layer_norms.append(float(np.sqrt(last_layer.item())))
            pbar.set_description("L: {0:1.1e}|{1:1.1e}. A: {2:2.1f}%|{3:2.1f}%".format(
                train_losses[-1],
                test_losses[-1],
                train_accuracies[-1] * 100, 
                test_accuracies[-1] * 100))
            
        if train_accuracies[-1]>0.99 and not(memorised_saved):
            save_log_file=open('save_log.txt','w')
            save_log_file.write(f'initialization_scale={initialization_scale}\n')
            save_log_file.write(f'memorised:step={steps},train_accuracy={train_accuracies[-1]},test_accuracy={test_accuracies[-1]}\n')
            save_log_file.close()
            torch.save(mlp.state_dict(),'memorised_model.pth')
            memorised_saved=True
        
        if test_accuracies[-1]>0.85 and not(generalised_saved):
            save_log_file=open('save_log.txt','a')
            save_log_file.write(f'generalised:step={steps},train_accuracy={train_accuracies[-1]},test_accuracy={test_accuracies[-1]}')
            save_log_file.close()
            torch.save(mlp.state_dict(),'generalised_model.pth')
            generalised_saved=True
            break

        optimizer.zero_grad()
        y = mlp(x.to(device))
        loss = nn.MSELoss()(y, one_hots[labels])
        loss.backward()
        optimizer.step()
        steps += 1
        pbar.update(1)
    if generalised_saved:
        save_log_file=open('save_log.txt','a')
        save_log_file.write(f'generalised:step={steps},train_accuracy={train_accuracies[-1]},test_accuracy={test_accuracies[-1]}')
        save_log_file.close()
        torch.save(mlp.state_dict(),'generalised_model.pth')

In [None]:
ax = plt.subplot(1, 1, 1)
plt.plot(log_steps, train_accuracies, color='red', label='train')
plt.plot(log_steps, test_accuracies, color='green', label='test')
plt.xscale('log')
plt.xlim(10, None)
plt.xlabel("Optimization Steps")
plt.ylabel("Accuracy")
plt.legend(loc=(0.015, 0.75))


plt.legend(loc=(0.015, 0.65))
plt.title(f"depth-3 width-200 ReLU MLP on MNIST\nUnconstrained Optimization α = {initialization_scale}", fontsize=11)
plt.tight_layout()
plt.savefig(f'train_test_acc_init_{initialization_scale}.png')

In [None]:
def robustness_score(model,perturbation_increment=0.01):
    robustness_set=torch.utils.data.Subset(test,range(test_points,test_points+robustness_points))
    robustness_loader=torch.utils.data.DataLoader(robustness_set,batch_size=1,shuffle=False)
    perturbations=np.zeros(len(robustness_loader))

    def verification(img,perturbation,checks=5):
        verified=True
        for n in range(5):
            perturbation_img=2*torch.rand_like(img)-1
            perturbation_img/=torch.norm(perturbation_img)
            if model(img+perturbation*perturbation_img).argmax()!=label:
                return False
        return verified

    pbar=tqdm(enumerate(robustness_loader),total=len(robustness_loader))
    for k,(img,label) in pbar:
        perturbation=0
        while verification(img,perturbation):
            pbar.set_description(f'perturbation={perturbation:.2f}')
            perturbation+=perturbation_increment
        perturbations[k]=perturbation
    return perturbations

In [None]:
layers = [nn.Flatten()]
for i in range(depth):
    if i == 0:
        layers.append(nn.Linear(784, width))
        layers.append(nn.ReLU())
    elif i == depth - 1:
        layers.append(nn.Linear(width, 10))
    else:
        layers.append(nn.Linear(width, width))
        layers.append(nn.ReLU())
mlp_memorised = nn.Sequential(*layers).to(device)
mlp_memorised.load_state_dict(torch.load('memorised_model.pth'))
mlp_memorised.eval()
robustness_memorised=robustness_score(mlp_memorised)
mlp_generalised = nn.Sequential(*layers).to(device)
mlp_generalised.load_state_dict(torch.load('generalised_model.pth'))
mlp_generalised.eval()
robustness_generalised=robustness_score(mlp_generalised)

In [None]:
plt.hist(robustness_memorised[robustness_memorised!=0],alpha=0.5,label='memorised',bins=20,density=True)
plt.hist(robustness_generalised[robustness_generalised!=0],alpha=0.5,label='generalised',bins=20,density=True)
plt.title('"Verified" L2 Robustness')
plt.legend()
plt.show()
plt.savefig(f'verification_histogram_{initialization_scale}.png')