### Residual Network

An Example: Resnet for Images.

<div style="display:flex">
    <img src="https://miro.medium.com/v2/resize:fit:640/format:webp/0*sGlmENAXIZhSqyFZ"/>
    <pre>
        When using the Residual Network, our network looks like a series of blocks instead 
        of a series of layers. Each block has a series of layers, and the input of block gets
        added to the output of the last layer in the block. This is called a skip connection.
        The skip connection allows the network to learn the residual of the block, which is 
        the difference between the input and the output of the block. This makes it easier for
        the network to learn the identity function, which is the function that maps the input to
        the output. The skip connection also helps with the vanishing gradient problem, which is 
        when the gradient of the loss function becomes very small as it is propagated back through
        the network. The skip connection allows the gradient to bypass the block and go directly
        to the next block, which helps with the vanishing gradient problem. This allows the network
        to be very deep network that can learn very complex functions, and it has been very successful
        the tasks that require very deep networks.
    </pre>
</div>

Note: ResNets don't necessarily have to be used for images. They can be used for any type of data, such as text or audio.

In [2]:
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_chnl:int, out_chnl:int, downsample:nn.Module=None, dropout:float=0.5):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_chnl, out_chnl, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_chnl)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_chnl, out_chnl, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_chnl)
        
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.downsample = downsample
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            out = self.maxpool(out)              # downsamples by maxpooling by 2
            residual = self.downsample(residual) # halves the size by strides of 2

        out = out + residual
        out = self.dropout(out)
        out = self.relu(out)
        return out
    
class ResNet(nn.Module):
    def __init__(self, input_chnl:int):
        super(ResNet, self).__init__()

        in_chnls = [input_chnl, 32, 64, 128, 256]
        out_chnls = [32, 64, 128, 256, 512]

        downsamplers = [nn.Conv2d(i, o, kernel_size=1, stride=2, bias=False) for i,o in zip (in_chnls, out_chnls)]
        blocks = [ResidualBlock(i, o, d) for i, o, d in zip(in_chnls, out_chnls, downsamplers)]

        self.blocks = nn.Sequential(*blocks)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(512, 1)

    def forward(self, x):
        print(f"input   - {x.shape}")
        
        x = self.blocks(x);   print(f"blocks  - {x.shape}")
        x = self.avgpool(x);  print(f"avgpool - {x.shape}")
        x = self.flatten(x);  print(f"flatten - {x.shape}")
        x = self.linear(x);   print(f"linear  - {x.shape}\n")
        return x
    
model = ResNet(input_chnl=3)

# (batch, channel, height, width)
x = torch.randn(1, 3, 180, 180); y = torch.tensor([1.])

with torch.no_grad(): print(f"Pred: {model(x)}") # test forward pass

input   - torch.Size([1, 3, 180, 180])
blocks  - torch.Size([1, 512, 6, 6])
avgpool - torch.Size([1, 512, 1, 1])
flatten - torch.Size([1, 512])
linear  - torch.Size([1, 1])

Pred: tensor([[-0.7718]])
