In [2]:

import torch
from torch.nn import(
    Module,
    Linear,
    ReLU,
    Conv2d,
    Sequential,
    MaxPool2d,
    Flatten,
    Dropout,
    BatchNorm2d
)
class Residual(Module):
    def __init__(self,input_channels,output_channels,kernel,strides,use_lxlconv=False):
        super().__init__()
        self.conv1 = Conv2d(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=kernel,
            padding=1,
            stride=strides
        )
        self.conv2 = Conv2d(
            in_channels=output_channels,
            out_channels=output_channels,
            kernel_size=1,
            stride=strides
        )

        if use_lxlconv:
            self.conv3 = Conv2d(
                in_channels=input_channels,
                out_channels=output_channels,
                kernel_size=1,
                stride=strides
            )
        else:
            self.conv3 = None
        self.bn1 = BatchNorm2d(output_channels)
        self.bn2 = BatchNorm2d(output_channels)
        self.rlu = ReLU(inplace=True)
    def forward(self,x):
        output_1 = torch.nn.functional.relu(self.bn1(self.conv1(x)))
        output_1 = self.bn2(self.conv2(output_1))
        if self.conv3:
            output_2 = self.conv3(x)
            return torch.nn.functional.relu(output_1 + output_2)
        else:
            return torch.nn.functional.relu(x + output_1)


In [None]:
blk = Residual(3,3)