In [19]:
from typing import List, Optional
from collections import OrderedDict

import torch
import torch.nn.functional as F
from torch import nn

In [37]:

class ShortcutProjection(nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 stride: int,
                 apply_bn = True) -> None:

        super().__init__()

        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size=1, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels) \
            if apply_bn \
            else nn.Identity()
    
    def forward(self, x: torch.Tensor):
        return self.bn(self.conv(x))

In [73]:


class ResidualBlock(nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 n_convs: int = 2,
                 stride: int = 1) -> None:
        super().__init__()

        dims = zip(
            [in_channels] + [out_channels] * (n_convs - 1),
            [out_channels] * n_convs)

        self.backbone = nn.ModuleList([
            nn.Sequential(
                OrderedDict([
                    (f"conv_{i}", nn.Conv2d(ch_in, ch_out, 3, stride = stride, padding=1)),
                    (f"bn_{i}", nn.BatchNorm2d(ch_out))])
                )
            for i, (ch_in, ch_out) in enumerate(dims, start = 1)])

        self.shortcut = ShortcutProjection(in_channels, out_channels, stride) \
            if stride != 1 or in_channels != out_channels \
            else nn.Identity()
        
    def forward(self, x: torch.Tensor):
        f_x = x
        for conv_bn in self.backbone[:-1]:
            #print(i)
            f_x = F.relu(conv_bn(f_x))
        
        return F.relu(
            self.backbone[-1](f_x)) + self.shortcut(x)


In [74]:
resblock = ResidualBlock(3,5, n_convs=30)
print(len(resblock.backbone))
for i, conv_bn in enumerate(resblock.backbone[:-1], start=1):
    print(i, conv_bn, end='\n\n')

30
1 Sequential(
  (conv_1): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn_1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

2 Sequential(
  (conv_2): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn_2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

3 Sequential(
  (conv_3): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn_3): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

4 Sequential(
  (conv_4): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn_4): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

5 Sequential(
  (conv_5): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn_5): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

6 Sequential(
  (conv_6): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1

In [76]:
class BottleneckResidualBlock(nn.Module):

    def __init__(self, in_channels: int, bottleneck_channels: int, out_channels: int, stride: int):
        """
        * `in_channels` is the number of channels in $x$
        * `bottleneck_channels` is the number of channels for the $3 \times 3$ convlution
        * `out_channels` is the number of output channels
        * `stride` is the stride length in the $3 \times 3$ convolution operation.
        """
        super().__init__()

        # First $1 \times 1$ convolution layer, this maps to `bottleneck_channels`
        self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, stride=1)
        # Batch normalization after the first convolution
        self.bn1 = nn.BatchNorm2d(bottleneck_channels)
        # First activation function (ReLU)
        self.act1 = nn.ReLU()

        # Second $3 \times 3$ convolution layer
        self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1)
        # Batch normalization after the second convolution
        self.bn2 = nn.BatchNorm2d(bottleneck_channels)
        # Second activation function (ReLU)
        self.act2 = nn.ReLU()

        # Third $1 \times 1$ convolution layer, this maps to `out_channels`.
        self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, stride=1)
        # Batch normalization after the second convolution
        self.bn3 = nn.BatchNorm2d(out_channels)

        # Shortcut connection should be a projection if the stride length is not $1$
        # of if the number of channels change
        if stride != 1 or in_channels != out_channels:
            # Projection $W_s x$
            self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
        else:
            # Identity $x$
            self.shortcut = nn.Identity()

        # Second activation function (ReLU) (after adding the shortcut)
        self.act3 = nn.ReLU()

    def forward(self, x: torch.Tensor):
        """
        * `x` is the input of shape `[batch_size, in_channels, height, width]`
        """
        # Get the shortcut connection
        shortcut = self.shortcut(x)
        # First convolution and activation
        x = self.act1(self.bn1(self.conv1(x)))
        # Second convolution and activation
        x = self.act2(self.bn2(self.conv2(x)))
        # Third convolution
        x = self.bn3(self.conv3(x))
        # Activation function after adding the shortcut
        return self.act3(x + shortcut)

In [None]:
class ResNetBase(nn.Module):

    def __init__(self,
                 n_blocks: List[int],
                 n_channels: List[int],
                 bottlenecks: Optional[List[int]] = None,
                 img_channels: int = 3,
                 first_kernel_size: int = 7
                 ) -> None:

        """
        n_blocks is a list of of number of blocks for each feature map size.
        n_channels is the number of channels for each feature map size.
        bottlenecks is the number of channels the bottlenecks. If this is None , residual blocks are used.
        img_channels is the number of channels in the input.
        first_kernel_size is the kernel size of the initial convolution layer
        """

        super().__init__()


In [75]:
x = torch.rand(64, 3, 128, 128)
resblock(x).shape

torch.Size([64, 5, 128, 128])

In [77]:
BottleneckResidualBlock?

[0;31mInit signature:[0m
[0mBottleneckResidualBlock[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0min_channels[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbottleneck_channels[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mout_channels[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstride[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x)