In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision.transforms as transforms
import torchvision.datasets as datasets

from sklearn import metrics
from sklearn import decomposition
from sklearn import manifold
import tqdm as tqdm

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd


import copy
import random
import time


def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim=True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

def train(model, iterator, optimizer, criterion, device):

    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for (x, y) in tqdm.tqdm(iterator, desc="Training", leave=False):

        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        y_pred = model(x)

        loss = criterion(y_pred, y)

        acc = calculate_accuracy(y_pred, y)

        loss.backward()

        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)


def evaluate(model, iterator, criterion, device):

    epoch_loss = 0
    epoch_acc = 0

    model.eval()
    num_classes = dict()

    with torch.no_grad():

        for (x, y) in tqdm.tqdm(iterator, desc="Evaluating", leave=False):

            x = x.to(device)
            y = y.to(device)

            y_pred = model(x)
            top_pred = y_pred.argmax(1).cpu().data.numpy()
            for c in top_pred:
                if c not in num_classes:
                    num_classes[c] = 1
                else:
                    num_classes[c] += 1

            loss = criterion(y_pred, y)

            acc = calculate_accuracy(y_pred, y)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator), num_classes

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def plot_images(images):

    n_images = len(images)

    rows = int(np.sqrt(n_images))
    cols = int(np.sqrt(n_images))

    fig = plt.figure()
    for i in range(rows*cols):
        ax = fig.add_subplot(rows, cols, i+1)
        ax.imshow(images[i].view(28, 28).cpu().numpy(), cmap='bone')
        ax.axis('off')

class MyMLP(nn.Module):
    def __init__(self, D, L, res, activ, norm, D_in = None, D_out=None):
        super().__init__()
        assert(activ in ['lin', 'relu','tanh'])
        assert(norm in ['', 'LN','BN1','BN2'])
        if not D_in:
            D_in = D
        if not D_out:
            D_out = D
        self.norm = norm
        self.activ = activ
        self.fcs = [nn.Linear(D, D) for l in range(L)]
        self.fcs[0] = nn.Linear(D_in,D)
        self.fcs[-1] = nn.Linear(D,D_out)
        for fc in self.fcs:
            shape = fc.weight.shape
            weights = 1.0/np.sqrt(shape[1])* torch.normal(0, 1, size=shape)
            fc.weight.data = weights
        
        for li,fc in enumerate(self.fcs):
            self.add_module("fc_"+str(li), fc)
        self.L = L
        self.D = D
        self.res = res
        
    def activation(self, h):
        if self.activ=='lin':
            return h
        elif self.activ=='relu':
            return F.relu(h)
        elif self.activ=='tanh':
            return torch.tanh(h)
    
    def normalize(self, h):
        if self.norm=='LN':
            # h = h - h.mean(1,keepdim=True)
            h = h / torch.norm(h,dim=1,keepdim=True)
        if self.norm=='BN2':
            h = h - h.mean(0,keepdim=True)
            h = h / torch.norm(h,dim=0,keepdim=True)
        if self.norm=='BN1':
            h = h / torch.norm(h,dim=0,keepdim=True)
        return h 
        
    def layer_update(self, l,h):
        h2 = h
        if l>0:
            h2 = self.normalize(h2)
        h2 = self.activation(h2)
        h2 = self.fcs[l](h2)
        if l==self.L-1:
            h2 = torch.softmax(h2,1)
        return h2


    def full_forward(self, h):
        h = h.view(h.shape[0],-1) # flatten images to vectors
        H = [h.cpu().data.numpy()]
        for l in range(self.L):
            h = self.layer_update(l,h)
            H.append(h.cpu().data.numpy())
        return H
    
    def forward(self, h):
        h = h.view(h.shape[0],-1) #  # flatten images to vectors
        for l in range(self.L):
            h = self.layer_update(l,h)
        return h
    
def show_layers(Hidden, Num=None,subplot=True,title=False,save_path=None):
    if not Num:
        Num = len(Hidden)
    inds = np.linspace(0,len(Hidden)-1,Num).astype(np.int32)
    Hidden = [Hidden[i] for i in inds]
    if subplot:
        fig = plt.figure(figsize=(2*Num,2))
    for Hi,(l,H) in enumerate(zip(inds,Hidden)):
        while H.shape[0]==1:
            H = H[0]
        # H = H.data.numpy()
        if subplot:
            ax = fig.add_subplot(1,Num,Hi+1)
        else:
            plt.clf()
            ax = plt.gca()
        
        ax.scatter(H[0],H[1],2)
        if title:
            ax.set_title(f'Layer = {Hi+1}')
                      

In [None]:
ROOT = '.data'
train_data = datasets.MNIST(root=ROOT,
                            train=True,
                            download=True)

mean = train_data.data.float().mean() / 255
std = train_data.data.float().std() / 255
print(f'Calculated mean: {mean}')
print(f'Calculated std: {std}')

train_transforms = transforms.Compose([
                            transforms.RandomRotation(5, fill=(0,)),
                            transforms.RandomCrop(28, padding=2),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[mean], std=[std])
                                      ])

test_transforms = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize(mean=[mean], std=[std])
                                     ])

train_data = datasets.MNIST(root=ROOT,
                            train=True,
                            download=True,
                            transform=train_transforms)

test_data = datasets.MNIST(root=ROOT,
                           train=False,
                           download=True,
                           transform=test_transforms)

Classes = list(range(10))
Classes = [0,1]
train_data = [(x,y) for x,y in train_data if y in Classes]
test_data = [(x,y) for x,y in test_data if y in Classes]

print(f'Number of training examples: {len(train_data)}')
print(f'Number of testing examples: {len(test_data)}')

N_IMAGES = 25

images = [image for image, label in [test_data[i] for i in range(N_IMAGES)]]

plot_images(images)

In [None]:
len([image for image, label in train_data if label<=1])

In [None]:
def class_freq(num_classes,classes, prob=True, label = ''):
    freq = np.zeros(len(classes))
    for c, n in num_classes.items():
        freq[c] = n
    if prob:
        freq = freq / np.sum(freq)
    s = ', '.join([f'f[{c}]={f:.3f}' for f in freq])
    s = label + ': ' + s
    print(s)
    return s
def calc_H(num_classes, classes):
    freq = np.zeros(len(classes))
    for c, n in num_classes.items():
        freq[c] = n
    freq = freq + 1e-2/sum(freq) # to avoid NaN in log
    freq = freq / np.sum(freq)
    tv = np.mean(abs(freq - 1.0/len(freq)))
    return -np.sum(freq * np.log2(freq)),tv

repeat = 50
EPOCHS = 10

D_IN = 28*28
num_classes = len(Classes)
WIDTH = 200
LAYERS = 15

RES = .0
BATCH_SIZE = 16
ACTIVATION = 'relu'
NORM = ''
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

df = pd.DataFrame(columns=['epoch','bath_size','width','layers','normalization','activation', 'H_train','H_val', 'TV_train','TV_test', 'train_acc','val_acc','train_loss','val_loss',])


for ri in range(repeat):
    print(ri, repeat)
    for NORM in ['','BN2']:
        train_iterator = data.DataLoader(train_data,batch_size=BATCH_SIZE)
        test_iterator = data.DataLoader(test_data,batch_size=BATCH_SIZE)

        criterion = nn.CrossEntropyLoss().to(device)
        model = MyMLP(L=LAYERS,D=WIDTH,res=RES,activ=ACTIVATION, norm=NORM, D_in=D_IN, D_out=num_classes)
        model = model.to(device)

        best_valid_loss = float('inf')
        optimizer = optim.SGD(model.parameters(),lr=1e-3)


        for epoch in range(EPOCHS+1):
            if epoch>0:
                train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)
            else:
                train_loss, train_acc = None, None
            valid_loss, valid_acc, freq = evaluate(model, test_iterator, criterion, device)
            class_freq(freq, Classes, 'Valid. freqs')
            # _, _, num_classes = evaluate(model, train_iterator, criterion, device)
            _, _, train_num_classes = evaluate(model, train_iterator, criterion, device)
            _, _, test_num_classes = evaluate(model, test_iterator, criterion, device)
            H_train,TV_train = calc_H(train_num_classes, Classes)
            H_test,TV_test = calc_H(test_num_classes, Classes)
            df.loc[len(df),:] = (epoch, BATCH_SIZE,WIDTH,LAYERS, NORM,ACTIVATION,H_train,TV_test, TV_train,TV_test, train_acc, valid_acc, train_loss, valid_loss)
            if epoch==0:
                continue

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(model.state_dict(), 'tut1-model.pt')

            # print(f'Epoch: {epoch+1}\tTrain Loss: {train_loss:.5f}\tTrain Acc: {train_acc}\tVal. Loss: {valid_loss:.5f}\t Val. Acc: {valid_acc:.5f}')
            # sample_plot(10,2)
            # plt.show()

In [None]:
sns.set_palette("Set1")

df.normalization[df.normalization==''] = 'Vanilla'
df.normalization[df.normalization=='BN2'] = 'BN'


fig, axes = plt.subplots(1,2,figsize=(10,4))
sns.lineplot(ax=axes[0], data=df, x='epoch',y='H_train',hue='normalization',marker='o')
axes[0].set_ylabel('Model entropy (bits)')
axes[0].set_title('(A) Model entropy vs. epoch',fontsize=12)
sns.lineplot(ax=axes[1], data=df, x='epoch',y='train_acc',hue='normalization',marker='o')
axes[1].set_title('(B) Loss vs. epoch',fontsize=12)
axes[1].set_ylabel('Accuracy')


fig.savefig('entropy_vs_SGD.pdf')

In [None]:

WIDTH = 100
LAYERS = 50

RES = .0
repeat = 10
BATCH_SIZE = 16
ACTIVATION = 'lin'
NORM = ''
df2 = pd.DataFrame(columns=['bath_size','width','layers','normalization','activation', 'H_train','H_test', 'TV_train','TV_test'])

for ri in range(repeat):
    print(ri, repeat)
    for BATCH_SIZE in [16]:
        for WIDTH in [50,100]:
            for LAYERS in [2,5,10]:
                for ACTIVATION in ['relu']:#['lin','tanh','relu']:
                    for NORM in ['','BN2']:
                        train_iterator = data.DataLoader(train_data,batch_size=BATCH_SIZE)
                        test_iterator = data.DataLoader(test_data,batch_size=BATCH_SIZE)

                        model = MyMLP(L=LAYERS,D=WIDTH,res=RES,activ=ACTIVATION, norm=NORM, D_in=D_IN, D_out=num_classes)
                        model = model.to(device)

                        valid_loss, valid_acc, train_num_classes = evaluate(model, train_iterator, criterion, device)
                        valid_loss, valid_acc, test_num_classes = evaluate(model, test_iterator, criterion, device)

                        # print(f'{train_num_classes}, {test_num_classes}')
                        H_train,TV_train = calc_H(train_num_classes, Classes)
                        H_test,TV_test = calc_H(test_num_classes, Classes)
                        df2.loc[len(df2),:] = (BATCH_SIZE,WIDTH,LAYERS, NORM,ACTIVATION,H_train,H_test, TV_train,TV_test)

In [None]:
sns.set_theme(style="white")

df2.normalization[df2.normalization==''] = 'Vanilla'
df2.normalization[df2.normalization=='BN2'] = 'BN'
df2.activation[df2.activation=='lin'] = 'linear'

sub = df2.loc[(df2.activation=='relu')].rename(columns={'H_train': 'Model entropy'})
# sns.lineplot(data=sub,x='layers',y='Model entropy',hue='normalization',style='width')
# plt.savefig('entropy_vs_depth.pdf')

# g = sns.FacetGrid(sub, col="width", hue="normalization")
# g.map(sns.lineplot, "layers", "Model entropy", markers=True)
# g.add_legend()
# plt.savefig('entropy_vs_depth.pdf')
sns.set_palette("Set1")

fig, axes = plt.subplots(1,2,figsize=(6*2,5))
L = 'ABC'
for axi, (ax, w) in enumerate(zip(axes, sorted(df2.width.unique()))):
    sns.lineplot(ax=ax, data=sub.loc[sub.width==w],x='layers',y='Model entropy',hue='normalization',marker='o')
    ax.set_xlabel('depth')
    if axi==0:
        ax.set_ylabel('Model entropy (bits)')
    else:
        ax.set_ylabel(None)
        ax.set_yticklabels([])
    ax.set_title(f'({L[axi]}) $width = {w}$',fontsize=15)
    ax.set_ylim(0,1)
    # plt.xscale('log')
fig.savefig('entropy_vs_depth.pdf')