In [3]:
!ls

activate_environment.sh  imagenette.ipynb  Optimizers	   README.rst
conda_env.yml		 models		   poetry.lock	   start_off.ipynb
environment.yml		 nnfl		   pyproject.toml  tests
First_notebook.ipynb	 nnfl.egg-info	   README.md	   wandb


In [4]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [5]:
from fastai.vision import *
from torch.optim import Optimizer
from Optimizers.lookahead import *
from torchcontrib.optim import SWA
import wandb
from wandb.fastai import WandbCallback

In [6]:
PROJECT_PATH = Path.cwd()
MODELS = PROJECT_PATH/'models'

MODELS.mkdir(exist_ok=True)

In [7]:
def print_all(*args):
    for i in args:
        print(i)

In [8]:
def get_imagenette_data(bs=128):
    PATH = untar_data(URLs.IMAGENETTE_320)
    data = (ImageList.from_folder(path=PATH)
            .split_by_folder(train='train', valid='val')
            .label_from_folder()
            .transform(get_transforms(), size=224)
            .databunch(bs=bs))
    return data

### Check if everything works

In [11]:
# data = get_mnist_data()
# # data.show_batch()

# def get_simple_cnn(pretrained=False):
#     return simple_cnn([3, 4, 2])

# learn = cnn_learner(data=data, base_arch=get_simple_cnn, opt_func=LookaheadSGD)
# learn.metrics.append(accuracy)
# learn.fit_one_cycle(1, 0.003)
# learn.save(MODELS/'mnist_sample.pkl')
# learn.load(MODELS/'mnist_sample.pkl')

### Running the experiment

In [12]:
wandbRecorder = partial(WandbCallback, input_type='images')

In [13]:
optimizers = [LookaheadSGD, optim.SGD, optim.AdamW] # Along with SWA which needs to be separately dealt with
experiment_names = {LookaheadSGD:'Lookahead_SGD', optim.SGD:'SGD', optim.AdamW:'AdamW'}

lookahead_sgd_params = dict(alpha=0.5, k=5) # lr = 0.1
sgd_params = dict(momentum=0.9, weight_decay=0.001) # lr = 0.05
adamW_params = dict(weight_decay=0.03) # lr = 0.01

params_list = [lookahead_sgd_params, sgd_params, adamW_params]

params_dict = {LookaheadSGD:lookahead_sgd_params, optim.SGD:sgd_params, optim.AdamW:adamW_params}
lr_dict = {LookaheadSGD:0.1, optim.SGD:0.05, optim.AdamW:0.01,}
epochs_dict = {LookaheadSGD:40, optim.SGD:40, optim.AdamW:40}

In [14]:
wandb.init()

W&B Run: https://app.wandb.ai/akashpalrecha/lookahead/runs/fxij98tl

In [15]:
class LRDecayCallback(LearnerCallback):
    def __init__(self, learn:Learner, decay_on_epochs:list, decay_factor:int):
        super().__init__(learn)
        self.decay_on_epochs = decay_on_epochs
        self.decay_factor = decay_factor
    def on_epoch_end(self, epoch, **kwargs):
        if epoch in self.decay_on_epochs:
            self.opt.lr = self.opt.lr / self.decay_factor
            print("LR changed to: " + str(self.opt.lr))

In [16]:
lr_decay = partial(LRDecayCallback, decay_on_epochs=[60, 120, 160], decay_factor=5.0)

In [19]:
def fit_and_record(epochs=None, lr=None, opt_func=LookaheadSGD, experiment_name=None, 
                   data=None, base_arch=models.resnet50, pretrained=False):
    np.random.seed(42)
    # Initialize Training
    if experiment_name is None:
        experiment_name = experiment_names[opt_func]
        experiment_name += "_imagenette_320"
    if data is None or type(data) is str:
        data = get_imagenette_data(128)
    if epochs is None:
        epochs = epochs_dict[opt_func]
    if lr is None:
        lr = lr_dict[opt_func]
    # Setting optimizer parameters correctly
    opt_func = partial(opt_func, **params_dict[opt_func])
    
    print_all(data, opt_func, epochs, lr, experiment_name, base_arch)
    
    wandb.init(project='lookahead', name=experiment_name)
    learn = cnn_learner(data=data, base_arch=base_arch, opt_func=opt_func, 
                        pretrained=pretrained, callback_fns=WandbCallback).mixup()
    learn.metrics.append(accuracy)
    learn.metrics.append(partial(top_k_accuracy, k=5))
    learn.fit_one_cycle(epochs, lr)
    learn.save(MODELS/experiment_name)

# Experimentation for CIFAR10

In [20]:
fit_and_record(opt_func=LookaheadSGD)

ImageDataBunch;

Train: LabelList (12894 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
n03888257,n03888257,n03888257,n03888257,n03888257
Path: /home/ubuntu/.fastai/data/imagenette-320;

Valid: LabelList (500 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
n03888257,n03888257,n03888257,n03888257,n03888257
Path: /home/ubuntu/.fastai/data/imagenette-320;

Test: None
functools.partial(<class 'Optimizers.lookahead.LookaheadSGD'>, alpha=0.5, k=5)
40
0.1
Lookahead_SGD_imagenette_320
<function resnet50 at 0x7f94dd29c3b0>


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.567675,1.962278,0.388,0.796,00:42
1,2.211404,1.703623,0.482,0.9,00:42
2,2.074142,1.698804,0.57,0.912,00:42
3,2.080024,1.672964,0.5,0.89,00:42
4,2.143681,1.857412,0.38,0.808,00:42
5,2.012544,1.151218,0.624,0.952,00:42
6,1.907582,1.203249,0.63,0.946,00:42
7,1.835333,0.952235,0.708,0.962,00:42
8,1.710193,0.946083,0.704,0.97,00:42
9,1.670971,1.093194,0.654,0.974,00:42


Better model found at epoch 0 with valid_loss value: 1.9622784852981567.
Better model found at epoch 1 with valid_loss value: 1.703622817993164.
Better model found at epoch 2 with valid_loss value: 1.6988040208816528.
Better model found at epoch 3 with valid_loss value: 1.6729638576507568.
Better model found at epoch 5 with valid_loss value: 1.151218295097351.
Better model found at epoch 7 with valid_loss value: 0.9522354602813721.
Better model found at epoch 8 with valid_loss value: 0.9460827708244324.
Better model found at epoch 10 with valid_loss value: 0.855219841003418.
Better model found at epoch 11 with valid_loss value: 0.7718964219093323.
Better model found at epoch 14 with valid_loss value: 0.6835370659828186.
Better model found at epoch 16 with valid_loss value: 0.641135573387146.
Better model found at epoch 17 with valid_loss value: 0.5793488621711731.
Better model found at epoch 19 with valid_loss value: 0.5408705472946167.
Better model found at epoch 23 with valid_loss va

In [21]:
fit_and_record(opt_func=optim.AdamW)

ImageDataBunch;

Train: LabelList (12894 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
n03888257,n03888257,n03888257,n03888257,n03888257
Path: /home/ubuntu/.fastai/data/imagenette-320;

Valid: LabelList (500 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
n03888257,n03888257,n03888257,n03888257,n03888257
Path: /home/ubuntu/.fastai/data/imagenette-320;

Test: None
functools.partial(<class 'torch.optim.adamw.AdamW'>, weight_decay=0.03)
40
0.01
AdamW_imagenette_320
<function resnet50 at 0x7f94dd29c3b0>


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.567176,1.882721,0.368,0.806,00:44
1,2.214505,1.918799,0.388,0.84,00:43
2,1.962503,18.036596,0.382,0.806,00:44
3,1.710548,3.929991,0.37,0.788,00:44
4,1.692597,155.174973,0.098,0.524,00:44
5,1.88837,9.719261,0.448,0.866,00:43
6,1.811385,1.747995,0.604,0.94,00:44
7,1.660436,1.043073,0.654,0.938,00:44
8,1.917861,1.629917,0.446,0.932,00:44
9,1.769187,4.657066,0.614,0.918,00:44


Better model found at epoch 0 with valid_loss value: 1.882720708847046.
Better model found at epoch 6 with valid_loss value: 1.7479950189590454.
Better model found at epoch 7 with valid_loss value: 1.0430727005004883.
Better model found at epoch 12 with valid_loss value: 0.7316505312919617.
Better model found at epoch 14 with valid_loss value: 0.6461852192878723.
Better model found at epoch 15 with valid_loss value: 0.5928518176078796.
Better model found at epoch 22 with valid_loss value: 0.5700778961181641.
Better model found at epoch 23 with valid_loss value: 0.5290701985359192.
Better model found at epoch 25 with valid_loss value: 0.4418453276157379.
Better model found at epoch 32 with valid_loss value: 0.37320706248283386.
Better model found at epoch 33 with valid_loss value: 0.33699649572372437.
Better model found at epoch 34 with valid_loss value: 0.3197595179080963.
Better model found at epoch 35 with valid_loss value: 0.3082115650177002.
Better model found at epoch 36 with vali

In [22]:
fit_and_record(opt_func=optim.SGD)

ImageDataBunch;

Train: LabelList (12894 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
n03888257,n03888257,n03888257,n03888257,n03888257
Path: /home/ubuntu/.fastai/data/imagenette-320;

Valid: LabelList (500 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
n03888257,n03888257,n03888257,n03888257,n03888257
Path: /home/ubuntu/.fastai/data/imagenette-320;

Test: None
functools.partial(<class 'torch.optim.sgd.SGD'>, momentum=0.9, weight_decay=0.001)
40
0.05
SGD_imagenette_320
<function resnet50 at 0x7f94dd29c3b0>


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.577726,1.977674,0.356,0.808,00:42
1,2.272291,1.678074,0.442,0.896,00:42
2,2.204104,1.834952,0.5,0.898,00:42
3,2.127513,3.182383,0.516,0.888,00:42
4,2.206284,5.721944,0.506,0.916,00:42
5,2.157249,1.533208,0.49,0.894,00:42
6,2.018359,2.436231,0.566,0.922,00:42
7,2.269503,10.531339,0.368,0.79,00:42
8,2.230512,1.165079,0.614,0.946,00:42
9,2.045107,1.294158,0.562,0.928,00:43


wandb: Network error resolved after 0:00:11.474116, resuming normal operation.


Better model found at epoch 0 with valid_loss value: 1.9776735305786133.
Better model found at epoch 1 with valid_loss value: 1.6780741214752197.
Better model found at epoch 5 with valid_loss value: 1.5332080125808716.
Better model found at epoch 8 with valid_loss value: 1.1650793552398682.
Better model found at epoch 10 with valid_loss value: 1.1061084270477295.
Better model found at epoch 11 with valid_loss value: 0.8907231092453003.
Better model found at epoch 14 with valid_loss value: 0.7777143120765686.
Better model found at epoch 16 with valid_loss value: 0.6707106828689575.
Better model found at epoch 18 with valid_loss value: 0.670626699924469.
Better model found at epoch 20 with valid_loss value: 0.6270689368247986.
Better model found at epoch 24 with valid_loss value: 0.5811817646026611.
Better model found at epoch 29 with valid_loss value: 0.460629403591156.
Better model found at epoch 32 with valid_loss value: 0.4599528908729553.
Better model found at epoch 37 with valid_lo

wandb: Network error resolved after 0:00:11.274634, resuming normal operation.
