In [None]:
from torchsummary import summary

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# BGRU

In [None]:
class StackedBiGRUNet(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(StackedBiGRUNet, self).__init__()
        self.num_layers = len(hidden_sizes)
        self.hidden_sizes = hidden_sizes
        self.gru_layers = nn.ModuleList()

        # Create bidirectional GRU layers
        for i in range(self.num_layers):
            if i == 0:
                input_dim = input_size
            else:
                input_dim = self.hidden_sizes[i-1] * 2  # Multiply by 2 for bidirectional

            self.gru_layers.append(nn.GRU(input_dim, self.hidden_sizes[i], batch_first=True, bidirectional=True))

        # Final fully connected layer
        self.final_fc = nn.Linear(self.hidden_sizes[-1] * 2, output_size)

    def forward(self, x):

        for i in range(self.num_layers):
            # Initialize hidden state with zeros
            h0 = torch.zeros(2, x.size(0), self.hidden_sizes[i]).to(x.device)  # *2 for bidirectional

            # Forward propagate GRU
            out, _ = self.gru_layers[i](x, h0)

            # # Decode the hidden state of the last time step
            # out = self.fc_layers[i](out[:, -1, :])
            x = out

        # Extract the output of the last time step of the last layer
        out = out[:, -1, :]

        # Final fully connected layer
        out = self.final_fc(out)
        return out

In [None]:
class TheActualGRU(nn.Module):

  def __init__(self):
    super(TheActualGRU, self).__init__()

    self.bgru_1 = nn.GRU(300, 300, batch_first = True, bidirectional = True) # output: x: (N,L,2∗H_out), h_n: (2∗num_layers,N,H_out​)
    self.bgru_2 = nn.GRU(600, 200, batch_first = True, bidirectional = True)
    self.dropout_1 = nn.Dropout(p = 0.2)
    self.bgru_3 = nn.GRU(400, 50, batch_first = True, bidirectional = True)
    self.fc1 = nn.Linear(100, 50)
    self.dropout_2 = nn.Dropout(p = 0.2)
    self.fc2 = nn.Linear(50, 3)

  def forward(self, x):
    # h0 = torch.zeros(2, x.size(0), 300).to(x.device) # defaults to zero if not provided

    x, _ = self.bgru_1(x)
    x, _ = self.bgru_2(x)
    x = self.dropout_1(x)
    x, _ = self.bgru_3(x)
    x = x[:, -1, :]
    x = self.fc1(x)
    x = self.dropout_2(x)
    x = self.fc2(x)

    return F.softmax(x, dim=1)

In [None]:
summary(TheActualGRU(), (5, 300))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
               GRU-1  [[-1, 5, 600], [-1, 2, 300]]               0
               GRU-2  [[-1, 5, 400], [-1, 2, 200]]               0
           Dropout-3               [-1, 5, 400]               0
               GRU-4  [[-1, 5, 100], [-1, 2, 50]]               0
            Linear-5                   [-1, 50]           5,050
           Dropout-6                   [-1, 50]               0
            Linear-7                    [-1, 3]             153
Total params: 5,203
Trainable params: 5,203
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 20.20
Params size (MB): 0.02
Estimated Total Size (MB): 20.23
----------------------------------------------------------------


# CNN

In [None]:
class ConvBlock(nn.Module):

  def __init__(self):
    super(ConvBlock, self).__init__()

    self.conv_0_1 = nn.Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding='same')
    self.bn_0 = nn.BatchNorm3d(15)
    self.conv_0_p = nn.Conv3d(15, 15, kernel_size=(2, 2, 2), stride=(2, 2, 2))

    self.conv_1_1 = nn.Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding='same')
    self.bn_1 = nn.BatchNorm3d(15)
    self.conv_1_2 = nn.Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding='same')

    # merge1: layer 0 and layer 1

    self.bn_2_1 = nn.BatchNorm3d(30)
    self.conv_2_1 = nn.Conv3d(30, 25, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    self.bn_2_2 = nn.BatchNorm3d(25)
    self.conv_2_2 = nn.Conv3d(25, 25, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding='same')
    self.conv_2_cut = nn.Conv3d(30, 15, kernel_size=(2, 2, 2), stride=(2, 2, 2)) # input from merge 1

    # merge2: layer2 and the cut2

    self.bn_3_1 = nn.BatchNorm3d(40)
    self.conv_3_1 = nn.Conv3d(40, 35, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    self.bn_3_2 = nn.BatchNorm3d(35)
    self.conv_3_2 = nn.Conv3d(35, 35, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding='same')
    self.conv_3_cut = nn.Conv3d(40, 25, kernel_size=(2, 2, 2), stride=(2, 2, 2)) # input from merge 2

    # merge3: layer3 and the cut3

    self.bn_4 = nn.BatchNorm3d(60)
    self.conv_4_1 = nn.Conv3d(60, 30, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding='valid')
    self.conv_4_2 = nn.Conv3d(30, 30, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding='valid')


  def forward(self, x):

    x = self.conv_0_1(x)
    x = F.relu(self.bn_0(x))
    x1 = self.conv_0_p(x)

    x = self.conv_1_1(x1)
    x = F.dropout(F.relu(self.bn_1(x)), p=0.2) #?
    x = self.conv_1_2(x)

    x2 = torch.cat((x1, x), dim=1) # merge_1

    x = F.relu(self.bn_2_1(x2))
    x = self.conv_2_1(x)
    x = F.relu(self.bn_2_2(F.dropout(x, p=0.2)))
    x = self.conv_2_2(x)

    xc2 = self.conv_2_cut(x2)
    x3 = torch.cat((x, xc2), dim=1) # merge_2

    x = F.relu(self.bn_3_1(x3))
    x = self.conv_3_1(x)
    x = F.relu(self.bn_3_2(F.dropout(x, p=0.2)))
    x = self.conv_3_2(x)

    xc3 = self.conv_3_cut(x3)
    x4 = torch.cat((x, xc3), dim=1) # merge_3

    x = F.relu(self.bn_4(x4))
    x = self.conv_4_1(x)
    x = self.conv_4_2(x)

    return x


In [None]:
summary(ConvBlock(), (1, 50, 42, 42))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 15, 50, 42, 42]             420
       BatchNorm3d-2       [-1, 15, 50, 42, 42]              30
            Conv3d-3       [-1, 15, 25, 21, 21]           1,815
            Conv3d-4       [-1, 15, 25, 21, 21]           6,090
       BatchNorm3d-5       [-1, 15, 25, 21, 21]              30
            Conv3d-6       [-1, 15, 25, 21, 21]           6,090
       BatchNorm3d-7       [-1, 30, 25, 21, 21]              60
            Conv3d-8       [-1, 25, 12, 10, 10]           6,025
       BatchNorm3d-9       [-1, 25, 12, 10, 10]              50
           Conv3d-10       [-1, 25, 12, 10, 10]          16,900
           Conv3d-11       [-1, 15, 12, 10, 10]           3,615
      BatchNorm3d-12       [-1, 40, 12, 10, 10]              80
           Conv3d-13          [-1, 35, 6, 5, 5]          11,235
      BatchNorm3d-14          [-1, 35, 

In [None]:
class ActualCNN(nn.Module):

  def __init__(self):
    super(ActualCNN, self).__init__()

    self.conv_block = ConvBlock()

    self.flatten = nn.Flatten()

    self.fc1 = nn.Linear(60, 300)
    self.fc2 = nn.Linear(300, 50)
    self.fc3 = nn.Linear(50, 3)

  def forward(self, x):

    x = self.conv_block(x)

    x = self.flatten(x)

    x = F.relu(F.dropout(self.fc1(x), p=0.2))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)

    return F.softmax(x, dim=1)

In [None]:
summary(ActualCNN(), (1, 50, 42, 42))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 15, 50, 42, 42]             420
       BatchNorm3d-2       [-1, 15, 50, 42, 42]              30
            Conv3d-3       [-1, 15, 25, 21, 21]           1,815
            Conv3d-4       [-1, 15, 25, 21, 21]           6,090
       BatchNorm3d-5       [-1, 15, 25, 21, 21]              30
            Conv3d-6       [-1, 15, 25, 21, 21]           6,090
       BatchNorm3d-7       [-1, 30, 25, 21, 21]              60
            Conv3d-8       [-1, 25, 12, 10, 10]           6,025
       BatchNorm3d-9       [-1, 25, 12, 10, 10]              50
           Conv3d-10       [-1, 25, 12, 10, 10]          16,900
           Conv3d-11       [-1, 15, 12, 10, 10]           3,615
      BatchNorm3d-12       [-1, 40, 12, 10, 10]              80
           Conv3d-13          [-1, 35, 6, 5, 5]          11,235
      BatchNorm3d-14          [-1, 35, 