In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path

from fastai import *
from fastai.vision import *
from ipyexperiments import *
from fastai.basic_train import *

Maybe use `loss(reduction="sum")` and average before `real_step()`

In [3]:
class AccumulateOptimWrapper(OptimWrapper):
    def step(self):          pass
    def zero_grad(self):      pass
    def real_step(self):      super().step()
    def real_zero_grad(self): super().zero_grad()
        
def acc_create_opt(self, lr:Floats, wd:Floats=0.):
        "Create optimizer with `lr` learning rate and `wd` weight decay."
        self.opt = AccumulateOptimWrapper.create(self.opt_func, lr, self.layer_groups,
                                         wd=wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
        
@dataclass
class AccumulateStep(LearnerCallback):
    """
    Does accumlated step every nth step by accumulating gradients
    """
    def __init__(self, learn:Learner, n_step:int = 1):
        super().__init__(learn)
        self.n_step = n_step
 
    def on_train_begin(self, **kwargs):
        "check if loss is reduction"
        if self.loss_func.reduction == "mean":
             print("For better gradients consider 'reduction=sum'")
        
    def on_epoch_begin(self, **kwargs):
        "init samples and batches, change optimizer"
        self.acc_samples = 0
        self.acc_batches = 0
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        "accumulate samples and batches"
        self.acc_samples += last_input.shape[0]
        self.acc_batches += 1
#         print(f"At batch {self.acc_batches}")
        
    def on_backward_end(self, **kwargs):
        "step if number of desired batches accumulated, reset samples"
        if (self.acc_batches % self.n_step) == 0:
            for p in (self.learn.model.parameters()):
                if p.requires_grad: p.grad.div_(self.acc_samples)
    
#             print(f"Stepping at batch: {self.acc_batches}")
            self.learn.opt.real_step()
            self.learn.opt.real_zero_grad()
            self.acc_samples = 0
    
    def on_epoch_end(self, **kwargs):
        "step the rest of the accumulated grads"
        self.learn.opt.real_step()
        self.learn.opt.real_zero_grad()

### MNIST batch size = 4, no accumulation, 4 epochs

`effective batch size = 4`

In [5]:
import sys
sys.path.append("../dev")
from data_utils import seed_everything
from fastai.basic_train import Learner; original_create_opt = Learner.create_opt

In [7]:
# original optimwrapper
Learner.create_opt = original_create_opt

In [8]:
# seed everything for reproducibility
seed_everything(42)

In [9]:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path, bs=4)

In [10]:
learn = create_cnn(data, models.resnet18, metrics=accuracy)

In [11]:
learn.loss_func = CrossEntropyFlat(reduction='mean')

In [12]:
learn.fit(4)

epoch,train_loss,valid_loss,accuracy
1,0.396730,0.495012,0.829735
2,0.329213,0.482487,0.856232
3,0.323470,0.328319,0.876349
4,0.319317,9.076464,0.852797


### MNIST batch size=4, accumulate every n_step=8, 4 epochs

`effective batch size = 32 (bs x n_step)`

In [25]:
# seed everything for reproducibility
seed_everything(42)

In [26]:
# monkey patch 
Learner.create_opt = acc_create_opt

In [27]:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path, bs=32)

In [28]:
learn = create_cnn(data, models.resnet18, metrics=accuracy,
                   callback_fns=[partial(AccumulateStep, n_step=8)])

In [29]:
learn.callback_fns

[fastai.basic_train.Recorder,
 functools.partial(<class '__main__.AccumulateStep'>, n_step=8)]

In [30]:
learn.loss_func = CrossEntropyFlat(reduction='sum')

In [31]:
learn.fit(4)

epoch,train_loss,valid_loss,accuracy
1,5.950532,4.001670,0.958783
2,5.227745,3.383670,0.962218
3,5.597690,4.019064,0.957311
4,5.486883,3.625966,0.963199
