In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
!git clone https://github.com/fastai/course-v3
%cd /content/course-v3/nbs/dl2
from exp.nb_08 import *

Cloning into 'course-v3'...
remote: Enumerating objects: 5893, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 5893 (delta 0), reused 2 (delta 0), pack-reused 5890[K
Receiving objects: 100% (5893/5893), 263.10 MiB | 34.79 MiB/s, done.
Resolving deltas: 100% (3251/3251), done.
/content/course-v3/nbs/dl2


In [3]:
torch.optim

<module 'torch.optim' from '/usr/local/lib/python3.7/dist-packages/torch/optim/__init__.py'>

## Load dataset and vanila model

In [4]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)

Downloading https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz


In [5]:
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]
bs = 128
il = ImageList.from_files(path, tfms=tfms)
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())
data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=2)

In [6]:
nfs = [32, 64, 128, 256]

In [7]:
cbfs = [partial(AvgStatsCallback, accuracy),
        CudaCallback,
        partial(BatchTransformXCallback, norm_imagenette)]

In [8]:
learn, run= get_learn_run(nfs, data, 0.4, conv_layer, cbs=cbfs)

In [9]:
run.fit(1, learn)

train: [1.770394263517795, tensor(0.3820, device='cuda:0')]
valid: [1.6519826084792995, tensor(0.4517, device='cuda:0')]


---

## Refining the optimizer

In [10]:
import inspect
print(inspect.getsource(compose))

def compose(x, funcs, *args, order_key='_order', **kwargs):
    key = lambda o: getattr(o, order_key, 0)
    for f in sorted(listify(funcs), key=key): x = f(x, **kwargs)
    return x



In [22]:
class Optimizer():
    def __init__(self, params, steppers, **defaults):
        self.param_groups = list(params)
        # ensure params is a list of lists
        if not isinstance(self.param_groups[0], list): self.param_groups = [self.param_groups]
        self.hypers = [{**defaults} for p in self.param_groups]
        self.steppers = listify(steppers)
    def grad_params(self):
        gps = []
        for pg, hyper in zip(self.param_groups, self.hypers):
            for p in pg:
                if p.grad is not None:
                    gps = gps + [(p, hyper)]
        return gps
    def zero_grad(self):
        for p, hyper in self.grad_params():
            p.grad.detach_()
            p.grad.zero_()
    def step(self):
        for p, hyper in grad_params:
            compose(p, self.steppers, **hyper)

In [24]:
Optimizer(learn.model.parameters(), sgd_step).hypers

[{}]

In [20]:
def sgd_step(p, lr, **kwargs):
    # pytorch inplace function of sum. second parameter will be multiplied to first parameter
    p.data.add_(-lr, p.grad.data)
    return p
# steppers is compositional function
opt_func = partial(Optimizer, steppers=[sgd_step])

In [None]:
dir(torch.optim)

- Q3

In [25]:
print(inspect.getsource(Recorder))

class Recorder(Callback):
    def begin_fit(self):
        self.lrs = [[] for _ in self.opt.param_groups]
        self.losses = []

    def after_batch(self):
        if not self.in_train: return
        for pg,lr in zip(self.opt.param_groups,self.lrs): lr.append(pg['lr'])
        self.losses.append(self.loss.detach().cpu())

    def plot_lr  (self, pgid=-1): plt.plot(self.lrs[pgid])
    def plot_loss(self, skip_last=0): plt.plot(self.losses[:len(self.losses)-skip_last])

    def plot(self, skip_last=0, pgid=-1):
        losses = [o.item() for o in self.losses]
        lrs    = self.lrs[pgid]
        n = len(losses)-skip_last
        plt.xscale('log')
        plt.plot(lrs[:n], losses[:n])



In [28]:
torch_opt_func = torch.optim.SGD(learn.model.parameters(), lr=0.3)

In [31]:
torch_opt_func

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.3
    momentum: 0
    nesterov: False
    weight_decay: 0
)

In [30]:
torch_opt_func.param_groups

[{'dampening': 0,
  'lr': 0.3,
  'momentum': 0,
  'nesterov': False,
  'params': [Parameter containing:
   tensor([[[[ 1.7439e-01, -1.0964e-01, -2.5532e-01],
             [ 3.4949e-01,  5.9283e-02, -6.7358e-01],
             [-7.5286e-02,  5.4168e-02, -1.5314e-01]],
   
            [[-2.9575e-01,  1.1946e-01, -1.6396e-01],
             [-2.8440e-02, -4.9511e-01,  1.1969e-01],
             [ 6.1617e-01, -1.9601e-03, -1.3831e-01]],
   
            [[ 5.3499e-01, -8.7488e-02, -3.1547e-02],
             [-2.1036e-01,  1.3039e-01, -2.3309e-01],
             [ 6.1347e-02, -2.3782e-01, -4.0898e-01]]],
   
   
           [[[ 1.0472e-01,  6.5960e-01, -1.7434e-01],
             [-3.0749e-02,  2.3055e-01,  3.7322e-01],
             [-6.1702e-03, -2.5747e-01,  7.7779e-02]],
   
            [[ 1.4046e-01, -8.8347e-02,  1.9498e-01],
             [-2.0334e-01, -1.1188e-01,  3.1159e-01],
             [ 3.2762e-02, -7.9771e-02, -2.3705e-01]],
   
            [[-2.5879e-01,  2.2828e-01, -6.8593e-02],
  

In [32]:
# original one (i.e., dependency on pytorch)
class Recorder(Callback):
    def begin_fit(self):
        self.lrs = [[] for _ in self.opt.param_groups]
        self.losses = []

    def after_batch(self):
        if not self.in_train: return
        for pg,lr in zip(self.opt.param_groups,self.lrs): lr.append(pg['lr'])
        self.losses.append(self.loss.detach().cpu())

    def plot_lr  (self, pgid=-1): plt.plot(self.lrs[pgid])
    def plot_loss(self, skip_last=0): plt.plot(self.losses[:len(self.losses)-skip_last])

    def plot(self, skip_last=0, pgid=-1):
        losses = [o.item() for o in self.losses]
        lrs    = self.lrs[pgid]
        n = len(losses)-skip_last
        plt.xscale('log')
        plt.plot(lrs[:n], losses[:n])

In [None]:
# adjusted one
class Recorder(Callback):
    def begin_fit(self): self.lrs, self.losses = [], []

    def after_batch(self):
        if not self.in_train: return
        self.lrs.append(self.opt.hypers[-1]['lr'])
        self.losses.append(self.loss.detach().cpu())
    
    def plot_lr  (self, pgid=-1): plt.plot(self.lrs[pgid])
    def plot_loss(self, skip_last=0): plt.plot(self.losses[:len(self.losses)-skip_last])

    def plot(self, skip_last=0, pgid=-1):
        losses = [o.item() for o in self.losses]
        lrs    = self.lrs[pgid]
        n = len(losses)-skip_last
        plt.xscale('log')
        plt.plot(lrs[:n], losses[:n])

In [33]:
print(inspect.getsource(ParamScheduler))

class ParamScheduler(Callback):
    _order=1
    def __init__(self, pname, sched_funcs): self.pname,self.sched_funcs = pname,sched_funcs

    def begin_fit(self):
        if not isinstance(self.sched_funcs, (list,tuple)):
            self.sched_funcs = [self.sched_funcs] * len(self.opt.param_groups)

    def set_param(self):
        assert len(self.opt.param_groups)==len(self.sched_funcs)
        for pg,f in zip(self.opt.param_groups,self.sched_funcs):
            pg[self.pname] = f(self.n_epochs/self.epochs)

    def begin_batch(self):
        if self.in_train: self.set_param()



In [None]:
# original one
class ParamScheduler(Callback):
    _order=1
    def __init__(self, pname, sched_funcs): self.pname,self.sched_funcs = pname,sched_funcs

    def begin_fit(self):
        if not isinstance(self.sched_funcs, (list,tuple)):
            self.sched_funcs = [self.sched_funcs] * len(self.opt.param_groups)

    def set_param(self):
        assert len(self.opt.param_groups)==len(self.sched_funcs)
        for pg,f in zip(self.opt.param_groups,self.sched_funcs):
            pg[self.pname] = f(self.n_epochs/self.epochs)

    def begin_batch(self):
        if self.in_train: self.set_param()