In [5]:
import torch
import torch.nn as nn
import torchsummary
import math

In [6]:
class LDB(nn.Module):
    def __init__(self,inD,outD,dropout) -> None:
        super().__init__()
        inD = int(inD)
        outD = int(outD)
        self.seq = nn.Sequential(
            nn.Linear(inD,outD),
            nn.Dropout1d(p = dropout),
            nn.BatchNorm1d(outD)
        )
    def forward(self,x):
        return self.seq(x)


def conv_formula(d_in,kernel_size,padding = 0,dilation = 1,stride = 1):
    return math.floor(((d_in+(2*padding)-(dilation*(kernel_size-1))-1)/stride)+1)


class CDB(nn.Module):
    def __init__(self,x_in,y_in,in_channels:int,out_channels:int,kernel_size:int,stride = 1,padding = 0,dilation = 1,dropout = .2,) -> None:
        super().__init__()
        self.seq = nn.Sequential( # type: ignore
            nn.Conv2d(in_channels=in_channels,
                      out_channels= out_channels,
                      kernel_size=kernel_size,
                      stride = stride,
                      padding = padding,
                      dilation=dilation
                      ),
            nn.BatchNorm2d(out_channels),
            nn.Dropout2d(dropout),
        )
        self.x_out = conv_formula(x_in,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation)
        self.y_out = conv_formula(y_in,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation)
        self.out_channels = out_channels
    def forward(self,x):
        return self.seq(x)

def new_CDB(x_in,y_in,in_channels:int = 1,out_channels:int = 1,kernel_size:int = 3,stride = 1,padding = 0,dilation = 1,dropout = .2,device = "cpu"):
    cdb = CDB(x_in,y_in,in_channels,out_channels,kernel_size,stride,padding,dilation,dropout).to(device)
    return cdb,cdb.x_out,cdb.y_out,cdb.out_channels

In [11]:
class hyper_cnn_ann_cog(nn.Module):
    def __init__(self, x=400, y=400, out_classes=19, device="cpu") -> None:
        super().__init__()
        self.name = "multi_cnn_ann_cog"
        self.norm1 = nn.BatchNorm2d(1).to(device)
        self.conv1, xout, yout, out_channels = new_CDB(
            x_in=400, y_in=400, in_channels=1, out_channels=8, kernel_size=3, stride=2, dropout=0, device=device)
        self.conv2, xout, yout, out_channels = new_CDB(
            x_in=xout, y_in=yout, in_channels=out_channels, out_channels=16, kernel_size=3, stride=2, dropout=0, device=device)
        self.conv3, xout, yout, out_channels = new_CDB(
            x_in=xout, y_in=yout, in_channels=out_channels, out_channels=32, kernel_size=3, stride=2, dropout=0, device=device)
        self.conv4, xout, yout, out_channels = new_CDB(
            x_in=xout, y_in=yout, in_channels=out_channels, out_channels=64, kernel_size=3, stride=2, dropout=0, device=device)
        self.conv5, xout, yout, out_channels = new_CDB(
            x_in=xout, y_in=yout, in_channels=out_channels, out_channels=128, kernel_size=3, stride=2, dropout=0, device=device)
        self.conv6, xout, yout, out_channels = new_CDB(
            x_in=xout, y_in=yout, in_channels=out_channels, out_channels=256, kernel_size=3, stride=2, dropout=0, device=device)
        print(xout, yout, out_channels, xout*yout*out_channels)
        self.flatten = nn.Flatten().to(device)
        self.l1 = LDB(xout*yout*out_channels, 2048, dropout=0).to(device)
        self.l2 = LDB(2048, 1024, 0).to(device)
        self.l3 = LDB(1024, out_classes, 0).to(device)
        self.seq = nn.Sequential(self.flatten, self.l1, self.l2, self.l3)
        self.softmax = nn.Softmax(-1).to(device)

    def forward(self, x: torch.Tensor):
        xs = x.shape
        print(xs)
        ix = x.view((xs[0], 1, xs[1], xs[2]))
        ix = self.norm1(ix)
        ix = self.conv1(ix)
        ix = self.conv2(ix)
        ix = self.conv3(ix)
        ix = self.conv4(ix)
        ix = self.conv5(ix)
        ix = self.conv6(ix)
        ix = self.seq(ix)
        return self.softmax(ix)


model = hyper_cnn_ann_cog(400, 400, 19, "cuda:0")
torchsummary.summary(model, (400, 400))


5 5 256 6400
torch.Size([2, 400, 400])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm2d-1          [-1, 1, 400, 400]               2
            Conv2d-2          [-1, 8, 199, 199]              80
       BatchNorm2d-3          [-1, 8, 199, 199]              16
         Dropout2d-4          [-1, 8, 199, 199]               0
               CDB-5          [-1, 8, 199, 199]               0
            Conv2d-6           [-1, 16, 99, 99]           1,168
       BatchNorm2d-7           [-1, 16, 99, 99]              32
         Dropout2d-8           [-1, 16, 99, 99]               0
               CDB-9           [-1, 16, 99, 99]               0
           Conv2d-10           [-1, 32, 49, 49]           4,640
      BatchNorm2d-11           [-1, 32, 49, 49]              64
        Dropout2d-12           [-1, 32, 49, 49]               0
              CDB-13           [-1, 32, 49, 49]               0
