In [21]:
!ls

activate_environment.sh  nnfl.egg-info	start_off-Copy1.ipynb  wandb
environment.yml		 Optimizers	start_off.ipynb
imagenette.ipynb	 README.md	tests
models			 README.rst	tmp.ipynb


In [22]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
from fastai.vision import *
from torch.optim import Optimizer
from Optimizers.lookahead import *
from torchcontrib.optim import SWA
import wandb
from wandb.fastai import WandbCallback

In [24]:
PROJECT_PATH = Path.cwd()
MODELS = PROJECT_PATH/'models'

MODELS.mkdir(exist_ok=True)

In [25]:
def print_all(*args):
    for i in args:
        print(i)

In [26]:
def get_mnist_data(bs=16):
    PATH = untar_data(URLs.MNIST)
    data = (ImageList.from_folder(path=PATH)
            .split_by_folder(train='training', valid='testing')
            .label_from_folder()
            .databunch(bs=bs))
    return data

In [27]:
def get_cifar100_data(bs=16):
    pad4 = RandTransform(tfm=pad, kwargs={'padding':4, 'mode':'zeros'}, is_random=False, p=1.0, use_on_y=False)
    crop32 = RandTransform(tfm=crop, kwargs={'row_pct':(0.4,0.6), 'col_pct':(0.4,0.6), 'size':32}, p=1.0, use_on_y=False)
    flip50 = RandTransform(tfm=flip_lr, kwargs={}, p=0.5)
    tfms = [pad4, crop32, flip50]
    PATH = untar_data(URLs.CIFAR_100)
    data = (ImageList.from_folder(path=PATH)
            .split_by_folder(train='train', valid='test')
            .label_from_folder()
            .transform((tfms, []))
            .databunch(bs=bs))
    return data

In [28]:
def get_cifar10_data(bs=16):
    pad4 = RandTransform(tfm=pad, kwargs={'padding':4, 'mode':'zeros'}, is_random=False, p=1.0, use_on_y=False)
    crop32 = RandTransform(tfm=crop, kwargs={'row_pct':(0.4,0.6), 'col_pct':(0.4,0.6), 'size':32}, p=1.0, use_on_y=False)
    flip50 = RandTransform(tfm=flip_lr, kwargs={}, p=0.5)
    tfms = [pad4, crop32, flip50]
    PATH = untar_data(URLs.CIFAR)
    data = (ImageList.from_folder(path=PATH)
            .split_by_folder(train='train', valid='test')
            .label_from_folder()
            .transform((tfms, []))
            .databunch(bs=bs))
    return data

### Check if everything works

In [29]:
# data = get_mnist_data()
# # data.show_batch()

# def get_simple_cnn(pretrained=False):
#     return simple_cnn([3, 4, 2])

# learn = cnn_learner(data=data, base_arch=get_simple_cnn, opt_func=LookaheadSGD)
# learn.metrics.append(accuracy)
# learn.fit_one_cycle(1, 0.003)
# learn.save(MODELS/'mnist_sample.pkl')
# learn.load(MODELS/'mnist_sample.pkl')

### Running the experiment

In [30]:
wandbRecorder = partial(WandbCallback, input_type='images')

In [31]:
optimizers = [LookaheadSGD, optim.SGD, optim.AdamW, optim.RMSprop, LookaheadAdamW] # Along with SWA which needs to be separately dealt with
experiment_names = {LookaheadSGD:'Lookahead_SGD', optim.SGD:'SGD', optim.AdamW:'AdamW', 
                        optim.RMSprop:'RMSProp', LookaheadAdamW:'LookaheadAdamW'}

lookahead_sgd_params = dict(momentum=0.9, alpha=0.5, k=5) # lr = 0.1
sgd_params = dict(momentum=0.9, weight_decay=0.001) # lr = 0.05
adamW_params = dict(weight_decay=0.3) # lr = 0.001
rmsprop_params = dict(weight_decay=0.001) # lr = 0.01
lookahead_adamw_params = dict(weight_decay=0.3, k=5, alpha=0.5) # lr = 0.1

params_list = [lookahead_sgd_params, sgd_params, adamW_params, rmsprop_params, lookahead_adamw_params]

params_dict = {LookaheadSGD:lookahead_sgd_params, optim.SGD:sgd_params, optim.AdamW:adamW_params,
              optim.RMSprop:rmsprop_params, LookaheadAdamW:lookahead_adamw_params}
lr_dict = {LookaheadSGD:0.1, optim.SGD:0.05, optim.AdamW:0.001,
              optim.RMSprop:0.01, LookaheadAdamW:0.1}
epochs_dict = {LookaheadSGD:200, optim.SGD:200, optim.AdamW:200,
              optim.RMSprop:200, LookaheadAdamW:200}

In [32]:
wandb.init()

W&B Run: https://app.wandb.ai/akashpalrecha/lookahead/runs/reta6zhs

In [33]:
class LRDecayCallback(LearnerCallback):
    def __init__(self, learn:Learner, decay_on_epochs:list, decay_factor:int):
        super().__init__(learn)
        self.decay_on_epochs = decay_on_epochs
        self.decay_factor = decay_factor
    def on_epoch_end(self, epoch, **kwargs):
        if epoch in self.decay_on_epochs:
            self.opt.lr = self.opt.lr / self.decay_factor
            print("LR changed to: " + str(self.opt.lr))

In [34]:
lr_decay = partial(LRDecayCallback, decay_on_epochs=[60, 120, 160], decay_factor=5.0)

In [36]:
def fit_and_record(epochs=None, lr=None, opt_func=LookaheadSGD, experiment_name=None, 
                   data=None, base_arch=models.resnet18, pretrained=False, one_cycle=False, exp_postfix=""):
    np.random.seed(18)
    # Initialize Training
    if experiment_name is None:
        experiment_name = experiment_names[opt_func]
    if data is None or type(data) is str:
        if data is 'cifar100':
            data = get_cifar100_data(128)
            experiment_name += '_cifar100'
        else:
            data = get_cifar10_data(128)
            experiment_name += '_cifar10'
    if epochs is None:
        epochs = epochs_dict[opt_func]
    if lr is None:
        lr = lr_dict[opt_func]
    experiment_name += exp_postfix
    # Setting optimizer parameters correctly
    opt_func = partial(opt_func, **params_dict[opt_func])
    
    print_all(data, opt_func, epochs, lr, experiment_name, base_arch)
    
    wandb.init(project='lookahead', name=experiment_name)
    
    if one_cycle:
        learn = cnn_learner(data=data, base_arch=base_arch, opt_func=opt_func, 
                        pretrained=pretrained, callback_fns=[WandbCallback])
        learn.metrics.append(accuracy)
        learn.fit_one_cycle(epochs, lr)
    else:
        learn = cnn_learner(data=data, base_arch=base_arch, opt_func=opt_func, 
                        pretrained=pretrained, callback_fns=[WandbCallback, lr_decay])
        learn.metrics.append(accuracy)
        learn.fit(epochs, lr)
    learn.save(MODELS/experiment_name)

# Experimentation for CIFAR10

In [None]:
fit_and_record(opt_func=LookaheadSGD, data='cifar10')

ImageDataBunch;

Train: LabelList (50000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Valid: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Test: None
functools.partial(<class 'Optimizers.lookahead.LookaheadSGD'>, momentum=0.9, alpha=0.5, k=5)
200
0.1
Lookahead_SGD_cifar10
<function resnet18 at 0x7f12c0174290>


epoch,train_loss,valid_loss,accuracy,time
0,1.859745,4.138368,0.2701,00:12
1,1.509201,1.32573,0.5187,00:12
2,1.248699,1.483238,0.4875,00:11
3,1.122516,1.193955,0.5766,00:11
4,1.02513,1.279949,0.5643,00:11
5,0.985551,0.967895,0.6569,00:12
6,0.949791,1.103185,0.6164,00:11
7,0.905246,1.02094,0.6503,00:12
8,0.870605,1.712785,0.4579,00:12
9,0.868018,0.893622,0.6974,00:12


Better model found at epoch 0 with valid_loss value: 4.138368129730225.
Better model found at epoch 1 with valid_loss value: 1.325730323791504.
Better model found at epoch 3 with valid_loss value: 1.1939547061920166.
Better model found at epoch 5 with valid_loss value: 0.9678947329521179.
Better model found at epoch 9 with valid_loss value: 0.8936222791671753.
Better model found at epoch 11 with valid_loss value: 0.8396416902542114.
Better model found at epoch 15 with valid_loss value: 0.8119654059410095.
Better model found at epoch 22 with valid_loss value: 0.8069820404052734.
Better model found at epoch 23 with valid_loss value: 0.7375544309616089.
Better model found at epoch 43 with valid_loss value: 0.7374391555786133.
LR changed to: 0.02
Better model found at epoch 61 with valid_loss value: 0.48965540528297424.
LR changed to: 0.004
LR changed to: 0.0008


In [None]:
fit_and_record(opt_func=optim.SGD, data='cifar10')

In [None]:
fit_and_record(opt_func=optim.AdamW, data='cifar10')

In [None]:
fit_and_record(opt_func=optim.RMSprop, data='cifar10')

In [None]:
fit_and_record(opt_func=LookaheadAdamW, data='cifar10')

# Experimentation for CIFAR100

In [None]:
fit_and_record(opt_func=LookaheadSGD, data='cifar100', epochs=120, one_cycle=True, exp_postfix="120")

ImageDataBunch;

Train: LabelList (50000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
apple,apple,apple,apple,apple
Path: /home/ubuntu/.fastai/data/cifar100;

Valid: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
apple,apple,apple,apple,apple
Path: /home/ubuntu/.fastai/data/cifar100;

Test: None
functools.partial(<class 'Optimizers.lookahead.LookaheadSGD'>, momentum=0.9, alpha=0.5, k=5)
120
0.1
Lookahead_SGD_cifar100120
<function resnet18 at 0x7f12c0174290>


epoch,train_loss,valid_loss,accuracy,time
0,4.988226,4.193522,0.0934,00:18
1,4.52575,3.864068,0.1388,00:11
2,4.222537,3.701214,0.1508,00:12
3,4.010345,3.592666,0.1748,00:12
4,3.823682,3.462116,0.1902,00:12
5,3.654565,3.441228,0.2118,00:12
6,3.506438,3.39933,0.2233,00:11
7,3.318072,3.106065,0.2464,00:11
8,3.181965,3.29915,0.2225,00:11
9,3.063131,2.867387,0.291,00:12


Better model found at epoch 0 with valid_loss value: 4.193521976470947.
Better model found at epoch 1 with valid_loss value: 3.8640682697296143.
Better model found at epoch 2 with valid_loss value: 3.701214075088501.
Better model found at epoch 3 with valid_loss value: 3.592665910720825.
Better model found at epoch 4 with valid_loss value: 3.462115526199341.
Better model found at epoch 5 with valid_loss value: 3.441227674484253.
Better model found at epoch 6 with valid_loss value: 3.3993301391601562.
Better model found at epoch 7 with valid_loss value: 3.106064558029175.
Better model found at epoch 9 with valid_loss value: 2.867387056350708.
Better model found at epoch 10 with valid_loss value: 2.855327844619751.
Better model found at epoch 13 with valid_loss value: 2.5248360633850098.
Better model found at epoch 18 with valid_loss value: 2.4313759803771973.


In [None]:
fit_and_record(opt_func=optim.SGD, data='cifar100', epochs=120, one_cycle=True, exp_postfix="120")

In [None]:
fit_and_record(opt_func=optim.AdamW, data='cifar100', epochs=120, one_cycle=True, exp_postfix="120")

In [None]:
fit_and_record(opt_func=optim.RMSprop, data='cifar100', epochs=120, one_cycle=True, exp_postfix="120")

In [None]:
fit_and_record(opt_func=LookaheadAdamW, data='cifar100', epochs=120, one_cycle=True, exp_postfix="120")