<a href="https://colab.research.google.com/github/Rajitha-SL/My-AI-Projects/blob/AI-and-ML-learning/Transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Define a ResNet block in PyTorch

In [1]:
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
  def __init__(self):
    super().__init__()

    # We are defining a mini network
    # that is made of two standard convolutional layers
    # with the relu in-between
    self.conv_block = nn.Sequential(
        nn.Conv2d(inp, out1, 3),
        nn.ReLU(),
        nn.Conv2d(out1, out2, 3)
        # Note that after the second layer
        # there is no activation
    )

    self.relu = nn.ReLU()

  def forward(self, x):
    # F(x)
    F = self.conv_block(x)
    # Before we apply the second activation
    # we add back the input x
    # This is the implementation of the skip connection
    H = F + x
    return self.relu(H)


We can see above that if the optimizer puts all the convolutional filters to 0, then F will be 0 and H will be equal to x

Then the block will become the identity function.

# Following is how we use squeeze and execitation block in PyTorch


In [None]:
import torch
import torch.nn as nn

class SqueezeExcitation(nn.Module):
  def __init__(self, input_channels, squeeze_channels):
    super().__init__()

    # This is the squeeze part
    # It is a Global Average Pooling GAP layer
    self.squeeze = torch.nn.AdaptiveAvgPool2d(1)

    # This is the exitation part
    # This is a perceptron with two hidden layers and with a ReLU in between
    self.excitation = nn.Sequential(
        nn.Flatten(),
        nn.Linear(input_channels, squeeze_channels),
        nn.ReLU(),
        nn.Linear(squeeze_channels, input_channels),
        nn.Sigmoid()  # Squeezes the weights between 0 and 1
    )

  # In the forward part of this network
  # We first squeeze
  def forward(self,x):
    out = self.squeeze(x)
    # This is the excitation part
    scale = self.excitation(out).unsqueeze(-1).unsqueeze(-1)
    # Then we multiply the scaling factors by the input to boost or reduce
    # the importance of the input feature maps
    out = scale * x
    return out
