In [1]:
import torch
from torch import nn
from torchinfo import summary
from positional_encodings.torch_encodings import PositionalEncoding1D, PositionalEncoding2D, PositionalEncoding3D, Summer
import sys
sys.path.append('../../')
from Model import CNN as cnn

In [12]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.encoder_name = "CNN"

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class ConvNetBlock_large_2s(nn.Module):
    def __init__(self):
        super(ConvNetBlock_large_2s, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=22, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.mpv1 = nn.MaxPool2d(kernel_size=3, stride=(2, 1),padding=1)
        self.block1 = ResidualBlock(32, 64, kernel_size=3, stride=1, padding=1)
        self.mpv2 = nn.MaxPool2d(kernel_size=3, stride=(2, 1),padding=1)
        self.block2 = ResidualBlock(64, 128, kernel_size=3, stride=1, padding=1)
        self.mpv3 = nn.MaxPool2d(kernel_size=3, stride=2,padding=1)
        self.block3 = ResidualBlock(128, 256, kernel_size=3, stride=1, padding=1)
        self.mpv4 = nn.MaxPool2d(kernel_size=3, stride=2,padding=1)
        self.block4 = ResidualBlock(256, 512, kernel_size=3, stride=1, padding=1)
        self.mpv5 = nn.MaxPool2d(kernel_size=3, stride=2,padding=1)
        self.block5 = ResidualBlock(512, 768, kernel_size=3, stride=1, padding=1)
        self.mpv6 = nn.AdaptiveAvgPool2d((1, 1))
        self.Flatten = nn.Flatten()

        self.encoder_name = "CNN"

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.mpv1(x)
        x = self.block1(x)
        x = self.mpv2(x)
        x = self.block2(x)
        x = self.mpv3(x)
        x = self.block3(x)
        x = self.mpv4(x)
        x = self.block4(x)
        x = self.mpv5(x)
        x = self.block5(x)
        x = self.mpv6(x)
        x = self.Flatten(x)
        return x

In [13]:
summary(ConvNetBlock_large_2s(), input_size=(10, 22, 65,9))
# summary(cnn.ConvNetBlock_small(), input_size=(10, 22, 65,21))

Layer (type:depth-idx)                   Output Shape              Param #
ConvNetBlock_large_2s                    [10, 768]                 --
├─Conv2d: 1-1                            [10, 32, 65, 9]           17,632
├─BatchNorm2d: 1-2                       [10, 32, 65, 9]           64
├─ReLU: 1-3                              [10, 32, 65, 9]           --
├─MaxPool2d: 1-4                         [10, 32, 33, 9]           --
├─ResidualBlock: 1-5                     [10, 64, 33, 9]           --
│    └─Conv2d: 2-1                       [10, 64, 33, 9]           18,496
│    └─BatchNorm2d: 2-2                  [10, 64, 33, 9]           128
│    └─ReLU: 2-3                         [10, 64, 33, 9]           --
│    └─Conv2d: 2-4                       [10, 64, 33, 9]           36,928
│    └─BatchNorm2d: 2-5                  [10, 64, 33, 9]           128
│    └─Sequential: 2-6                   [10, 64, 33, 9]           --
│    │    └─Conv2d: 3-1                  [10, 64, 33, 9]           2,04