# RNN Encoder with Attention Decoder
Encoder and Decoder was taken from the [sequence to sequence tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html) from Pytorch.

In [1]:
import pickle
import os
import sys
import torch
import pandas as pd
from sklearn.metrics import f1_score, jaccard_score, confusion_matrix, precision_score, recall_score, accuracy_score

from models.m03_EcgToText_RNN.dataset import *
from models.m03_EcgToText_RNN.model import *
from models.m03_EcgToText_RNN.train import *

In [2]:
os.chdir('..')

## Setup Model

In [3]:
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

language, dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='train', device=device)
_, val_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='val', device=device, _lang=language)

n_epochs=50
hidden_size = 256
criterion = nn.NLLLoss()

encoder = EncoderRNN(num_leads=12,
                     hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size,
                         encoder_hidden_size=hidden_size,
                         output_size=language.n_words,
                         max_len=language.max_len).to(device)

In [4]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [5]:
n_param = count_parameters(encoder)
print(f"Number of parameters in Encoder: {n_param}")
encoder

Number of parameters in Encoder: 242280


EncoderRNN(
  (conv1): Conv1d(12, 24, kernel_size=(5,), stride=(1,), padding=(2,))
  (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(24, 48, kernel_size=(5,), stride=(1,), padding=(2,))
  (pool2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout(p=0.1, inplace=False)
  (gru): GRU(48, 256, batch_first=True)
)

In [6]:
n_param = count_parameters(decoder)
print(f"Number of parameters in Decoder: {n_param}")
decoder 

Number of parameters in Decoder: 2218724


AttnDecoderRNN(
  (hidden_transform): Linear(in_features=256, out_features=256, bias=True)
  (embedding): Embedding(2787, 256)
  (attention): BahdanauAttention(
    (Wa): Linear(in_features=256, out_features=256, bias=True)
    (Ua): Linear(in_features=256, out_features=256, bias=True)
    (Va): Linear(in_features=256, out_features=1, bias=True)
  )
  (gru): GRU(512, 256, batch_first=True)
  (out): Linear(in_features=256, out_features=2787, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

## Train

In [7]:
torch.manual_seed(42)

train(dataloader, val_dataloader, encoder, decoder, criterion, language, n_epochs)

1m 4s (- 52m 46s) (1 2.0%) | Train Loss: 0.5825 | Val METEOR: 0.3122
2m 6s (- 50m 33s) (2 4.0%) | Train Loss: 0.266 | Val METEOR: 0.3364
3m 6s (- 48m 46s) (3 6.0%) | Train Loss: 0.2304 | Val METEOR: 0.3507
4m 10s (- 47m 55s) (4 8.0%) | Train Loss: 0.2077 | Val METEOR: 0.3676
5m 13s (- 47m 2s) (5 10.0%) | Train Loss: 0.1883 | Val METEOR: 0.3728
6m 14s (- 45m 47s) (6 12.0%) | Train Loss: 0.1719 | Val METEOR: 0.4039
7m 17s (- 44m 45s) (7 14.0%) | Train Loss: 0.1586 | Val METEOR: 0.4032
8m 19s (- 43m 41s) (8 16.0%) | Train Loss: 0.1484 | Val METEOR: 0.443
9m 20s (- 42m 33s) (9 18.0%) | Train Loss: 0.1393 | Val METEOR: 0.4168
10m 24s (- 41m 36s) (10 20.0%) | Train Loss: 0.1315 | Val METEOR: 0.4289
11m 26s (- 40m 33s) (11 22.0%) | Train Loss: 0.1255 | Val METEOR: 0.4491
12m 27s (- 39m 25s) (12 24.0%) | Train Loss: 0.1191 | Val METEOR: 0.4282
13m 29s (- 38m 25s) (13 26.0%) | Train Loss: 0.1141 | Val METEOR: 0.4398
14m 32s (- 37m 23s) (14 28.0%) | Train Loss: 0.109 | Val METEOR: 0.4538
15m 33s

In [8]:
# save parameters
# torch.save(encoder.state_dict(), './models/m03_EcgToText_RNN/saved_models/EncoderRNN.pth')
# torch.save(decoder.state_dict(), './models/m03_EcgToText_RNN/saved_models/DecoderRNN.pth')

In [9]:
# load parameters
hidden_size = 256

encoder = EncoderRNN(num_leads=12,
                     hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size,
                         encoder_hidden_size=hidden_size,
                         output_size=language.n_words,
                         max_len=language.max_len).to(device)

encoder.load_state_dict(torch.load('./models/m03_EcgToText_RNN/saved_models/EncoderRNN.pth'))
decoder.load_state_dict(torch.load('./models/m03_EcgToText_RNN/saved_models/DecoderRNN.pth'))

<All keys matched successfully>

## Test

In [10]:
_, test_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='test', device=device, _lang=language)

total_loss, f1, jaccard, rouge, meteor = validate_epoch(test_dataloader, encoder, decoder, criterion, language)

print(f'Test Loss:    {round(total_loss, 5)}')
print(f'F1:           {round(f1, 5)}')
print(f'Jaccard:      {round(jaccard, 5)}')
print(f'Rouge-1 (p):  {round(rouge["rouge-1"]["p"], 3)}')
print(f'Rouge-1 (r):  {round(rouge["rouge-1"]["r"], 3)}')
print(f'Rouge-1 (f1): {round(rouge["rouge-1"]["f"], 3)}')
print(f'Rouge-2 (p):  {round(rouge["rouge-2"]["p"], 3)}')
print(f'Rouge-2 (r):  {round(rouge["rouge-2"]["r"], 3)}')
print(f'Rouge-2 (f1): {round(rouge["rouge-2"]["f"], 3)}')
print(f'Rouge-L (p):  {round(rouge["rouge-l"]["p"], 3)}')
print(f'Rouge-L (r):  {round(rouge["rouge-l"]["r"], 3)}')
print(f'Rouge-L (f1): {round(rouge["rouge-l"]["f"], 3)}')
print(f'METEOR:       {round(meteor, 3)}')

Test Loss:    3.18871
F1:           0.43462
Jaccard:      0.32197
Rouge-1 (p):  0.566
Rouge-1 (r):  0.606
Rouge-1 (f1): 0.561
Rouge-2 (p):  0.424
Rouge-2 (r):  0.451
Rouge-2 (f1): 0.416
Rouge-L (p):  0.562
Rouge-L (r):  0.602
Rouge-L (f1): 0.556
METEOR:       0.464


In [11]:
encoder.eval()
decoder.eval()
print_first_n_target_prediction(test_dataloader, encoder, decoder, language)

= sinus rhythm. incomplete right bundle branch block.
< sinus rhythm. normal ecg.

= sinus rhythm. normal ecg.
< sinus rhythm position type normal normal ecg

= sinus rhythm. normal ecg.
< sinus rhythm position type normal incomplete right bundle branch block otherwise normal ecg

= sinus rhythm. q waves in ii, iii, avf consistent with old inferior    myocardial infarction. otherwise no definite pathology.
< sinus rhythm. q waves in ii, iii, avf consistent with old inferior    myocardial infarction.

= sinus rhythm. normal ecg.
< sinus rhythm position type normal normal ecg

= sinus rhythm qrs(t) abnormal inferior infarction probably old
< sinus rhythm position type normal normal ecg

= sinus rhythm av block i left type chronic inferior infarction possible intraventricular conduction disturbance path. ecg
< sinus rhythm left type qrs(t) abnormal anteroseptal myocardial damage cannot be excluded inferior infarction probable old 4.46 unconfirmed report

= sinus rhythm. q waves in ii, iii