Script used for training on the two s.o.t. tasks of the paper (2 last tasks of Table 1):  
- VGG model on cifar10 from pytorchblog: http://torch.ch/blog/2015/07/30/cifar.html
- WRN 28 - 10 with dropout 0.3 trained on cifar100

Parameters that can be looped on:
- task solved
- optimization method used (SGD, adaptive gradient methods, +_layca, +_weight_decay)

The results (training history, layer-wise angle deviation curves, test  performance) are saved in a dictionary.
A result can be easily found through: results[task][optimizer]

In [2]:
%matplotlib inline
import mpld3
mpld3.enable_notebook()

%load_ext autoreload

In [34]:
%autoreload
import warnings
import os
import time

import math as m
import numpy as np
np.random.seed(1)

import matplotlib
import matplotlib.pyplot as plt
import pickle

from experiment_utils import history_todict, get_val_split
from rotation_rate_utils import LayerwiseAngleDeviationCurves

from layca_optimizers import SGD

from import_task import import_task
from get_training_utils import get_training_schedule_sot, get_optimizer, get_learning_rate_multipliers

from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator

In [4]:
# utilities for storing the results in pickle files
def load_results():
    if not os.path.isfile('results.p'):
        return {}
    else:
        with open('results.p','rb') as f:
            return pickle.load(f)

def dump_results(results):
    with open('results.p','wb') as f:
        pickle.dump(dict(results),f)

def update_results(path, new_data):
    results = load_results()
    position = results
    for p in path:
        position = position[p]
    # new_data is a dictionary with the new (key,value) pairs
    position.update(new_data)
    dump_results(results)

In [21]:
# if results should be saved in the file or not
save_results = True
if not save_results:
    results = {}
# file for monitoring the experiment's progress
monitor_file = 'monitor_experiment.txt' 

In [None]:
tasks = ['C10-CNN2','C100-WRN']
optimizers = ['SGD','SGD_layca','SGD_weight_decay'] # also available: Adam, RMSProp, Adagrad

for task in tasks:
    x_train, y_train, x_test, y_test, get_model = import_task(task)
    x_val, y_val = x_test, y_test
    
    # creates empty dictionary if first time the task is seen
    if save_results:
        results = load_results()
        if task not in results.keys():
            update_results([],{task:{}})
    elif task not in results.keys():
        results.update({task:{}})
    
    for optimizer in optimizers:
        
        if save_results:
            results = load_results()
            if optimizer not in results[task].keys():
                update_results([task],{optimizer:{'history':{'history':{'val_acc':[-1]}}}}) # save a bad initial performance
        elif optimizer not in results[task].keys():
            results[task].update({optimizer:{'history':{'history':{'val_acc':[-1]}}}})
     
        start = time.time()
        # the weight decay parameter is taken from their original implementation and is specified in import_task.py (0.0005 for both tasks)
        model = get_model(weight_decay = 0.) if 'weight_decay' not in optimizer else get_model()

        batch_size = 128
        # learning rate schedule is taken from their original implementation and is specified in get_training_utils.py
        epochs, lr_scheduler = get_training_schedule_sot(task,optimizer)
        verbose = 0

        batch_frequency = int((x_train.shape[0]/batch_size))+5 # higher value than # of batches per epoch means once per epoch
        ladc = LayerwiseAngleDeviationCurves(batch_frequency = batch_frequency)

        callbacks = [lr_scheduler, ladc]
    
        # C100-WRN + SGD is the only case where nesterov momentum is used (cfr. original implementation)
        if task == 'C100-WRN' and optimizer in ['SGD','SGD_weight_decay']: 
            opt = SGD(lr=0.1, momentum=0.9, nesterov=True) # lr is specified in lr_scheduler, not here
        else:
            opt = get_optimizer(optimizer, 0.1) # lr is specified in lr_scheduler, not here

        model.compile(loss='categorical_crossentropy',
                      optimizer= opt,
                      metrics=['accuracy'])
        
        # data augmentation
        datagen = ImageDataGenerator(width_shift_range=0.125,
                     height_shift_range=0.125,
                     fill_mode='reflect',
                     horizontal_flip=True)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore") # removes warning from keras for slow callback
            history = model.fit_generator(datagen.flow(x_train, y_train,batch_size=batch_size),
                                          steps_per_epoch=x_train.shape[0] // batch_size,
                                          epochs = epochs,
                                          verbose = verbose,
                                          validation_data = (x_val, y_val),
                                          callbacks = callbacks)

        test_performance = model.evaluate(x_test,y_test, verbose = verbose)

        if save_results:
            update_results([task],{optimizer:{'history':history_todict(history),'ladc':ladc.memory,
                                              'test_performance':test_performance}})
        else:
            results[task].update({optimizer:{'history':history_todict(history),'ladc':ladc.memory,
                                             'test_performance':test_performance}})

        with open(monitor_file,'a') as file:
            file.write(task + ', '+optimizer+': done in '+str(time.time()-start)+' seconds.\n')

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200
Epoch 56/200
Epoch 57/200
Epoch 58/200
Epoch 59/200
Epoch 60/200
Epoch 61/200
Epoch 62/200
Epoch 63/200
Epoch 64/200
Epoch 65/200
Epoch 66/200
Epoch 67/200
Epoch 68/200
Epoch 69/200
Epoch 70/200
Epoch 71/200
Epoch 72/200
Epoch 73/200
Epoch 74/200
Epoch 75/200
Epoch 76/200
Epoch 77/200
Epoch 78