In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path

from fastai import *
from fastai.vision import *
from ipyexperiments import *

Maybe use `loss(reduction="sum")` and average `on_epoch_end`

In [3]:
class myOptimWrapper(OptimWrapper):
    def step(self):          pass
    def zero_grad(self):      pass
    def real_step(self):      super().step()
    def real_zero_grad(self): super().zero_grad()

In [4]:
@dataclass
class StepEpochEnd(Callback):
    learn:Learner
    def on_epoch_end(self, **kwargs):
        print("real step and zero grad")
        self.learn.opt.real_step()
        self.learn.opt.real_zero_grad()

In [5]:
@dataclass
class ShowGrads(Callback):
    learn:Learner
    def on_loss_begin(self, **kwargs):
        print("before batch loss:")
        last_layers = self.learn.layer_groups[-1]
        last_layer = last_layers[-1]
        print(last_layer.weight)
    
    def on_epoch_end(self, **kwargs):
        print("on epoch end:")
        last_layers = self.learn.layer_groups[-1]
        last_layer = last_layers[-1]
        print(last_layer.weight)

In [6]:
def my_create_opt(self, lr:Floats, wd:Floats=0.)->None:
    "Create optimizer with `lr` learning rate and `wd` weight decay."
    self.opt = myOptimWrapper.create(self.opt_func, lr, self.layer_groups,
                                     wd=wd, true_wd=self.true_wd, bn_wd=self.bn_wd)

In [7]:
Learner.create_opt = my_create_opt

In [8]:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)

In [15]:
learn = create_cnn(data, models.resnet18, metrics=accuracy,
                   callback_fns=[partial(StepEpochEnd)])

In [16]:
learn.callback_fns

[fastai.basic_train.Recorder,
 functools.partial(<class '__main__.StepEpochEnd'>)]

In [21]:
learn.opt.real_step??

In [17]:
learn.loss_func = CrossEntropyFlat(reduction='mean')

In [18]:
learn.fit_one_cycle(10)

epoch,train_loss,valid_loss,accuracy
1,1.011151,0.898423,0.408734
2,0.714787,0.592207,0.676153
3,0.417709,0.342188,0.866045
4,0.316321,0.267099,0.904318
5,0.286497,0.236401,0.911678
6,0.266473,0.214437,0.921492
7,0.237963,0.201486,0.923454
8,0.234175,0.192990,0.931305
9,0.245454,0.187826,0.931796
10,0.233982,0.197862,0.928361


real step and zero grad
real step and zero grad
real step and zero grad
real step and zero grad
real step and zero grad
real step and zero grad
real step and zero grad
real step and zero grad
real step and zero grad
real step and zero grad


In [47]:
for p in (learn.model.parameters()):
    if p.requires_grad: p.grad.div_(10) 

In [28]:
p = next(learn.model.parameters())

In [33]:
p.requires_grad

False