In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Conv3d layer input shape:
# N -> number of sequences (mini batch)
# Cin -> number of channels (3 for rgb)
# D -> Number of images in a sequence
# H -> Height of one image in the sequence
# W -> Width of one image in the sequence

class MnistExampleModel(nn.Module):
    
    def __init__(self):
        """
        Basic Architecture of CNN
        Attributes:
            num_filters: Number of filters, out channel for 1st and 2nd conv layers,
            kernel_size: Kernel size of convolution,
            dense_layer: Dense layer units,
            img_rows: Height of input image,
            img_cols: Width of input image,
            maxpool: Max pooling size
        """
        super(MnistExampleModel, self).__init__()
        self.conv1 = self._conv_layer_set(3, 3)
        self.conv2 = self._conv_layer_set(3, 3)
        self.conv3 = self._conv_layer_set(3, 3)
        self.fc1 = nn.Linear(3*8*8*8, 128)
        self.fc2 = nn.Linear(128, 3)
        self.relu = nn.ReLU()
        self.skip = nn.Identity()
        self.max_pool = nn.MaxPool3d(kernel_size=(2, 2, 2))
        self.fc1_bn = nn.BatchNorm1d(128)
        self.drop = nn.Dropout(p=0.25)
        self.avgpool = nn.AdaptiveAvgPool3d(8)# 256 x 1 x 1

    def _conv_layer_set(self, in_channels, out_channels):
        conv_layer = nn.Sequential(
            nn.Conv3d(
                in_channels, 
                out_channels, 
                kernel_size=(3, 3, 3), 
                stride=1,
                padding=1,
                ),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            )
        return conv_layer

    def forward(self, x):
        print('input shape:', x.shape)
        x = self.conv1(x)
        x = self.max_pool(x)


        x = self.conv2(x)
        x = self.max_pool(x)

        x = self.conv3(x)
        x = self.max_pool(x)
        # x = self.conv3_bn(x)
        print('before flatten shape:', x.shape)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        # print('after flatten shape:', x.shape)
        x = self.fc1(x)
        x = self.relu(x)
        # x = self.fc1_bn(x)
        x = self.drop(x)
        x = self.fc2(x)
        # print('output shape:', x.shape)

        return x

In [30]:
from torchsummary import summary
model = MnistExampleModel()
model.cuda()
summary(model, (3, 10, 256, 256))

input shape: torch.Size([2, 3, 10, 256, 256])
before flatten shape: torch.Size([2, 3, 1, 32, 32])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 3, 10, 256, 256]             246
       BatchNorm3d-2      [-1, 3, 10, 256, 256]               6
              ReLU-3      [-1, 3, 10, 256, 256]               0
         MaxPool3d-4       [-1, 3, 5, 128, 128]               0
            Conv3d-5       [-1, 3, 5, 128, 128]             246
       BatchNorm3d-6       [-1, 3, 5, 128, 128]               6
              ReLU-7       [-1, 3, 5, 128, 128]               0
         MaxPool3d-8         [-1, 3, 2, 64, 64]               0
            Conv3d-9         [-1, 3, 2, 64, 64]             246
      BatchNorm3d-10         [-1, 3, 2, 64, 64]               6
             ReLU-11         [-1, 3, 2, 64, 64]               0
        MaxPool3d-12         [-1, 3, 1, 32, 32]               0
Adapt