# Model Exploration

This jupyter notebook is intended to be used for validating potential new neural network architecture designs.

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from datasets import PhysioNet2020Dataset


In [2]:
ds = PhysioNet2020Dataset(
    "Training_WFDB",
    max_seq_len=6000,
    records=("A0001", "A0002", "A0003", "A0004"), # (7500, 5000, 5000, 5974)
    proc=0,
    ensure_equal_len=False,
    derive_fft=True
)
dl = DataLoader(
    ds,
    batch_size=8,
    num_workers=0,
    collate_fn=PhysioNet2020Dataset.collate_fn
)


In [3]:
class ExplorationModel(nn.Module):
    def __init__(
        self,
        in_channels=12,
        num_classes=9,
        num_layers=2,
        dropout=0.1,
        hidden_size=200,
        bidirectional=True
    ):
        super(ExplorationModel, self).__init__()
        
        self.bidirectional = bidirectional
        self.lstm_sig = nn.LSTM(
            input_size = in_channels,
            hidden_size = hidden_size,
            num_layers = num_layers,
            dropout = dropout,
            bidirectional = bidirectional
        )
        
        self.lstm_fft = nn.LSTM(
            input_size = in_channels,
            hidden_size = hidden_size,
            num_layers = num_layers,
            dropout = dropout,
            bidirectional = bidirectional
        )
        
        lstm_hidden_size = hidden_size * 2
        if bidirectional:
            lstm_hidden_size *= 2

        self.classify = nn.Sequential(
            nn.BatchNorm1d(lstm_hidden_size),
            nn.LeakyReLU(),
            nn.Dropout(dropout, inplace=True),
            nn.Linear(lstm_hidden_size, num_classes)
        )

    def forward(self, batch):
        sig = batch["signal"]
        fft = batch["fft"]
        sig_lens = batch["len"]

        lstm_sig_in = pack_padded_sequence(sig, sig_lens, enforce_sorted=False)
        _, (sig_hidden, _) = self.lstm_sig(lstm_sig_in)
        # out, lens = pad_packed_sequence(packed_out)

        lstm_fft_in = pack_padded_sequence(fft, tuple(l//2 for l in sig_lens), enforce_sorted=False)
        _, (fft_hidden, _) = self.lstm_sig(lstm_fft_in)

        if self.bidirectional:
            # concat the forward and backward hidden states
            sig_hidden  = torch.cat((sig_hidden[-2,:,:], sig_hidden[-1,:,:]), dim=1)
            fft_hidden = torch.cat((fft_hidden[-2,:,:], fft_hidden[-1,:,:]), dim=1)
        else:
            sig_hidden = sig_hidden[-1,:,:]
            fft_hidden = fft_hidden[-1,:,:]
    
        hidden = torch.cat((sig_hidden, fft_hidden), dim=1)
        
        out = self.classify(hidden)

        return out

In [4]:
m = ExplorationModel()
batch = next(iter(dl))
out = m(batch)

print(out.shape)

m = ExplorationModel(bidirectional=False)
batch = next(iter(dl))
out = m(batch)

print(out.shape)

torch.Size([5, 9])
torch.Size([5, 9])
