In [6]:
import cbs
import torch
import torch.nn as nn
import torch.nn.functional as F
import ctraining_data

# basic RESNET convolution block
class SimpleResidualConv2D(nn.Module):
    def __init__(self, channels: int, H: int, W: int, kernel_size: int):
        super().__init__()
        pad = int((kernel_size-1)//2)
        self.conv1 = nn.LazyConv2d(channels, kernel_size, padding=pad, groups=1)
        self.bn1 = nn.LazyBatchNorm2d()
        self.conv2 = nn.LazyConv2d(channels, kernel_size, padding=pad, groups=1)
        self.bn2 = nn.LazyBatchNorm2d()
    
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        Y += X
        return F.relu(Y)
    
class JoeyBlock(nn.Module):
    def __init__(self, channels_in, 
                 channels_out, 
                 Hin, Win, 
                 ckernel_size, 
                 pkernel_size,
                 ppadding,
                 pstride):
        super().__init__()
        self.channels_in = channels_in
        self.channels_out = channels_out
        self.Hout = int((Hin+2*ppadding-pkernel_size-2)/pstride + 1)
        self.Wout = int((Win+2*ppadding-pkernel_size-2)/pstride + 1)
        self.res = SimpleResidualConv2D(channels_in, Hin, Win, ckernel_size)
        self.conv1x1 = nn.Conv2d(channels_in, channels_out, 1)
        self.pool = nn.MaxPool2d(pkernel_size, stride=pstride, padding=ppadding)
    
    def forward(self, X):
        return self.pool(self.conv1x1(self.res(X)))
    
class JoeyNet(nn.Module):
    def __init__(self, channels_in, Hin, Win, features_out, num_blocks, pkernel_size):
        super().__init__()
        if (Hin - num_blocks*(pkernel_size - 2)) <= 0:
            raise ValueError("(Hin - num_blocks*(pkernel_size - 2)) <= 0")
        if (Win - num_blocks*(pkernel_size - 2)) <= 0:
            raise ValueError("(Win - num_blocks*(pkernel_size - 2)) <= 0")
        H = [Hin - i*(pkernel_size-1) for i in range(num_blocks+1)]
        print(H)
        W = [Win - i*(pkernel_size-1) for i in range(num_blocks+1)]
        print(W)
        self.blocks = [] 
        for i in range(num_blocks):
            self.blocks.append(
                JoeyBlock(channels_in, channels_in, H[i], W[i], 3, pkernel_size, 0, 1)
            )
        self.fc = []
        for i in range(1,4):
            self.fc.append(
                nn.LazyLinear(int(channels_in - i * (channels_in - features_out)//3))
            )
    
    def forward(self, X):
        Y = X
        for i in range(len(self.blocks)):
            Y = self.blocks[i](Y)
        Y = torch.flatten(Y)
        for i in range(len(self.fc)-1):
            Y = F.relu(self.fc[i](Y))
        return self.fc[len(self.fc)-1](Y)
        

In [8]:
foo = JoeyNet(31, 12, 12, 15, 5, 3)
zeros = torch.zeros((1,31,12,12))
bar = foo(zeros)
bar.size()
bar

[12, 10, 8, 6, 4, 2]
[12, 10, 8, 6, 4, 2]




tensor([ 0.1129, -0.1937,  0.1256, -0.1859,  0.1628, -0.3325,  0.2598,  0.2407,
        -0.1520,  0.0782, -0.0084,  0.2027, -0.2863,  0.2861,  0.0365],
       grad_fn=<ViewBackward0>)