In [1]:
from torch import nn
import torch
import warnings


warnings.filterwarnings('ignore')

In [4]:
class ConvBlock(nn.Module):
    """
    Normal convolution block
    """
    def __init__(self, filter_width, input_filters, nb_filters, dilation, batch_norm):
        super(ConvBlock, self).__init__()
        self.filter_width = filter_width
        self.input_filters = input_filters
        self.nb_filters = nb_filters
        self.dilation = dilation
        self.batch_norm = batch_norm
        self.conv1 = nn.Conv2d(self.input_filters, self.nb_filters, (self.filter_width, 1), dilation=(self.dilation, 1))
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(self.nb_filters, self.nb_filters, (self.filter_width, 1), dilation=(self.dilation, 1))
        if self.batch_norm:
            self.norm1 = nn.BatchNorm2d(self.nb_filters)
            self.norm2 = nn.BatchNorm2d(self.nb_filters)

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        if self.batch_norm:
            out = self.norm1(out)
        out = self.conv2(out)
        out = self.relu(out)
        if self.batch_norm:
            out = self.norm2(out)
        return out

In [7]:
class DeepConvLSTM(nn.Module):
    def __init__(self):

        super(DeepConvLSTM, self).__init__()


        self.pooling               = False
        self.reduce_layer          = False
        self.reduce_layer_output   = 8
        self.pool_type             = 'max'
        self.pool_kernel_width     = 2
        self.window_size           = 100
        self.drop_prob             = 0.5
        self.nb_channels           = 9
        self.nb_classes            = 6
        self.weights_init          = "xavier_normal"
        self.seed                  = 1
        # convolution settings
        self.nb_conv_blocks        = 2
        self.conv_block_type       = "normal"
        self.use_fixup             = False
        
        self.nb_filters            = 64
        self.filter_width          = 21
        self.dilation              = 1
        self.batch_norm            = False
        # lstm settings
        self.nb_units_lstm         = 128
        self.nb_layers_lstm        = 2

        # define conv layers
        self.conv_blocks = []
        for i in range(self.nb_conv_blocks):
            if i == 0:
                input_filters = 1
            else:
                input_filters = self.nb_filters
            self.conv_blocks.append(
                    ConvBlock(self.filter_width, input_filters, self.nb_filters, self.dilation, self.batch_norm))
            
        self.conv_blocks = nn.ModuleList(self.conv_blocks)


        self.final_seq_len = self.window_size - (self.filter_width - 1) * (self.nb_conv_blocks * 2)


        self.lstm_layers = []
        for i in range(self.nb_layers_lstm):
            if i == 0:
                self.lstm_layers.append(nn.LSTM(self.nb_channels * self.nb_filters, self.nb_units_lstm))
            else:
                self.lstm_layers.append(nn.LSTM(self.nb_units_lstm, self.nb_units_lstm))
        self.lstm_layers = nn.ModuleList(self.lstm_layers)
        # define dropout layer
        self.dropout = nn.Dropout(self.drop_prob)
        # define classifier

        self.fc = nn.Linear(self.nb_units_lstm, self.nb_classes)

    def forward(self, x):
        # reshape data for convolutions
        x = x.view(-1, 1, self.window_size, self.nb_channels)
        for i, conv_block in enumerate(self.conv_blocks):
            x = conv_block(x)

        # permute dimensions and reshape for LSTM
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(-1, self.final_seq_len, self.nb_filters * self.nb_channels)

        for lstm_layer in self.lstm_layers:
            x, _ = lstm_layer(x)
        # reshape data for classifier
        x = x.view(-1, self.nb_units_lstm)

        x = self.dropout(x)
        x = self.fc(x)
        # reshape data and return predicted label of last sample within final sequence (determines label of window)
        out = x.view(-1, self.final_seq_len, self.nb_classes)

        return out[:, -1, :]

    def number_of_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [8]:
model = DeepConvLSTM()

In [10]:
model.number_of_parameters()

753990