## Notebook for running single training and compare them

Imports and set seed optionally:

In [None]:
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor


from layer_insertion_loop import layer_insertion_loop
from train_and_test_ import train, check_testerror
from nets import feed_forward, two_weight_resnet
from utils import ema

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

torch.set_num_threads(8)

Define and list hyperparameters:

In [None]:
fix_width = 100
no_iters = 2
lr_decrease_after_li =0.8
epochs = [100,50,50]
wanted_testerror = 2.
_type = 'fwd'
act_fun = nn.ReLU
interval_testerror = 1

batch_size = 20 #200 # 60000 for full batch

lr_init = 1e-1
optimizer_type = 'SGD'
lrscheduler_type = 'StepLR'
lrscheduler_args = {'step_size': 10,
                    'gamma': 0.1}


# for classical 
epochs_classical = sum(epochs)
lr_init_classical = lr_init
lrscheduler_args_classical = {'step_size': 10,
                    'gamma': 0.1}

Load dataset:

In [None]:
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)


In [None]:
print('no of iterations in one epoch:',int(len(training_data)/batch_size))
print(len(training_data))
print(len(test_data))

Create dataloader:

In [None]:
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=10000)

Build models:

In [None]:
kwargs_net ={
        'hidden_layers': 1,
        'dim_hidden_layers': fix_width,
        'act_fun': act_fun,
        'type': _type
}

dim_in = 28*28
dim_out = 10

In [None]:
# classical net
kwargs_net_classical = {
    'hidden_layers': 3,
    'dim_hidden_layers': fix_width,
    'act_fun': act_fun,
    'type': _type
}

Determine which trainings are run:

In [None]:
T1 = True
T2 = True
T3 = True

## Training with layer insertion:

In [None]:
# build model
if _type=='fwd':
    model_init = feed_forward(dim_in, dim_out,**kwargs_net)
if _type=='res2':
    model_init = two_weight_resnet(dim_in, dim_out,**kwargs_net)

param_init = torch.nn.utils.parameters_to_vector(model_init.parameters())

Check initial test error of the model

In [None]:
print(check_testerror(test_dataloader,model_init))

In [None]:
# train ali 1
if T1:
    model1, mb_losses1, test_errors_short1, test_errors1, exit_flag1 = layer_insertion_loop(
        iters=no_iters,
        epochs=epochs,
        model= model_init,
        kwargs_net=kwargs_net,
        dim_in=dim_in,
        dim_out=dim_out,
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        lr_init=lr_init,
        wanted_test_error=wanted_testerror,
        mode='abs max',
        optimizer_type=optimizer_type,
        lrschedule_type=lrscheduler_type,
        lrscheduler_args=lrscheduler_args,
        check_testerror_between=interval_testerror,
        decrease_after_li=lr_decrease_after_li,
        print_param_flag=False,
        start_with_backtracking=None,
        v2=False
    ) 

In [None]:
print(check_testerror(test_dataloader, model1))

plot results:

In [None]:
end_list=[]
for i,e in enumerate(epochs):
    end_list.append(int(e*len(training_data)/batch_size))
    end_list.append(1) 
end_list.pop() # removes last 1 which was too much

# todo plot
if T1:
    plt.figure(figsize=(20,5))
    plt.grid(which='major', axis='x', zorder=-1.0)
    xfull = range(1,len(mb_losses1)+1)
    yfull = mb_losses1
    begin=0
    end=0
    for i in range(len(end_list)):
        end=end+end_list[i]
        x_curr= xfull[begin:end]
        y_curr= yfull[begin:end]
        begin=end
        plt.plot(x_curr,y_curr,'o')
    plt.xlabel('minibatch iterations')
    plt.ylabel('minibatch loss')

    plt.yscale('log')
    #plt.ylim((0.4,.6))
    #plt.xlim((99300,99800))
    
    plt.grid()
    plt.plot(np.zeros(len(mb_losses1)))

In [None]:
plt.figure(figsize=(20,5))
plt.plot(test_errors1,'o')
plt.ylim(bottom=0)
plt.grid()
plt.xlabel('epochs')

## Second training

In [None]:
# build net
if _type=='fwd':
    model_init2 = feed_forward(dim_in, dim_out,**kwargs_net)
if _type=='res2':
    model_init2 = two_weight_resnet(dim_in, dim_out,**kwargs_net)

torch.nn.utils.vector_to_parameters(param_init, model_init2.parameters())

In [None]:
if T2:
    model2, mb_losses2, test_errors_short2, test_errors2, exit_flag2 = layer_insertion_loop(
        iters=no_iters,
        epochs=epochs,
        model= model_init2,
        kwargs_net=kwargs_net,
        dim_in=dim_in,
        dim_out=dim_out,
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        lr_init=lr_init,
        wanted_test_error=wanted_testerror,
        mode='abs min',
        optimizer_type=optimizer_type,
        lrschedule_type=lrscheduler_type,
        lrscheduler_args=lrscheduler_args,
        check_testerror_between=interval_testerror,
        decrease_after_li=lr_decrease_after_li,
        print_param_flag=False,
        start_with_backtracking=None,
        v2=False
    ) 

plot results:

In [None]:
# todo plot
if T2:
    plt.figure(figsize=(20,5))
    plt.grid(which='major', axis='x', zorder=-1.0)
    xfull = range(1,len(mb_losses2)+1)
    yfull = mb_losses2
    begin=0
    end=0
    for i in range(len(end_list)):
        end=end+end_list[i]
        x_curr= xfull[begin:end]
        y_curr= yfull[begin:end]
        begin=end
        plt.plot(x_curr,y_curr,'o')
    plt.xlabel('minibatch iterations')
    plt.ylabel('minibatch loss')

    plt.yscale('log')
    #plt.ylim((0.4,.6))
    #plt.xlim((99300,99800))
    
    plt.grid()
    plt.plot(np.zeros(len(mb_losses2)))

In [None]:
plt.figure(figsize=(20,5))
#plt.ylim((0,100))
plt.plot(test_errors2,'o')
plt.ylim(bottom=0)
plt.grid()

## Third training

In [None]:
# build model
if _type=='fwd':
    model_classical = feed_forward(dim_in, dim_out,**kwargs_net_classical)
if _type=='res2':
    model_classical = two_weight_resnet(dim_in, dim_out,**kwargs_net_classical)

In [None]:
# build optimizer
if optimizer_type == 'SGD':
    optimizer_classical = torch.optim.SGD(model_classical.parameters(), lr_init_classical)

# build lr scheduler
if lrscheduler_type == 'StepLR':
    step_size = lrscheduler_args_classical['step_size']
    gamma = lrscheduler_args_classical['gamma']
    lrscheduler_classical = torch.optim.lr_scheduler.StepLR(
            optimizer_classical, step_size=step_size, gamma=gamma)


In [None]:
if T3:
    print('training classically on model', model_classical)
    mblosses_classical, lr_end, test_error_classical, exit_flag_classical = train(model_classical,
                                                             train_dataloader=train_dataloader,
                                                             epochs=epochs_classical,
                                                             optimizer=optimizer_classical,
                                                             scheduler=lrscheduler_classical,
                                                             wanted_testerror=wanted_testerror,
                                                             start_with_backtracking=None,
                                                             check_testerror_between=interval_testerror,
                                                             test_dataloader=test_dataloader,
                                                             print_param_flag=False
                                                             )

plot results:

In [None]:
if T3:
    plt.figure(figsize=(20,5))
    plt.plot(mblosses_classical,'o')
    plt.yscale('log')
    
    #plt.xlim((-10,3000))
    #plt.ylim((0.68,0.8))

In [None]:
plt.figure(figsize=(20,5))
plt.plot(test_error_classical,'o')
plt.ylim(bottom=0)
plt.grid()

## Comparison of the trainings as plot:


In [None]:
plt.figure(figsize=(20,5))
plt.plot(np.zeros(max(len(mb_losses1),len(mb_losses2),len(mblosses_classical))))


plt.plot(mb_losses2, label='ali min')
plt.plot(mb_losses1,label='ali max')

plt.plot(mblosses_classical, label='classical coarse')
plt.yscale('log')

plt.legend()
#plt.ylim((1e-8,1.1))
#plt.xlim((99800,102000))
plt.show()

In [None]:
smooth_factor = 0.99
s1= ema(mb_losses1, smooth_factor)
s2= ema(mb_losses2, smooth_factor)
s3 = ema(mblosses_classical, smooth_factor)

In [None]:
plt.figure(figsize=(20,5))
plt.plot(np.zeros(max(len(s1),len(s2),len(s3))))


plt.plot(s2, label='ali min')
plt.plot(s1,label='ali max')

plt.plot(s3, label='classical coarse')
plt.yscale('log')

plt.legend()
#plt.ylim((1e-8,1.1))
#plt.xlim((99800,102000))
plt.show()

In [None]:
plt.figure(figsize=(20,5))
plt.plot(test_errors1,'o',label='absmax')
plt.plot(test_errors2,'o', label='absmin')
plt.plot(test_error_classical,'o',label='comparison')
plt.grid()
plt.ylim(bottom=0)
plt.legend()