# ResNet
paper: https://arxiv.org/pdf/1512.03385.pdf

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))

from stateful_optim import *

Residual block equation:

y = F(x, {Wi}) + x

\
Each res block has two "subModels", a F(x, {Wi}) and x

In [3]:
#export
def get_basic_block(i, o, s):
    '''Get basic ResNet block.
        i: channel in
        o: channel out
        s: stride size
    '''
    return Sequential(Conv(i, o, 3, s),
                      BatchNorm(o),
                      ReLU(),
                      Conv(o, o*4, 3, 1),
                      BatchNorm(o))

def get_bottleneck(i, o):
    '''Get bottleneck ResNet block.
        i: channel in
        o: channel out
    '''
    return Sequential(Conv(i, o, 1, 1),
                      BatchNorm(o),
                      ReLU(),
                      Conv(o, o, 3, 1),
                      BatchNorm(o),
                      ReLU(),
                      Conv(o, o*4, 1, 1),
                      BatchNorm(o))

In [4]:
get_basic_block(32, 64, 1)

(Model)
    Conv(32, 64, 3, 1)
    BatchNorm()
    ReLU()
    Conv(64, 256, 3, 1)
    BatchNorm()

In [5]:
get_bottleneck(32, 64)

(Model)
    Conv(32, 64, 1, 1)
    BatchNorm()
    ReLU()
    Conv(64, 64, 3, 1)
    BatchNorm()
    ReLU()
    Conv(64, 256, 1, 1)
    BatchNorm()

In [6]:
#export
class ResLayer(Module):
    def __init__(self, i, o, s, bottleneck):
        '''Get ResLayer (almost a ResBlock but not including the final activation).
            i: channel in
            o: channel out
            s: stride size
            bottleneck: boolean of whether the resblock is basic or bottleneck
        '''
        super().__init__()
        self.i, self.o, self.s, self.bottleneck = i, o, s, bottleneck
        self.x_layer = Identity()
        self.Fx_layer = get_bottleneck(i, o) if bottleneck else get_basic_block(self.i, self.o, self.s)
        
    def fwd(self, inp):
        self.fx = self.Fx_layer(inp)
        self.x = self.x_layer(inp)
        self.out = self.fx + self.x
        return self.out
    
    def bwd(self, out, inp):
        self.fx.g = out.g
        self.x.g = out.g
        self.Fx_layer.backward()
        self.x_layer.backward()
    
    def parameters(self):
        for sub_model in [self.Fx_layer, self.x_layer]:
            for param in sub_model.parameters():
                yield param

In [7]:
#export
class ResBlock(SubModel):
    def __init__(self, i, o, s, bottleneck):
        '''ResBlock (ResLayer + Activation).
            i: channel in
            o: channel out
            s: stride size
            bottleneck: boolean of whether the resblock is basic or bottleneck
        '''
        super().__init__()
        self.i, self.o, self.s, self.bottleneck = i, o, s, bottleneck
        self.sub_model = Sequential(ResLayer(i, o, s, bottleneck),
                                    ReLU())
    
    def __repr__(self, t=''):
        return f"{t+'    '}{'Bottleneck' if self.bottleneck else 'BasicBlock'}({self.i}, {self.o}, {self.s})"

In [8]:
ResBlock(32, 64, 1, True)

    Bottleneck(32, 64, 1)

In [9]:
ResBlock(32, 64, 1, False)

    BasicBlock(32, 64, 1)

In [10]:
#export
class ResBlockGroup(SubModel):
    def __init__(self, i, o, num_blocks, bottleneck):
        '''Group of ResBlocks.
            i: channel in
            o: channel out
            num_blocks: number of resblockss
            bottleneck: boolean of whether the resblocks are basic or bottleneck
        '''
        layers = [ResBlock(i, o, 2, bottleneck)]
        for _ in range(num_blocks-1):
            layers.append(ResBlock(o, o, 1, bottleneck))
        self.sub_model = Sequential(*layers)
        
    def __repr__(self, t=''):
        return f"{self.sub_model.__repr__(t+'    ')}"

In [11]:
#export
def get_res_head(in_shape, o):
    '''ResNet head (before ResBlocks).
        in_shape: input shape before conv bn relu and pool
        o: channel out
    '''
    return [Reshape(in_shape),
            Conv(in_shape[0], o, 7, 2),
            BatchNorm(o),
            ReLU(),
            MaxPool(3, 2, 1)]

def get_res_tail(h, o):
    '''ResNet tail (after ResBlocks).
        h: number of hidden cells
        o: channel out
    '''
    return [AvgPool(1, 1, 0),
            Flatten(),
            Linear(h, o, True)]

In [12]:
Sequential(get_res_head((1, 28, 28), 64))

(Model)
    Reshape(1, 28, 28)
    Conv(1, 64, 7, 2)
    BatchNorm()
    ReLU()
    MaxPool(3, 2)

In [13]:
Sequential(get_res_tail(1024, 64))

(Model)
    AvgPool(1, 1)
    Flatten()
    Linear(1024, 64)

In [14]:
#export
name2depths = {18:  [2, 2, 2,  2],
               34:  [3, 4, 6,  3],
               50:  [3, 4, 6,  3],
               101: [3, 4, 23, 3],
               152: [3, 8, 36, 3]}

class ResNet(SubModel):
    def __init__(self, n_layer, in_shape=(3,28,28), out=100):
        '''ResNet model that is able to create ResNets with different number of layers adaptively.
            n_layer: number of resnet layers (18, 34...)
            in_shape: input image shape
            out: output number of labels
        '''
        self.name = f'ResNet {n_layer}'
        bottleneck = n_layer > 34
        channels = [64, 128, 256, 512]
        depths = name2depths[n_layer]
        
        head = get_res_head(in_shape, channels[0])
        body = [ResBlockGroup(channels[0], channels[0], depths[0], bottleneck)]
        for i, o, depth in zip(channels, channels[1:], depths[1:]):
            body.append(ResBlockGroup(i, o, depth, bottleneck))
        tail = get_res_tail(channels[-1]*(4 if bottleneck else 1), out)
        self.sub_model = Sequential(head + body + tail)
    
    def __repr__(self, t=''):
        return f'{t}{self.sub_model}'

# Tests

In [15]:
ResNet(18, (1, 48, 48))

(Model)
    Reshape(1, 48, 48)
    Conv(1, 64, 7, 2)
    BatchNorm()
    ReLU()
    MaxPool(3, 2)
        BasicBlock(64, 64, 2)
        BasicBlock(64, 64, 1)
        BasicBlock(64, 128, 2)
        BasicBlock(128, 128, 1)
        BasicBlock(128, 256, 2)
        BasicBlock(256, 256, 1)
        BasicBlock(256, 512, 2)
        BasicBlock(512, 512, 1)
    AvgPool(1, 1)
    Flatten()
    Linear(512, 100)

In [16]:
ResNet(34, (3, 128, 128), 25)

(Model)
    Reshape(3, 128, 128)
    Conv(3, 64, 7, 2)
    BatchNorm()
    ReLU()
    MaxPool(3, 2)
        BasicBlock(64, 64, 2)
        BasicBlock(64, 64, 1)
        BasicBlock(64, 64, 1)
        BasicBlock(64, 128, 2)
        BasicBlock(128, 128, 1)
        BasicBlock(128, 128, 1)
        BasicBlock(128, 128, 1)
        BasicBlock(128, 256, 2)
        BasicBlock(256, 256, 1)
        BasicBlock(256, 256, 1)
        BasicBlock(256, 256, 1)
        BasicBlock(256, 256, 1)
        BasicBlock(256, 256, 1)
        BasicBlock(256, 512, 2)
        BasicBlock(512, 512, 1)
        BasicBlock(512, 512, 1)
    AvgPool(1, 1)
    Flatten()
    Linear(512, 25)

In [17]:
ResNet(50)

(Model)
    Reshape(3, 28, 28)
    Conv(3, 64, 7, 2)
    BatchNorm()
    ReLU()
    MaxPool(3, 2)
        Bottleneck(64, 64, 2)
        Bottleneck(64, 64, 1)
        Bottleneck(64, 64, 1)
        Bottleneck(64, 128, 2)
        Bottleneck(128, 128, 1)
        Bottleneck(128, 128, 1)
        Bottleneck(128, 128, 1)
        Bottleneck(128, 256, 2)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 512, 2)
        Bottleneck(512, 512, 1)
        Bottleneck(512, 512, 1)
    AvgPool(1, 1)
    Flatten()
    Linear(2048, 100)

In [18]:
ResNet(101)

(Model)
    Reshape(3, 28, 28)
    Conv(3, 64, 7, 2)
    BatchNorm()
    ReLU()
    MaxPool(3, 2)
        Bottleneck(64, 64, 2)
        Bottleneck(64, 64, 1)
        Bottleneck(64, 64, 1)
        Bottleneck(64, 128, 2)
        Bottleneck(128, 128, 1)
        Bottleneck(128, 128, 1)
        Bottleneck(128, 128, 1)
        Bottleneck(128, 256, 2)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottleneck(256, 256, 1)
        Bottl

In [19]:
hyper_params = {'learning_rate':0.001, 'weight_decay':1e-4}
schedule = combine_schedules([0.4, 0.6], one_cycle_cos(0.0001, 0.003, 0.0001))

data_bunch = get_data_bunch(*get_mnist_data(), batch_size=64)
model = Sequential(ResNet(18))
loss_fn = CrossEntropy()
optimizer = adam_opt(model, **hyper_params)
callbacks = [StatsLogging(), Recorder()]

learner = Learner(data_bunch, model, loss_fn, optimizer, callbacks)
print(learner)

(DataBunch) 
    (DataLoader) 
        (Dataset) x: (50000, 784), y: (50000,)
        (Sampler) total: 50000, batch_size: 64, shuffle: True
    (DataLoader) 
        (Dataset) x: (10000, 784), y: (10000,)
        (Sampler) total: 10000, batch_size: 128, shuffle: False
(Model)
(Model)
    Reshape(3, 28, 28)
    Conv(3, 64, 7, 2)
    BatchNorm()
    ReLU()
    MaxPool(3, 2)
        BasicBlock(64, 64, 2)
        BasicBlock(64, 64, 1)
        BasicBlock(64, 128, 2)
        BasicBlock(128, 128, 1)
        BasicBlock(128, 256, 2)
        BasicBlock(256, 256, 1)
        BasicBlock(256, 512, 2)
        BasicBlock(512, 512, 1)
    AvgPool(1, 1)
    Flatten()
    Linear(512, 100)
(CrossEntropy)
(StatefulOpt) steppers: ['adam', 'l2_reg'], stats: ['ExpWeightedGrad', 'ExpWeightedSqrGrad', 'StepCount']
(Callbacks) ['TrainEval', 'StatsLogging', 'Recorder']
