## Acknowledgements

This project uses the [FastResNet](https://github.com/pytorch/ignite) architecture.  
Credits to the original authors for the implementation.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class GhostBatchNorm(nn.BatchNorm2d):
    """
    From : https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb

    Batch norm seems to work best with batch size of around 32. The reasons presumably have to do 
    with noise in the batch statistics and specifically a balance between a beneficial regularising effect 
    at intermediate batch sizes and an excess of noise at small batches.
    
    Our batches are of size 512 and we can't afford to reduce them without taking a serious hit on training times, 
    but we can apply batch norm separately to subsets of a training batch. This technique, known as 'ghost' batch 
    norm, is usually used in a distributed setting but is just as useful when using large batches on a single node. 
    It isn't supported directly in PyTorch but we can roll our own easily enough.
    """
    def __init__(self, num_features, num_splits, eps=1e-05, momentum=0.1, weight=True, bias=True):
        super(GhostBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum)
        self.weight.data.fill_(1.0)
        self.bias.data.fill_(0.0)
        self.weight.requires_grad = weight
        self.bias.requires_grad = bias        
        self.num_splits = num_splits
        self.register_buffer('running_mean', torch.zeros(num_features*self.num_splits))
        self.register_buffer('running_var', torch.ones(num_features*self.num_splits))

    def train(self, mode=True):
        if (self.training is True) and (mode is False):
            self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)
            self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)
        return super(GhostBatchNorm, self).train(mode)
        
    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            return F.batch_norm(
                input.view(-1, C*self.num_splits, H, W), self.running_mean, self.running_var, 
                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
                True, self.momentum, self.eps).view(N, C, H, W) 
        else:
            return F.batch_norm(
                input, self.running_mean[:self.num_features], self.running_var[:self.num_features], 
                self.weight, self.bias, False, self.momentum, self.eps)

In [3]:
class IdentityResidualBlock(nn.Module):

    def __init__(self, num_channels, 
                 conv_ksize=3, conv_pad=1,
                 gbn_num_splits=16):
        super(IdentityResidualBlock, self).__init__()
        self.res1 = nn.Sequential(
            Conv2d(num_channels, num_channels, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            GhostBatchNorm(num_channels, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3)         
        )
        self.res2 = nn.Sequential(
            Conv2d(num_channels, num_channels, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            GhostBatchNorm(num_channels, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3)    
        )

    def forward(self, x):
        residual = x
        x = self.res1(x)
        x = self.res2(x)
        return x + residual

In [4]:
# We override conv2d to get proper padding for kernel size = 2   
class Conv2d(nn.Conv2d):
    
    def __init__(self, *args, **kwargs):
        super(Conv2d, self).__init__(*args, **kwargs)
        if self.kernel_size == (2, 2):
            self.forward = self.ksize_2_forward
            self.ksize_2_padding = (0, self.padding[0], 0, self.padding[1])
            self.padding = (0, 0)
        
    def ksize_2_forward(self, x):
        x = F.pad(x, pad=self.ksize_2_padding)
        return super(Conv2d, self).forward(x)

In [5]:
class FastResNet(nn.Module):
        
    def __init__(self, num_classes=10, 
                 fmap_factor=64, conv_ksize=3, conv_pad=1, 
                 gbn_num_splits=512 // 32,                  
                 classif_scale=0.0625):
        super(FastResNet, self).__init__()
                
        self.prep = nn.Sequential(
            Conv2d(3, fmap_factor, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            GhostBatchNorm(fmap_factor, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3)
        )

        self.layer1 = nn.Sequential(
            Conv2d(fmap_factor, fmap_factor * 2, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            nn.MaxPool2d(kernel_size=2),
            GhostBatchNorm(fmap_factor * 2, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3),
            IdentityResidualBlock(fmap_factor * 2,
                                  conv_ksize=conv_ksize, conv_pad=conv_pad, 
                                  gbn_num_splits=gbn_num_splits)
        )
        
        self.layer2 = nn.Sequential(
            Conv2d(fmap_factor * 2, fmap_factor * 4, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            nn.MaxPool2d(kernel_size=2),
            GhostBatchNorm(fmap_factor * 4, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3),            
        )
        
        self.layer3 = nn.Sequential(
            Conv2d(fmap_factor * 4, fmap_factor * 8, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            nn.MaxPool2d(kernel_size=2),
            GhostBatchNorm(fmap_factor * 8, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3),
            IdentityResidualBlock(fmap_factor * 8, 
                                  conv_ksize=conv_ksize, conv_pad=conv_pad, 
                                  gbn_num_splits=gbn_num_splits)
        )
        
        self.pool = nn.MaxPool2d(kernel_size=4)
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(fmap_factor * 8, num_classes)
        )
        self.scale = torch.tensor(0.0625, requires_grad=False)

    def forward(self, x):
        x = self.prep(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        y = self.classifier(x)
        return y * self.scale

In [6]:
model = FastResNet(10, fmap_factor=64)