In [1]:
# Based on the wide resnet implementation by xternalz: https://github.com/xternalz/WideResNet-pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

Construct the basic block of Resent based on the figure

In [3]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, layer_index):
        super(BasicBlock, self).__init__()
        
        self.activation = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) or None
        self.gain, self.biases = nn.Parameter(torch.ones(1, 1, 1, 1)), nn.ParameterList([nn.parameter(torch.zeros(1, 1, 1, 1)) for _ in range(4)])
        self.layer_index = layer_index
        
    def shortcut(self, x):
        if self.convShortcut is not None:
            return self.convShortcut(self.activation(x))
        else:
            return x
        
    def residual(self, x):
        out = x + self.biases[0]
        out = self.conv1(out) + self.biases[1]
        out = self.activation(out) + self.biases[2]
        out = self.gain*self.conv2(out) + self.biases[3]
        return out
    
    def forward(self, x):
        return self.residual(x) + self.shortcut(x)