In [1]:
import torch
import pandas
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn, Tensor

%matplotlib inline

In [None]:
class ResBlock(nn.Module):
    def __init__(
        self,
        inputChannel: int,
        outputChannel: int,
        stride: int,
        isUseConvolution: bool,
    ) -> None:
        
        """
        Arguments:
            inputChannel:       tell the input channel of inputX
            outputChannel:      tell the output channel of inputX
            stride:             if the stride == 2, the weight should be cut into half shape of origin inputX
            isUseConvolution:   use the 1 * 1 Convolution to transfor the inputX's channel into outputChannel
        """

        super(ResBlock, self).__init__()

        self.__type_name__ = 'Resnet Block'

        self.conv1 = nn.Conv2d(
            inputChannel,
            outputChannel,
            3,
            stride
        )

        self.conv2 = nn.Conv2d(
            outputChannel,
            outputChannel,
            3,
            1
        )

        if isUseConvolution:
            self.conv3 = nn.Conv2d(
                inputChannel,
                outputChannel,
                1,
                stride
            )
        
        self.bn1 = nn.BatchNorm2d(inputChannel)
        self.bn2 = nn.BatchNorm2d(outputChannel)
    
    def forward(self, inputs) -> Tensor:
        Y = F.relu(self.conv1(inputs))
        Y = F.relu(self.conv2(Y))

        # if haven't set the 1*1 convolution, F(x) = x + g(x)
        if self.conv3:
            inputs = self.conv3(inputs)
        
        Y += inputs

        return F.relu(Y)