In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Подробности по ДЗ:
1. Нужно реализовать флаг make_downsample - увеличивает кол-во фильтров вдвое, размер уменьшается вдвое по высоте и ширине - например (64, 16, 16) -> (128, 8, 8)
2. Нужно реализовать флаг use_skip_connection - если он включен, то на выходе блока добавляется X со входа - иначе блок работает как обычная сеть

Особенности downsample
1. Уменьшать размер входного изображения надо посредством conv3x3 со stride=2
2. В Bottleneck версии - кол-во фильтров меняется первым bottleneck слоем

Общие рекомендации по построению ResNet сетей:
1. После каждой конволюции идет BatchNorm и Relu слои
2. В конце ResNet блока после суммирования идет Relu слой
3. Конволюциооные слои, включая слои Bottleneck не используют bias (bias=False) - опциональное

Блоки строятся на основании статьи https://arxiv.org/abs/1512.03385
Tutorial по Pytorch https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html

In [3]:
from operator import mul
from functools import reduce

# небольшой код по подсчету памяти и количества параметров в сети
MODULES_STAT=[]

def module_forward_hook(module, input, output):
    weight = module.weight.size() if not isinstance(module, torch.nn.modules.MaxPool2d) else (0, 0, 0, 0)
    MODULES_STAT.append((module, output.size(), weight))
    
def register_hook(module):
    for item in module.children(): 
        if type(item) in [nn.modules.conv.Conv2d, nn.modules.MaxPool2d, nn.modules.Linear]:
            print(item)
            item.register_forward_hook(module_forward_hook)
                 
def features_mem_and_params(input_tenzor):
    input_size = input_tenzor.size()
    total_param = 0
    total_mem =  reduce(mul,(input_size))
    print( "%02d" % 0,
          'INPUT',
          "memory",
          "%dx%dx%d=%d" % (input_size[1], input_size[2], input_size[3], reduce(mul,(input_size))),
          "parameters", "%dx%dx%d=%d"%(0, 0, 0 , 0)
         ) 
    for i, stat in enumerate(MODULES_STAT):
        module_name = str(stat).split('(')[1]
        total_param += reduce(mul,(stat[2]))
        total_mem   += reduce(mul,(stat[1]))
        
        if 'Linear' in module_name:
            print( "%02d"%(i+1),'FC',"memory", "%dx%d=%d"%(stat[1][0], stat[1][1], reduce(mul,(stat[1]))),
               "parameters", "%dx%d=%d"%(stat[2][0], stat[2][1] , reduce(mul,(stat[2]))))
        else:    
            print( "%02d"%(i+1),module_name,"memory", "%dx%dx%d=%d"%(stat[1][1], stat[1][2], stat[1][3], reduce(mul,(stat[1]))),
               "parameters", "%dx%dx%dx%d=%d"%(stat[2][0], stat[2][1], stat[2][2], stat[2][3] , reduce(mul,(stat[2]))))
    print()
    print ("Total_mem: %d * 4 = %d" % (total_mem, total_mem * 4))
    print ("Total params: %d" % total_param, "Total_mem: %d" % total_mem)          
    return (total_param, total_mem)

In [4]:
DOWNSAMPLE_COEF = 2

In [5]:
class CifarResidualBlock(nn.Module):
    def __init__(self, a_in_channels, make_downsample=False, use_skip_connection=True):
        super(CifarResidualBlock, self).__init__()
        self.use_skip_connection = use_skip_connection
        
        if make_downsample: coef = DOWNSAMPLE_COEF
        else: coef = 1  
        self.a_in_channels = a_in_channels
        self.a_out_channels = a_in_channels * coef
        self.coef = coef
        
        self.conv1 = nn.Conv2d(a_in_channels, self.a_out_channels, kernel_size=3, stride=coef, padding=1)
        self.conv2 = nn.Conv2d(self.a_out_channels, self.a_out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(self.a_out_channels)
            
    def add_padding(self, x):
        channel_diff = self.a_out_channels - self.a_in_channels
        x = F.pad(x, (0, 0, 0, 0, 0, channel_diff))
        return x[:,:,::self.coef,::self.coef]
    
    def forward(self, x):
        x_temp = x
        x = F.relu(self.bn(self.conv1(x)))
        x = F.relu(self.bn(self.conv2(x)))
        if self.use_skip_connection:
            return F.relu(x + self.add_padding(x_temp))
        return F.relu(x)

In [16]:
class CifarResidualBottleneckBlock(nn.Module):
    
    BOTTLENECK_COEF = 4
    
    def __init__(self, a_in_channels, make_downsample=False, use_skip_connection=True):
        super(CifarResidualBottleneckBlock, self).__init__()
        self.use_skip_connection = use_skip_connection
        self.make_downsample = make_downsample
        
        if make_downsample:
            coef = DOWNSAMPLE_COEF
        else: coef = 1
            
        self.a_in_channels = a_in_channels
        self.a_out_channels = a_in_channels * coef
        self.bottleneck_channels = self.a_out_channels // CifarResidualBottleneckBlock.BOTTLENECK_COEF
        self.coef = coef
        
        self.conv1 = nn.Conv2d(self.a_in_channels, self.bottleneck_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(self.bottleneck_channels, self.bottleneck_channels, kernel_size=3, stride=coef, padding=1)
        self.conv3 = nn.Conv2d(self.bottleneck_channels, self.a_out_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(self.bottleneck_channels)
        self.bn2 = nn.BatchNorm2d(self.a_out_channels)
        
        if make_downsample:
            self.conv0 = nn.Conv2d(self.a_in_channels, self.a_out_channels, kernel_size=1, stride=coef)
    
    def add_padding(self, x):
        channel_diff = self.a_out_channels - self.a_in_channels
        x = F.pad(x, (0, 0, 0, 0, 0, channel_diff))
        return x[:,:,::self.coef,::self.coef]
                
    def forward(self, x):
        x_temp = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn1(self.conv2(x)))
        x = F.relu(self.bn2(self.conv3(x)))
        if self.use_skip_connection:
            if self.make_downsample:
                x_temp = self.conv0(x_temp)
            x += x_temp
        return F.relu(x)

In [7]:
### Test 1
x = torch.ones(1, 3, 32, 32)*100
print("Input size :\t\t", x.size())

first_conv = nn.Conv2d(3, 16, 3, padding=1, bias=False)
x = first_conv(x)
print("After first layers:\t", x.size())

block = CifarResidualBlock(16, make_downsample=False, use_skip_connection=True)
x = block(x)
print("After ResBlock layers:\t", x.size())

assert(x.size() == torch.Size((1, 16, 32, 32)))

Input size :		 torch.Size([1, 3, 32, 32])
After first layers:	 torch.Size([1, 16, 32, 32])
After ResBlock layers:	 torch.Size([1, 16, 32, 32])


In [8]:
### Test 2
x = torch.ones(1, 3, 32, 32)*100
print("Input size :\t\t", x.size())

first_conv = nn.Conv2d(3, 16, 3, padding=1, bias=False)
x = first_conv(x)
print("After first layers:\t", x.size())

block = CifarResidualBlock(16, make_downsample=True, use_skip_connection=True)
x = block(x)
print("After ResBlock layers:\t", x.size())

assert(x.size() == torch.Size((1, 32, 16, 16)))

Input size :		 torch.Size([1, 3, 32, 32])
After first layers:	 torch.Size([1, 16, 32, 32])
After ResBlock layers:	 torch.Size([1, 32, 16, 16])


In [9]:
### Test 3
x = torch.ones(1, 16, 32, 32)
block = CifarResidualBlock(16, make_downsample=False, use_skip_connection=True)
x = block(x)
print(x.size()[1]*x.size()[2]*x.size()[3], x.sum())
assert(x.sum() > 10000)

16384 tensor(18199.1309, grad_fn=<SumBackward0>)


In [10]:
### Test 4
x = torch.ones(1, 16, 32, 32)
block = CifarResidualBlock(16, make_downsample=False, use_skip_connection=False)
x = block(x)
print(x.size()[1]*x.size()[2]*x.size()[3], x.sum())
assert(x.sum() < 5000)

16384 tensor(2218.2639, grad_fn=<SumBackward0>)


In [11]:
### Test 5

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBlock(256, make_downsample=False, use_skip_connection=True)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
assert((params, memory) == (1179648, 49152))

Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 256x8x8=16384 parameters 256x256x3x3=589824
02 Conv2d memory 256x8x8=16384 parameters 256x256x3x3=589824

Total_mem: 49152 * 4 = 196608
Total params: 1179648 Total_mem: 49152


In [12]:
### Test 6

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBlock(256, make_downsample=True, use_skip_connection=False)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (3538944, 32768))

Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 512x4x4=8192 parameters 512x256x3x3=1179648
02 Conv2d memory 512x4x4=8192 parameters 512x512x3x3=2359296

Total_mem: 32768 * 4 = 131072
Total params: 3538944 Total_mem: 32768
3538944 32768


In [13]:
### Test 7

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBottleneckBlock(256, use_skip_connection=True)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
assert((params, memory) == (69632, 40960))

Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 64x8x8=4096 parameters 64x256x1x1=16384
02 Conv2d memory 64x8x8=4096 parameters 64x64x3x3=36864
03 Conv2d memory 256x8x8=16384 parameters 256x64x1x1=16384

Total_mem: 40960 * 4 = 163840
Total params: 69632 Total_mem: 40960


In [17]:
### Test 8

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBottleneckBlock(256, make_downsample=True, use_skip_connection=True)
register_hook(block)
out = block(input)
params, memory = features_mem_and_params(input)
assert((params, memory) == (376832, 43008))

Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 128x8x8=8192 parameters 128x256x1x1=32768
02 Conv2d memory 128x4x4=2048 parameters 128x128x3x3=147456
03 Conv2d memory 512x4x4=8192 parameters 512x128x1x1=65536
04 Conv2d memory 512x4x4=8192 parameters 512x256x1x1=131072

Total_mem: 43008 * 4 = 172032
Total params: 376832 Total_mem: 43008


In [18]:
### Test 9

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBlock(256, make_downsample=True, use_skip_connection=True)
register_hook(block)
out = block(input)
params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (3670016, 40960))

Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 512x4x4=8192 parameters 512x256x3x3=1179648
02 Conv2d memory 512x4x4=8192 parameters 512x512x3x3=2359296

Total_mem: 32768 * 4 = 131072
Total params: 3538944 Total_mem: 32768
3538944 32768


AssertionError: 