In [None]:
from fastai.vision import *

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(42)

### No Acc

In [None]:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = create_cnn(data, models.vgg16_bn, metrics=accuracy)
learn.fit(1)

In [None]:
data.batch_size

### Naive Acc 

In [None]:
data = ImageDataBunch.from_folder(path, bs=2)
learn = create_cnn(data, models.vgg16_bn, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=32)])
learn.loss_func = CrossEntropyFlat(reduction='sum')
learn.fit(1)

### Acc + BnFreeze

In [None]:
data = ImageDataBunch.from_folder(path, bs=2)
learn = create_cnn(data, models.vgg16_bn, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=32), BnFreeze])
learn.loss_func = CrossEntropyFlat(reduction='sum')
learn.fit(1)

### Increase Momentum 

In [None]:
def set_bn_mom(m:nn.Module, mom=0.9):
    "Set bn layers in eval mode for all recursive children of `m`."
    for l in m.children():
        if isinstance(l, bn_types):
            l.momentum = mom
        set_bn_mom(l, mom)

In [None]:
data = ImageDataBunch.from_folder(path, bs=2)
learn = create_cnn(data, models.vgg16_bn, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=32)])
learn.loss_func = CrossEntropyFlat(reduction='sum')
set_bn_mom(learn.model, mom=0.9)
learn.fit(1)

### Decrease Momentum

In [None]:
data = ImageDataBunch.from_folder(path, bs=2)
learn = create_cnn(data, models.vgg16_bn, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=32)])
learn.loss_func = CrossEntropyFlat(reduction='sum')
set_bn_mom(learn.model, mom=0.01)
learn.fit(1)

### InstanceNorm

In [None]:
def bn2instance(bn):
    if isinstance(bn, nn.BatchNorm1d):
        inst = nn.InstanceNorm1d(bn.num_features, affine=True)
    elif isinstance(bn, nn.BatchNorm2d):
        inst = nn.InstanceNorm2d(bn.num_features, affine=True)
    elif isinstance(bn, nn.BatchNorm3d):
        inst = nn.InstanceNorm3d(bn.num_features, affine=True)
    
    inst.weight = bn.weight
    inst.bias = bn.bias
    inst.running_mean = nn.Parameter(bn.running_mean, requires_grad=False)
    inst.running_var = nn.Parameter(bn.running_var, requires_grad=False)
    inst.momentum = bn.momentum
    inst.eps = bn.eps
    return (inst).to(bn.weight.device)

In [None]:
def convert_bn(list_mods, func=bn2instance):
    for i in range(len(list_mods)):
        if isinstance(list_mods[i], bn_types):
            list_mods[i] = func(list_mods[i])
        elif list_mods[i].__class__.__name__ in ("Sequential", "BasicBlock"):
            list_mods[i] = nn.Sequential(*convert_bn(list(list_mods[i].children()), func))
    return list_mods

In [None]:
data = ImageDataBunch.from_folder(path, bs=2)
learn = create_cnn(data, models.vgg16_bn, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=32)])
learn.loss_func = CrossEntropyFlat(reduction='sum')

In [None]:
learn.model = nn.Sequential(*convert_bn(list(learn.model.children()), bn2instance))

In [None]:
learn.freeze()

In [None]:
learn.lr_find()

In [None]:
learn.fit(1)

### GroupNorm

In [172]:
def bn2group(bn):
    groupnorm = nn.GroupNorm(4, bn.num_features, affine=True)
    groupnorm.weight = bn.weight
    groupnorm.bias = bn.bias
    groupnorm.running_mean = nn.Parameter(bn.running_mean, requires_grad=False)
    groupnorm.running_var = nn.Parameter(bn.running_var, requires_grad=False)
    groupnorm.momentum = bn.momentum
    groupnorm.eps = bn.eps
    return (groupnorm).to(bn.weight.device)

In [173]:
data = ImageDataBunch.from_folder(path, bs=2)
learn = create_cnn(data, models.vgg16_bn, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=32)])
learn.loss_func = CrossEntropyFlat(reduction='sum')

In [174]:
learn.model = nn.Sequential(*convert_bn(list(learn.model.children()), bn2group))

In [175]:
learn.freeze()

In [176]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy,time
1,0.070629,0.093760,0.985770,01:00
