<a href="https://colab.research.google.com/github/Idan-Alter/OU-22961-Deep-Learning/blob/main/22961_5_7_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

In [None]:
class ResBlockMLP(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.Z1   = nn.Linear(in_features, in_features)
        self.relu = nn.ReLU()
        self.bn   = nn.BatchNorm1d(in_features)
        self.Z2   = nn.Linear(in_features, in_features)
    def forward(self, X):
        Y1   = self.Z1(X)
        Y1   = self.bn(Y1)
        Y1   = self.relu(Y1)
        Y1   = self.Z2(Y1)
        Y2    = Y1+X           #Skip Connection
        Y   = self.relu(Y2)
        return Y

In [None]:
class ResBlockConv(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.relu  = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, in_channels,3,
                             padding="same",bias=False)
        self.bn1   = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels,3,
                             padding="same", bias=False)
        self.bn2   = nn.BatchNorm2d(in_channels)
    def forward(self, X):
        Y1  = self.conv1(X)
        Y1  = self.bn1(Y1)
        Y1  = self.relu(Y1)
        Y1  = self.conv2(Y1)
        Y1  = self.bn2(Y1)
        Y2  = Y1+X          
        Y = self.relu(Y2)
        return Y

In [None]:
class ResBlockDownSamp(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        out_channels = in_channels*2                            #                        
        self.relu  = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, out_channels,3,
                             padding=1, stride=2, bias=False)   #
        
        self.bn1   = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,3,
                             padding="same", bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
        
        self.downsampX = nn.Conv2d(in_channels, out_channels,1, #
                                 stride=2, bias=False)
    def forward(self, X):
        Y1  = self.conv1(X)
        Y1  = self.bn1(Y1)
        Y1  = self.relu(Y1)
        Y1  = self.conv2(Y1)
        Y1  = self.bn2(Y1)
        Y2  = Y1 + self.downsampX(X)                            #
        Y   = self.relu(Y2)
        return Y

In [None]:
import torchvision.models as models
resnet18 = models.resnet18()
print(resnet18)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
last_layers = list(resnet18.children())[-3:]
print(*last_layers, sep="\n")

Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): BasicBlock(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1