# <center> GeNNus CNN Mel Spectrogram conv to fully-connected size

A set of utilities to compute the size of the fully-connected in_features, when the previous layer is a convolutional layer.

This is needed in all of our CNNs, since we add a final, fully-connected (or linear) layer.

In [1]:
import torch

from torch import nn

In [2]:
class CNN(nn.Module):
  def __init__(
    self, 
    num_layers, 
    kernel_sizes, strides, 
    in_channels, num_filters,
    pool_sizes, pool_strides,
    dropout_p_conv
  ):
    super().__init__()

    self.num_layers = num_layers 
    self.kernel_sizes = kernel_sizes 
    self.strides = strides 
    self.in_channels = in_channels 
    self.num_filters = num_filters
    self.pool_sizes = pool_sizes 
    self.pool_strides = pool_strides
    
    self.dropout_p_conv = dropout_p_conv

    
    self.conv_layers = nn.Sequential()
    self.fc_layers = nn.Sequential()
    
    for i in range(num_layers):
      
      conv_layer = nn.Conv2d(
        kernel_size=self.kernel_sizes[i],
        stride=self.strides[i],
        in_channels=in_channels,
        out_channels=self.num_filters[i]
      )
      torch.nn.init.xavier_uniform_(conv_layer.weight)

      
      pooling_layer = nn.MaxPool2d(
        kernel_size=self.pool_sizes[i],
        stride=self.pool_strides[i],
      )
      
      in_channels = self.num_filters[i]
      
      self.conv_layers.add_module(name=f"conv_{i}", module=conv_layer)
      
      self.conv_layers.add_module(name=f"pool_{i}", module=pooling_layer)
        
      self.conv_layers.add_module(
        name=f"batchnorm_{i}", 
        module=nn.BatchNorm2d(num_features=self.num_filters[i])
      )
        
      self.conv_layers.add_module(name=f"activ_{i}", module=nn.ReLU())

      self.conv_layers.add_module(
        name=f"dropout_{i}", module=nn.Dropout(p=self.dropout_p_conv)
      )
  
  def forward(self, x):    
    x = self.conv_layers(x)

    return x

  def get_model_setup(self):
    
    return {
      "num_layers": self.num_layers, 
      "kernel_sizes": self.kernel_sizes, 
      "strides": self.strides, 
      "in_channels": self.in_channels, 
      "num_filters": self.num_filters,
      "pool_sizes": self.pool_sizes, 
      "pool_strides": self.pool_strides,
      "dropout_p_conv": self.dropout_p_conv,
      "dropout_p_linear": self.dropout_p_linear,
    }

In [13]:
k_fold_cv_kernel_sizes = [   3,    3,   3]
k_fold_cv_pool_sizes   = [   3,    3,   3]
k_fold_cv_strides      = [   2,    2,   1]
k_fold_cv_pool_strides = [   2,    2,   1]
k_fold_cv_num_filters  = [   1,    2,   4]

k_fold_cv_num_layers = len(k_fold_cv_num_filters)

In [14]:
cnn = CNN(
  kernel_sizes = k_fold_cv_kernel_sizes,
  pool_sizes = k_fold_cv_pool_sizes,
  strides = k_fold_cv_strides,
  pool_strides = k_fold_cv_pool_strides,
  num_filters = k_fold_cv_num_filters,
  num_layers=k_fold_cv_num_layers,
  in_channels=1,
  dropout_p_conv=0.0
)

In [15]:
out = cnn(torch.rand((64, 1, 128, 1860)))

In [16]:
print(f"{out.shape[1]} * {out.shape[2]} * {out.shape[3]}")

4 * 3 * 111
