In [14]:
import torch
from torch import nn


In [22]:
class BasicConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, hasRelu= True) -> None:
        super(BasicConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.batchNorm = nn.BatchNorm2d(out_channels)
        self.hasRelu = hasRelu
        if hasRelu:
            self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.batchNorm(x)
        if self.hasRelu:
            x = self.relu(x)
        return x

BasicConvBlock(256, 64, 1, 1, 1)

BasicConvBlock(
  (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
  (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
)

In [19]:
class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, factor=4) -> None:
        super(BottleneckBlock, self).__init__()
        hidden_channels = in_channels // factor
        self.layers = nn.Sequential(
            BasicConvBlock(in_channels, hidden_channels, 1),
            BasicConvBlock(hidden_channels, hidden_channels, 3, padding=1),
            BasicConvBlock(hidden_channels, in_channels, 1)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x0 = x
        x1 = self.layers(x)
        x = x0 + x1
        return self.relu(x)
    
print(BottleneckBlock(256))
print(BottleneckBlock(256)(torch.zeros((3, 256, 32, 32))).shape)

BottleneckBlock(
  (layers): Sequential(
    (0): BasicConvBlock(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicConvBlock(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): BasicConvBlock(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (relu): ReLU(inplace=True)
)
torch.Size([3, 256, 32, 32])


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self) -> None:
        super(ResidualBlock, self).__init__()
        self.layer1 = BasicConvBlock()
        self.layer2 = BasicConvBlock()
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return x

In [None]:
class HrStage1(nn.Module):
    def __init__(self) -> None:
        super(HrStage1, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d()
        )