In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
from exp.nb_04 import *

# Inital setup

In [None]:
x_train, y_train, x_valid, y_valid = get_data()
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)
nh, bs = 50, 512
c = y_train.max().item() + 1
loss_func = F.cross_entropy

In [None]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [None]:
def create_learner(model_func, loss_func, data):
    return Learner(*model_func(data), loss_func, data)

In [None]:
learn = create_learner(get_model, loss_func, data)
run = Runner(cbs=[AvgStatsCallback([accuracy])])

run.fit(3, learn)

In [None]:
def get_model_func(lr=0.1): return partial(get_model, lr=lr)

In [None]:
learn = create_learner(get_model_func(0.1), loss_func, data)
run = Runner(cbs=[AvgStatsCallback([accuracy])])

run.fit(3, learn)

# Annealing

In [None]:
# torch.optim.Adam().param_groups

In [None]:
class Recorder(Callback):
    def begin_fit(self): self.lrs, self.losses = [], []
        
    def after_batch(self): 
        if not self.in_train: return
        self.lrs.append(self.opt.param_groups[-1]['lr'])
        self.losses.append(self.loss.detach().cpu())
        
    def plot_lr(self): plt.plot(self.lrs)
    
    def plot_loss(self): plt.plot(self.losses)
        
class ParamScheduler(Callback):
    _order = 1
    def __init__(self, pname, sched_func): self.pname, self.sched_func = pname, sched_func

    def begin_batch(self):
        if self.in_train: self.set_params()
            
    def set_params(self):
        if not self.in_train: return
        for pg in self.opt.param_groups:
            pg[self.pname] = self.sched_func(self.n_epochs / self.epochs)

In [None]:
def sched_lin(start, end):
    def _inner(start, end, pos): return start + pos * (end - start)
    
    return partial(_inner, start, end)

In [None]:
sched = sched_lin(1,2)

In [None]:
sched(0.3)

In [None]:
def annealer(f):
    def _inner(start, end): return partial(f, start, end)
    
    return _inner

In [None]:
@annealer
def sched_lin(start, end, pos): return start + pos * (end - start)

In [None]:
sched_lin(1,2)(0.3)

In [None]:
@annealer
def sched_no(start, end, pos): return start
@annealer
def sched_cos(start, end, pos): return start + (1 + math.cos(math.pi * (1 - pos))) * (end - start) / 2 # 
@annealer
def sched_exp(start, end, pos): return start * (end / start) ** pos

In [None]:
def cos_1cycle_anneal(start, high, end):
    return [sched_cos(start, high), sched_cos(high, end)]

In [None]:
torch.Tensor.ndim = property(lambda x: len(x.shape))

In [None]:
annealings = 'NO LINEAR COS EXP'.split()
fncs = [sched_no, sched_lin, sched_cos, sched_exp]

x = torch.arange(0, 100)
y = torch.linspace(0.01, 1, 100)

for fnc, name in  zip(fncs, annealings):
    sched = fnc(2, 1e-2)
    plt.plot(x, [sched(o) for o in y], label=name)
    
plt.legend();

In [None]:
def combine_scheds(pcts, scheds): #combine_scheds([0.3, 0.7], [sched_cos(0.3, 0.6), sched_cos(0.6, 0.2)]) 
    assert sum(pcts) == 1
    pcts = tensor([0] + listify(pcts))
    assert torch.all(pcts >= 0)
    pcts = torch.cumsum(pcts, 0)
    
    def _inner(pos):
        idx = (pos >= pcts).nonzero().max()
        if idx == 2: idx = 1
        actual_pos = (pos - pcts[idx]) / (pcts[idx+1] - pcts[idx])
        return scheds[idx](actual_pos)
    
    return _inner

In [None]:
sched = combine_scheds([0.3, 0.7], [sched_cos(0.3, 0.6), sched_cos(0.6, 0.2)]) 

In [None]:
plt.plot(x, [sched(o) for o in y]);

In [None]:
cbfs = [Recorder,
       partial(AvgStatsCallback, accuracy),
       partial(ParamScheduler, 'lr', sched)]

In [None]:
learn = create_learner(get_model_func(lr=0.02), loss_func, data)

In [None]:
run = Runner(cb_funcs=cbfs)

In [None]:
run.fit(3, learn)

In [None]:
run.recorder.plot_lr()

In [None]:
run.recorder.plot_loss()