This notebook is copied from the eponym notebook in the parent folder, with additional options used for our analysis (e.g. storing the norm of weights and gradients during training.)

In [1]:
%matplotlib inline
import mpld3
mpld3.enable_notebook()

%load_ext autoreload

In [2]:
%autoreload
import sys
sys.path.insert(0, "../")

import warnings
import os
import time

import math as m
import numpy as np
np.random.seed(1)

import matplotlib
import matplotlib.pyplot as plt
import pickle

from experiment_utils import history_todict, get_val_split
from rotation_rate_utils import LayerwiseAngleDeviationCurves, get_kernel_layer_names
from training_monitoring import trainingMemories, LayerRotationRateCurves, Adam_2nd_moment_memory

from import_task import import_task
from get_training_utils import get_training_schedule, get_stopping_criteria, get_optimizer, get_learning_rate_multipliers
from LARS import LARS

from keras.callbacks import ModelCheckpoint

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
def load_results():
    if not os.path.isfile('results.p'):
        return {}
    else:
        with open('results.p','rb') as f:
            return pickle.load(f)

def dump_results(results):
    with open('results.p','wb') as f:
        pickle.dump(dict(results),f)

def update_results(path, new_data):
    results = load_results()
    position = results
    for p in path:
        position = position[p]
    # new_data is a dictionary with the new (key,value) pairs
    position.update(new_data)
    dump_results(results)

In [4]:
save_results = False
if not save_results:
    results = {}
monitor_file = 'monitor_experiment.txt' 

# monitor_norms determines if the norms of the weights and weight gradients should be monitored during training and saved afterwards
# it is only used for the SGD analysis
monitor_norms = True
monitor_layer_rotation_rates = True
monitor_adam_moment = False

In [5]:
tasks = ['C10-CNN1']
optimizers = ['SGD', 'SGD_weight_decay'] # order is optimizer,_layca,_weight_decay
alphas = [0.]#[-0.8, -0.6, -0.4, -0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.6, 0.8]
lrs = [3**-1]#3.**(-i) for i in range(-2,8)]

for task in tasks:
    x_train, y_train, x_test, y_test, get_model = import_task(task)
    
    # validation set is needed for early stopping or learning rate/alpha selection
    [x_train, y_train], [x_val, y_val] = get_val_split(x_train,y_train, 0.1)
    
    # creates empty dictionary if first time the task is seen
    if save_results:
        results = load_results()
        if task not in results.keys():
            update_results([],{task:{}})
    elif task not in results.keys():
        results.update({task:{}})
    
    for optimizer in optimizers:
        
        weight_decay = 0. if 'weight_decay' not in optimizer else 1e-3
        
        if save_results:
            results = load_results()
            if optimizer not in results[task].keys():
                update_results([task],{optimizer:{'history':{'history':{'val_acc':[-1]}}}}) # save a bad initial performance
        elif optimizer not in results[task].keys():
            results[task].update({optimizer:{'history':{'history':{'val_acc':[-1]}}}})
     
        for alpha in alphas:
            # layer-wise learning rate multipliers parametrized by alpha have not been implemented for the adaptive gradient methods
            # changing alpha in this code will have no effect on their training
            
            if save_results:
                results = load_results()
                if alpha not in results[task][optimizer].keys():
                    update_results([task,optimizer],{alpha:{}})
            elif alpha not in results[task][optimizer].keys():
                results[task][optimizer].update({alpha:{}})
            
            for lr in lrs:
                start = time.time()
                model = get_model(weight_decay)
                
                # we use smaller initial weights when using LARS, as the weights increase a lot during training
                if optimizer == 'LARS':
                    layer_names = get_kernel_layer_names(model)
                    factor = 1/4.
                    for l in layer_names:
                        w = model.get_layer(l).get_weights()
                        w[0] = factor*w[0]
                        model.get_layer(l).set_weights(w)

                batch_size = 128
                epochs, lr_scheduler = get_training_schedule(task,lr)
                stop_callback = get_stopping_criteria(task)
                verbose = 0

                batch_frequency = int((x_train.shape[0]/batch_size))+5 # once per epoch
                ladc = LayerwiseAngleDeviationCurves(batch_frequency = batch_frequency)

                callbacks = [lr_scheduler, ladc] #, stop_callback
                if monitor_norms:
                    sample_indices = np.random.choice(x_train.shape[0],min(x_train.shape[0],2000)) # selecting 2000 samples randomly
                    norm_memory = trainingMemories(x_train[sample_indices],y_train[sample_indices],batch_frequency)
                    callbacks.append(norm_memory)
                if monitor_layer_rotation_rates:
                    lrrc = LayerRotationRateCurves()
                    callbacks.append(lrrc)
                if monitor_adam_moment:
                    adamMoment = Adam_2nd_moment_memory(batch_frequency = batch_frequency)
                    callbacks.append(adamMoment)

                multipliers = get_learning_rate_multipliers(model,alpha = alpha)
                metrics = ['accuracy', 'top_k_categorical_accuracy'] if 'tiny' in task else ['accuracy']
                opt = get_optimizer(optimizer, lr, multipliers) if optimizer != 'LARS' else LARS(model,lr,multipliers=multipliers)
                model.compile(loss='categorical_crossentropy',
                              optimizer= opt,
                              metrics=metrics)

                # cifar100 and tinyImagenet need early stopping
                if 'C100' in task or 'tiny' in task:
                    weights_file = 'saved_weights/best_weights_'+str(np.random.randint(1e6))+'.h5'
                    callbacks += [ModelCheckpoint(weights_file, monitor='val_acc', save_best_only=True, save_weights_only = True)]


                with warnings.catch_warnings():
                    warnings.simplefilter("ignore") # removes warning from keras for slow callback
                    history = model.fit(x_train,y_train,
                                        epochs = epochs,
                                        batch_size = batch_size,
                                        verbose = verbose,
                                        validation_data = (x_val, y_val),
                                        callbacks = callbacks)

                if 'C100' in task or 'tiny' in task:
                    model.load_weights(weights_file)

                test_performance = model.evaluate(x_test,y_test, verbose = verbose)

                if save_results:
                    update_results([task, optimizer,alpha],{lr:{'history':history_todict(history),'ladc':ladc.memory,
                                                                'test_performance':test_performance}})
                else:
                    results[task][optimizer][alpha].update({lr:{'history':history_todict(history),'ladc':ladc.memory,
                                                                'test_performance':test_performance}})
                
                # if it beats current best validation performance of (task,optimizer) pair
                if save_results:
                    results = load_results()
                if max(history.history['val_acc']) > max(results[task][optimizer]['history']['history']['val_acc']):
                    if save_results:
                        update_results([task,optimizer],{'history':history_todict(history),'ladc':ladc.memory,
                                                         'test_performance':test_performance,'best_alpha':alpha,'best_lr':lr})
                    else:
                        results[task][optimizer].update({'history':history_todict(history),'ladc':ladc.memory,
                                                         'test_performance':test_performance,'best_alpha':alpha,'best_lr':lr})
                    
                if monitor_norms:
                    with open('results/norm_memory_'+task+'_'+optimizer+'.p','wb') as f:
                        pickle.dump(norm_memory.memory,f)
                if monitor_layer_rotation_rates:
                    with open('results/layer_rotation_rates_'+task+'_'+optimizer+'.p','wb') as f:
                        pickle.dump(lrrc.memory,f)
                if monitor_adam_moment:
                    with open('results/Adam_Moments_'+task+'.p','wb') as f:
                        pickle.dump(adamMoment.memory,f)
                    
#                 with open(monitor_file,'a') as file:
#                     file.write(task + ', '+optimizer+', '+str(alpha)+ ', '+str(lr)+': done in '+str(time.time()-start)+' seconds.\n')