In [22]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


In [23]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [24]:
from fastai.imports import *
from fastai.sgdr import Callback

from fastai.core import SimpleNet
from fastai.conv_learner import *

In [25]:
PATH = "data/cifar10/"

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

def get_data(sz,bs):
    tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)
    return ImageClassifierData.from_paths(PATH, trn_name='train_', val_name='test_', tfms=tfms, bs=bs)

bs=128
data = get_data(32,bs)

In [26]:
net = SimpleNet([32*32*3, 40, 10])

In [27]:
learn = ConvLearner.from_model_data(net, data)

In [28]:
lr = 2e-2

In [29]:
learn.fit(lr, 1, use_swa=True)

epoch      trn_loss   val_loss   accuracy   swa_loss   swa_accuracy 
    0      1.774172   1.649104   0.413074   1.649104   0.413074  



[1.6491035, 0.4130735759493671, 1.6491035, 0.4130735759493671]

In [30]:
learn.swa_model

SimpleNet(
  (layers): ModuleList(
    (0): Linear(in_features=3072, out_features=40, bias=True)
    (1): Linear(in_features=40, out_features=10, bias=True)
  )
)

In [31]:
# verifies that it's equal to the first model's parameters after 1 epoch
for p1, p2 in zip(learn.model.parameters(), learn.swa_model.parameters()):
    print(p1 == p2)

Variable containing:
    1     1     1  ...      1     1     1
    1     1     1  ...      1     1     1
    1     1     1  ...      1     1     1
       ...          ⋱          ...       
    1     1     1  ...      1     1     1
    1     1     1  ...      1     1     1
    1     1     1  ...      1     1     1
[torch.cuda.ByteTensor of size 40x3072 (GPU 0)]

Variable containing:
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
[torch.cuda.ByteTensor of size 40 (GPU 0)]

Variable containing:

Columns 0 to 12 
    1     1     1     1     1     1     1     1     1     1     1     1     1
    1     1     1     1     1     1     1     1     1     1     1     1     1
    1     1     1     1     1     1     1     1     1     1     1     1     1
    1     1     1     1     1     1     1     1     1     1     1     1     1
    1     1     1     1     1     1     1     1     1     1     1     1     1
    1     1     1     

In [49]:
params = []

class SaveModelParams(Callback):
    def __init__(self, model):
        self.model = model
        
    def on_epoch_end(self, metrics):
        params.append([p.data.cpu().numpy() for p in self.model.parameters()])

In [50]:
net2 = SimpleNet([32*32*3, 40, 10])
learn2 = ConvLearner.from_model_data(net2, data)
lr = 2e-2
learn2.fit(lr, 3, use_swa=True, callbacks=[SaveModelParams(learn2.model)])

epoch      trn_loss   val_loss   accuracy   swa_loss   swa_accuracy 
    0      1.773514   1.691156   0.404371   1.691156   0.404371  
    1      1.737232   1.603997   0.432259                   
    2      1.68089    1.644307   0.417227   1.513425   0.463212  



[1.6443068, 0.41722705696202533, 1.513425, 0.4632120253164557]

In [51]:
print(len(params))

3


In [52]:
swa_model_params = [p.data.cpu().numpy() for p in learn2.swa_model.parameters()]

In [53]:
for p_model1, p_model2, p_model3, p_swa_model in zip(*params, swa_model_params):
    # check for equality up to a certain tolerance
    print(np.isclose(p_swa_model, np.mean(np.stack([p_model1, p_model2, p_model3]), axis=0)))

[[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]
[ True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True]
[[ True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True

In [56]:
learn.save('test')
learn.load('test')

In [57]:
preds = learn2.predict()
preds_swa = learn2.predict(use_swa=True)
print(preds, preds_swa)

[[-1.70664 -1.05492 -3.87068 ... -3.25529 -2.24639 -1.74912]
 [-1.96794 -4.14807 -2.40735 ... -1.167   -4.95536 -3.72812]
 [-0.61625 -4.61556 -1.19078 ... -6.56123 -2.47553 -4.84968]
 ...
 [-3.75604 -2.18386 -3.12696 ... -2.74888 -4.26681 -0.53816]
 [-2.22326 -1.58195 -4.6265  ... -2.21887 -4.00163 -1.10925]
 [-2.19951 -0.71748 -5.376   ... -5.23997 -1.85072 -1.54509]] [[-1.28266 -1.90767 -3.08345 ... -2.65925 -1.99473 -1.40246]
 [-1.29219 -5.05158 -2.04013 ... -1.32583 -4.00925 -3.2551 ]
 [-0.4586  -5.46795 -1.64899 ... -5.80076 -2.18757 -5.69979]
 ...
 [-3.23602 -1.76301 -3.36458 ... -2.50504 -3.53825 -0.70707]
 [-1.65576 -2.66782 -3.46807 ... -1.68105 -3.06506 -1.33282]
 [-1.79727 -1.17396 -4.66499 ... -4.85271 -1.85708 -1.09175]]


In [37]:
params = []

class SaveModelParams(Callback):
    def __init__(self, model):
        self.model = model
        
    def on_epoch_end(self, metrics):
        params.append([p.data.cpu().numpy() for p in self.model.parameters()])

In [38]:
net = SimpleNet([32*32*3, 40, 10])
learn = ConvLearner.from_model_data(net, data)
lr = 2e-2
learn.fit(lr, 6, use_swa=True, swa_start=3, callbacks=[SaveModelParams(learn.model)])

epoch      trn_loss   val_loss   accuracy   swa_loss   swa_accuracy 
    0      1.771523   1.649534   0.413074  
    1      1.722752   1.678871   0.400218                   
    2      1.698014   1.648443   0.416337   1.648443   0.416337  
    3      1.691304   1.582978   0.437302                   
    4      1.706535   1.555867   0.439082                   
    5      1.656273   1.628426   0.435423   1.449969   0.484474  



[1.6284263, 0.4354232594936709, 1.4499689, 0.4844738924050633]

In [39]:
swa_model_params = [p.data.cpu().numpy() for p in learn.swa_model.parameters()]

In [40]:
print(len(params))

6


In [41]:
for *p_models, p_swa_model in zip(*params[2:], swa_model_params):
    # check for equality up to a certain tolerance
    print(np.isclose(p_swa_model, np.mean(np.stack(p_models), axis=0)))

[[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]
[ True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True]
[[ True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True

In [43]:
net = SimpleNet([32*32*3, 40, 10])
net

SimpleNet(
  (layers): ModuleList(
    (0): Linear(in_features=3072, out_features=40, bias=True)
    (1): Linear(in_features=40, out_features=10, bias=True)
  )
)

In [44]:
from fastai.swa import collect_bn_modules

net_bn = []
net.apply(lambda m: collect_bn_modules(m, net_bn))
print(len(net_bn))

resnet_bn = []
resnet34().apply(lambda m: collect_bn_modules(m, resnet_bn))
print(len(resnet_bn))

0
36


In [46]:
from __future__ import absolute_import

'''Resnet for cifar dataset.
Ported form
https://github.com/facebook/fb.resnet.torch
and
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
(c) YANG, Wei
'''
import torch.nn as nn
import math


__all__ = ['preresnet']

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        return out


class PreResNet(nn.Module):

    def __init__(self, depth, num_classes=1000):
        super(PreResNet, self).__init__()
        # Model type specifies number of layers for CIFAR-10 model
        assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
        n = (depth - 2) // 6

        block = Bottleneck if depth >=44 else BasicBlock

        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
                               bias=False)
        self.layer1 = self._make_layer(block, 16, n)
        self.layer2 = self._make_layer(block, 32, n, stride=2)
        self.layer3 = self._make_layer(block, 64, n, stride=2)
        self.bn = nn.BatchNorm2d(64 * block.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)

        x = self.layer1(x)  # 32x32
        x = self.layer2(x)  # 16x16
        x = self.layer3(x)  # 8x8
        x = self.bn(x)
        x = self.relu(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def preresnet(**kwargs):
    """
    Constructs a ResNet model.
    """
    return PreResNet(**kwargs)

In [47]:
from fastai.swa import fix_batchnorm, collect_bn_modules

def test_momentum_preserved(model):
    bn_modules = []
    model.apply(lambda module: collect_bn_modules(module, bn_modules))
    momenta_before = [m.momentum for m in bn_modules]
    fix_batchnorm(preresnet110, data.trn_dl)
    
    for module, momentum_before in zip(bn_modules, momenta_before):
        assert module.momentum == momentum_before

In [48]:
model = preresnet110 = preresnet(depth=110, num_classes=10).cuda()
test_momentum_preserved(model)