In [39]:
from __future__ import print_function
import argparse
import shutil

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from torchvision import models

import numpy as np
from numpy import mean, sqrt, square, arange
import pandas as pd

import logging
import csv

import os
from os.path import exists
import datetime
import time

import shutil


from wfdb import processing


In [16]:
class Sequence(nn.Module):

    def __init__(self):
        super(Sequence, self).__init__()
        input_size = 1
        self.input_size = input_size

        hidden_layers1 = 256
        hidden_layers2 = 128
        hidden_layers3 = 64
        hidden_layers4 = 32
        hidden_layers5 = 16

        self.hidden_layers1 = hidden_layers1
        self.hidden_layers2 = hidden_layers2
        self.hidden_layers3 = hidden_layers3
        self.hidden_layers4 = hidden_layers4
        self.hidden_layers5 = hidden_layers5

        self.lstm1 = nn.LSTMCell(self.input_size, self.hidden_layers1)
        self.lstm2 = nn.LSTMCell(self.hidden_layers1, self.hidden_layers2)
        self.lstm3 = nn.LSTMCell(self.hidden_layers2, self.hidden_layers3)
        self.lstm4 = nn.LSTMCell(self.hidden_layers3, self.hidden_layers4)
        self.lstm5 = nn.LSTMCell(self.hidden_layers4, self.hidden_layers5)

        self.linear = nn.Linear(self.hidden_layers5, 1)

    def forward(self, inputData):
        outputs = []
        h_t = torch.zeros(inputData.size(0), self.hidden_layers1, dtype=torch.double)
        c_t = torch.zeros(inputData.size(0), self.hidden_layers1, dtype=torch.double)

        h_t2 = torch.zeros(inputData.size(0), self.hidden_layers2, dtype=torch.double)
        c_t2 = torch.zeros(inputData.size(0), self.hidden_layers2, dtype=torch.double)

        h_t3 = torch.zeros(inputData.size(0), self.hidden_layers3, dtype=torch.double)
        c_t3 = torch.zeros(inputData.size(0), self.hidden_layers3, dtype=torch.double)

        h_t4 = torch.zeros(inputData.size(0), self.hidden_layers4, dtype=torch.double)
        c_t4 = torch.zeros(inputData.size(0), self.hidden_layers4, dtype=torch.double)

        h_t5 = torch.zeros(inputData.size(0), self.hidden_layers5, dtype=torch.double)
        c_t5 = torch.zeros(inputData.size(0), self.hidden_layers5, dtype=torch.double)
        
        for input_t in inputData.split(1, dim=1):

            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            h_t3, c_t3 = self.lstm3(h_t2, (h_t3, c_t3))
            h_t4, c_t4 = self.lstm4(h_t3, (h_t4, c_t4))
            h_t5, c_t5 = self.lstm5(h_t4, (h_t5, c_t5))

            output = self.linear(h_t5)
            outputs += [output]

        return torch.cat(outputs, dim=1)

In [17]:
model_path = "../../LSTM/Losses/SPECGmodelNEO_i100.pt"

In [20]:
seq = Sequence()
seq.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
seq.eval()

Sequence(
  (lstm1): LSTMCell(1, 256)
  (lstm2): LSTMCell(256, 128)
  (lstm3): LSTMCell(128, 64)
  (lstm4): LSTMCell(64, 32)
  (lstm5): LSTMCell(32, 16)
  (linear): Linear(in_features=16, out_features=1, bias=True)
)

In [58]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_parameters = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_parameters+=params
    print(table)
    print(f"Total Trainable Parameters: {total_parameters}")
    return total_parameters

In [59]:
count_parameters(seq)

+-----------------+------------+
|     Modules     | Parameters |
+-----------------+------------+
| lstm1.weight_ih |    1024    |
| lstm1.weight_hh |   262144   |
|  lstm1.bias_ih  |    1024    |
|  lstm1.bias_hh  |    1024    |
| lstm2.weight_ih |   131072   |
| lstm2.weight_hh |   65536    |
|  lstm2.bias_ih  |    512     |
|  lstm2.bias_hh  |    512     |
| lstm3.weight_ih |   32768    |
| lstm3.weight_hh |   16384    |
|  lstm3.bias_ih  |    256     |
|  lstm3.bias_hh  |    256     |
| lstm4.weight_ih |    8192    |
| lstm4.weight_hh |    4096    |
|  lstm4.bias_ih  |    128     |
|  lstm4.bias_hh  |    128     |
| lstm5.weight_ih |    2048    |
| lstm5.weight_hh |    1024    |
|  lstm5.bias_ih  |     64     |
|  lstm5.bias_hh  |     64     |
|  linear.weight  |     16     |
|   linear.bias   |     1      |
+-----------------+------------+
Total Trainable Parameters: 528273


528273