In [1]:
!ls

activate_environment.sh  models		README.md	 tests
environment.yml		 nnfl.egg-info	README.rst	 tmp.ipynb
imagenette.ipynb	 Optimizers	start_off.ipynb  wandb


In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
from fastai.vision import *
from torch.optim import Optimizer
from Optimizers.lookahead import *
from torchcontrib.optim import SWA
import wandb
from wandb.fastai import WandbCallback

In [4]:
PROJECT_PATH = Path.cwd()
MODELS = PROJECT_PATH/'models'

MODELS.mkdir(exist_ok=True)

In [5]:
def print_all(*args):
    for i in args:
        print(i)

In [6]:
def get_mnist_data(bs=16):
    PATH = untar_data(URLs.MNIST)
    data = (ImageList.from_folder(path=PATH)
            .split_by_folder(train='training', valid='testing')
            .label_from_folder()
            .databunch(bs=bs))
    return data

In [7]:
def get_cifar100_data(bs=16):
    pad4 = RandTransform(tfm=pad, kwargs={'padding':4, 'mode':'zeros'}, is_random=False, p=1.0, use_on_y=False)
    crop32 = RandTransform(tfm=crop, kwargs={'row_pct':(0.4,0.6), 'col_pct':(0.4,0.6), 'size':32}, p=1.0, use_on_y=False)
    tfms = [pad4, crop32]
    PATH = untar_data(URLs.CIFAR_100)
    data = (ImageList.from_folder(path=PATH)
            .split_by_folder(train='train', valid='test')
            .label_from_folder()
            .transform((tfms, []))
            .databunch(bs=bs))
    return data

In [8]:
def get_cifar10_data(bs=16):
    pad4 = RandTransform(tfm=pad, kwargs={'padding':4, 'mode':'zeros'}, is_random=False, p=1.0, use_on_y=False)
    crop32 = RandTransform(tfm=crop, kwargs={'row_pct':(0.4,0.6), 'col_pct':(0.4,0.6), 'size':32}, p=1.0, use_on_y=False)
    tfms = [pad4, crop32]
    PATH = untar_data(URLs.CIFAR)
    data = (ImageList.from_folder(path=PATH)
            .split_by_folder(train='train', valid='test')
            .label_from_folder()
            .transform((tfms, []))
            .databunch(bs=bs))
    return data

### Check if everything works

In [8]:
# data = get_mnist_data()
# # data.show_batch()

# def get_simple_cnn(pretrained=False):
#     return simple_cnn([3, 4, 2])

# learn = cnn_learner(data=data, base_arch=get_simple_cnn, opt_func=LookaheadSGD)
# learn.metrics.append(accuracy)
# learn.fit_one_cycle(1, 0.003)
# learn.save(MODELS/'mnist_sample.pkl')
# learn.load(MODELS/'mnist_sample.pkl')

KeyboardInterrupt: 

### Running the experiment

In [11]:
wandbRecorder = partial(WandbCallback, input_type='images')

In [31]:
optimizers = [LookaheadSGD, optim.SGD, optim.AdamW, optim.RMSprop, LookaheadAdamW] # Along with SWA which needs to be separately dealt with
experiment_names = {LookaheadSGD:'Lookahead_SGD', optim.SGD:'SGD', optim.AdamW:'AdamW', 
                        optim.RMSprop:'RMSProp', LookaheadAdamW:'LookaheadAdamW'}

lookahead_sgd_params = dict(alpha=0.5, k=5) # lr = 0.1
sgd_params = dict(momentum=0.9, weight_decay=0.001) # lr = 0.05
adamW_params = dict(weight_decay=0.3) # lr = 0.001
rmsprop_params = dict(weight_decay=0.001) # lr = 0.01
lookahead_adamw_params = dict(weight_decay=0.3, k=5, alpha=0.5) # lr = 0.1

params_list = [lookahead_sgd_params, sgd_params, adamW_params, rmsprop_params, lookahead_adamw_params]

params_dict = {LookaheadSGD:lookahead_sgd_params, optim.SGD:sgd_params, optim.AdamW:adamW_params,
              optim.RMSprop:rmsprop_params, LookaheadAdamW:lookahead_adamw_params}
lr_dict = {LookaheadSGD:0.1, optim.SGD:0.05, optim.AdamW:0.001,
              optim.RMSprop:0.01, LookaheadAdamW:0.1}
epochs_dict = {LookaheadSGD:200, optim.SGD:200, optim.AdamW:200,
              optim.RMSprop:200, LookaheadAdamW:200}

In [13]:
wandb.init()

W&B Run: https://app.wandb.ai/akashpalrecha/lookahead/runs/hvjrnb9b

In [56]:
6 in [1,2,3,4]

False

In [57]:
class LRDecayCallback(LearnerCallback):
    def __init__(self, learn:Learner, decay_on_epochs:list, decay_factor:int):
        super().__init__(learn)
        self.decay_on_epochs = decay_on_epochs
        self.decay_factor = decay_factor
    def on_epoch_end(self, epoch, **kwargs):
        if epoch in self.decay_on_epochs:
            self.opt.lr = self.opt.lr / self.decay_factor
            print("LR changed to: " + str(self.opt.lr))

In [58]:
lr_decay = partial(LRDecayCallback, decay_on_epochs=[60, 120, 160], decay_factor=5.0)

In [70]:
def fit_and_record(epochs=None, lr=None, opt_func=LookaheadSGD, experiment_name=None, 
                   data=None, base_arch=models.resnet18, pretrained=False, one_cycle=False):
    np.random.seed(42)
    # Initialize Training
    if experiment_name is None:
        experiment_name = experiment_names[opt_func]
    if data is None or type(data) is str:
        if data is 'cifar100':
            data = get_cifar100_data(128)
            experiment_name += '_cifar100'
        else:
            data = get_cifar10_data(128)
            experiment_name += '_cifar10'
    if epochs is None:
        epochs = epochs_dict[opt_func]
    if lr is None:
        lr = lr_dict[opt_func]
    # Setting optimizer parameters correctly
    opt_func = partial(opt_func, **params_dict[opt_func])
    
    print_all(data, opt_func, epochs, lr, experiment_name, base_arch)
    
    wandb.init(project='lookahead', name=experiment_name)
    learn = cnn_learner(data=data, base_arch=base_arch, opt_func=opt_func, 
                        pretrained=pretrained, callback_fns=[WandbCallback, lr_decay])
    learn.metrics.append(accuracy)
    if one_cycle:
        learn.fit_one_cycle(epochs, lr)
    else:
        learn.fit(epochs, lr)
    learn.save(MODELS/experiment_name)

# Experimentation for CIFAR10

In [60]:
fit_and_record(opt_func=LookaheadSGD, data='cifar10')

ImageDataBunch;

Train: LabelList (50000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Valid: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Test: None
functools.partial(<class 'Optimizers.lookahead.LookaheadSGD'>, alpha=0.5, k=5)
200
0.1
Lookahead_SGD_cifar10
<function resnet18 at 0x7fa06f539200>


epoch,train_loss,valid_loss,accuracy,time
0,1.581931,1.408851,0.4892,00:10
1,1.304877,1.235686,0.5574,00:10
2,1.180615,1.166253,0.5834,00:10
3,1.098163,1.086026,0.6245,00:10
4,1.013424,1.033062,0.6388,00:10
5,0.958619,1.025127,0.6429,00:10
6,0.908821,0.927633,0.6796,00:10
7,0.880085,0.889115,0.7011,00:11
8,0.862087,0.923126,0.6883,00:10
9,0.842486,1.121552,0.6262,00:10


Better model found at epoch 0 with valid_loss value: 1.408851146697998.
Better model found at epoch 1 with valid_loss value: 1.2356864213943481.
Better model found at epoch 2 with valid_loss value: 1.1662530899047852.
Better model found at epoch 3 with valid_loss value: 1.0860264301300049.
Better model found at epoch 4 with valid_loss value: 1.03306245803833.
Better model found at epoch 5 with valid_loss value: 1.0251268148422241.
Better model found at epoch 6 with valid_loss value: 0.9276328086853027.
Better model found at epoch 7 with valid_loss value: 0.8891153335571289.
Better model found at epoch 11 with valid_loss value: 0.8779951930046082.
Better model found at epoch 15 with valid_loss value: 0.7623305916786194.
Better model found at epoch 31 with valid_loss value: 0.7152103781700134.
Better model found at epoch 54 with valid_loss value: 0.7111032009124756.
LR changed to: 0.02
Better model found at epoch 61 with valid_loss value: 0.5787047147750854.
LR changed to: 0.004
LR chang

In [61]:
fit_and_record(opt_func=optim.SGD, data='cifar10')

ImageDataBunch;

Train: LabelList (50000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Valid: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Test: None
functools.partial(<class 'torch.optim.sgd.SGD'>, momentum=0.9, weight_decay=0.001)
200
0.05
SGD_cifar10
<function resnet18 at 0x7fa06f539200>


epoch,train_loss,valid_loss,accuracy,time
0,2.072237,1.960284,0.3769,00:11
1,1.506411,1.45838,0.4784,00:11
2,1.28429,1.374882,0.5039,00:11
3,1.109046,1.242516,0.5573,00:11
4,1.022108,1.22134,0.5762,00:11
5,0.942496,1.282917,0.5775,00:11
6,0.890733,1.081589,0.619,00:11
7,0.832625,1.075914,0.6347,00:11
8,0.818916,1.167694,0.6036,00:11
9,0.769443,0.946097,0.6758,00:11


Better model found at epoch 0 with valid_loss value: 1.960283637046814.
Better model found at epoch 1 with valid_loss value: 1.4583797454833984.
Better model found at epoch 2 with valid_loss value: 1.3748815059661865.
Better model found at epoch 3 with valid_loss value: 1.2425156831741333.
Better model found at epoch 4 with valid_loss value: 1.2213397026062012.
Better model found at epoch 6 with valid_loss value: 1.081588625907898.
Better model found at epoch 7 with valid_loss value: 1.0759135484695435.
Better model found at epoch 9 with valid_loss value: 0.9460968971252441.
Better model found at epoch 10 with valid_loss value: 0.8722596764564514.
Better model found at epoch 23 with valid_loss value: 0.8067433834075928.
Better model found at epoch 45 with valid_loss value: 0.7833822965621948.
LR changed to: 0.01
Better model found at epoch 61 with valid_loss value: 0.6396269798278809.
LR changed to: 0.002
LR changed to: 0.0004
Loaded best saved model from /home/ubuntu/personal/Lookahea

In [62]:
fit_and_record(opt_func=optim.AdamW, data='cifar10')

ImageDataBunch;

Train: LabelList (50000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Valid: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Test: None
functools.partial(<class 'torch.optim.adamw.AdamW'>, weight_decay=0.3)
200
0.001
AdamW_cifar10
<function resnet18 at 0x7fa06f539200>


epoch,train_loss,valid_loss,accuracy,time
0,1.494286,1.566047,0.4377,00:14
1,1.17105,1.43077,0.5031,00:14
2,1.016847,1.266495,0.5598,00:14
3,0.892871,0.990961,0.6519,00:14
4,0.788466,0.898746,0.6947,00:14
5,0.748336,0.972483,0.6709,00:14
6,0.685964,1.059485,0.6514,00:14
7,0.617005,0.805646,0.7317,00:14
8,0.576487,0.834833,0.7225,00:14
9,0.51609,1.010158,0.6702,00:14


Better model found at epoch 0 with valid_loss value: 1.5660473108291626.
Better model found at epoch 1 with valid_loss value: 1.4307698011398315.
Better model found at epoch 2 with valid_loss value: 1.2664949893951416.
Better model found at epoch 3 with valid_loss value: 0.9909612536430359.
Better model found at epoch 4 with valid_loss value: 0.898745596408844.
Better model found at epoch 7 with valid_loss value: 0.8056460618972778.
Better model found at epoch 11 with valid_loss value: 0.7591457366943359.
LR changed to: 0.0002
LR changed to: 4e-05
LR changed to: 8.000000000000001e-06
Loaded best saved model from /home/ubuntu/personal/Lookahead/wandb/run-20191114_024850-z8skfprj/bestmodel.pth


In [None]:
fit_and_record(opt_func=optim.RMSprop, data='cifar10')

ImageDataBunch;

Train: LabelList (50000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Valid: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
horse,horse,horse,horse,horse
Path: /home/ubuntu/.fastai/data/cifar10;

Test: None
functools.partial(<class 'torch.optim.rmsprop.RMSprop'>, weight_decay=0.001)
200
0.01
RMSProp_cifar10
<function resnet18 at 0x7fa06f539200>


epoch,train_loss,valid_loss,accuracy,time
0,1.937513,1.793664,0.3191,00:12
1,1.51258,2.013301,0.4452,00:12
2,1.215958,1.773873,0.4076,00:12
3,1.01806,1.598398,0.4796,00:12
4,0.89625,1.201032,0.5952,00:12
5,0.816969,1.125259,0.6226,00:12
6,0.723698,1.175508,0.6343,00:12
7,0.675739,0.897589,0.7083,00:12
8,0.595575,0.87924,0.7208,00:12
9,0.560913,1.046785,0.6705,00:12


Better model found at epoch 0 with valid_loss value: 1.7936636209487915.
Better model found at epoch 2 with valid_loss value: 1.7738728523254395.
Better model found at epoch 3 with valid_loss value: 1.598398208618164.
Better model found at epoch 4 with valid_loss value: 1.201032280921936.
Better model found at epoch 5 with valid_loss value: 1.1252591609954834.
Better model found at epoch 7 with valid_loss value: 0.8975886702537537.
Better model found at epoch 8 with valid_loss value: 0.8792399168014526.


In [None]:
fit_and_record(opt_func=LookaheadAdamW, data='cifar10')

# Experimentation for CIFAR100

In [None]:
fit_and_record(opt_func=LookaheadSGD, data='cifar100', epochs=100, one_cycle=True)

ImageDataBunch;

Train: LabelList (50000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
apple,apple,apple,apple,apple
Path: /home/ubuntu/.fastai/data/cifar100;

Valid: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
apple,apple,apple,apple,apple
Path: /home/ubuntu/.fastai/data/cifar100;

Test: None
functools.partial(<class 'Optimizers.lookahead.LookaheadSGD'>, alpha=0.5, k=5)
100
0.1
Lookahead_SGD_cifar100
<function resnet18 at 0x7fa06f539200>


epoch,train_loss,valid_loss,accuracy,time
0,4.319181,3.78808,0.1391,00:12
1,3.829179,3.468418,0.1938,00:11
2,3.498308,3.778581,0.2256,00:11
3,3.314722,3.041385,0.2561,00:11
4,3.047771,2.88672,0.2838,00:11
5,2.859583,2.732644,0.3178,00:11
6,2.723511,2.586912,0.3433,00:11
7,2.57626,2.569853,0.3458,00:11
8,2.433782,2.483126,0.3643,00:11
9,2.294419,2.395706,0.3851,00:12


Better model found at epoch 0 with valid_loss value: 3.7880799770355225.
Better model found at epoch 1 with valid_loss value: 3.4684176445007324.
Better model found at epoch 3 with valid_loss value: 3.0413854122161865.
Better model found at epoch 4 with valid_loss value: 2.8867201805114746.
Better model found at epoch 5 with valid_loss value: 2.7326436042785645.
Better model found at epoch 6 with valid_loss value: 2.586912155151367.
Better model found at epoch 7 with valid_loss value: 2.569852828979492.
Better model found at epoch 8 with valid_loss value: 2.483125925064087.
Better model found at epoch 9 with valid_loss value: 2.3957061767578125.
Better model found at epoch 10 with valid_loss value: 2.3372180461883545.
Better model found at epoch 11 with valid_loss value: 2.266594648361206.
Better model found at epoch 15 with valid_loss value: 2.2540066242218018.
Better model found at epoch 17 with valid_loss value: 2.162142515182495.
Better model found at epoch 32 with valid_loss value

In [None]:
fit_and_record(opt_func=optim.SGD, data='cifar100', epochs=100, one_cycle=True)

In [None]:
fit_and_record(opt_func=optim.AdamW, data='cifar100', epochs=100, one_cycle=True)

In [None]:
fit_and_record(opt_func=optim.RMSprop, data='cifar100', epochs=100, one_cycle=True)

In [None]:
fit_and_record(opt_func=LookaheadAdamW, data='cifar100', epochs=100, one_cycle=True)