# RNN Encoder with Attention Decoder
Encoder and Decoder was taken from the [sequence to sequence tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html) from Pytorch.

In [1]:
import pickle
import os
import sys
import torch
import pandas as pd
from sklearn.metrics import f1_score, jaccard_score, confusion_matrix, precision_score, recall_score, accuracy_score

from models.m03_EcgToText_RNN.dataset import *
from models.m03_EcgToText_RNN.model import *
from models.m03_EcgToText_RNN.train import *

In [2]:
os.chdir('..')

## Simple RNN Encoder

### Train

In [3]:
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

language, dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='train', device=device)
_, val_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='val', device=device, _lang=language)

n_epochs=50
hidden_size = 256
criterion = nn.NLLLoss()

encoder = SimpleEncoderRNN(input_size=12, hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size, output_size=language.n_words, max_len=language.max_len).to(device)

encoder, decoder

(SimpleEncoderRNN(
   (gru): GRU(12, 256, batch_first=True)
   (dropout): Dropout(p=0.1, inplace=False)
 ),
 AttnDecoderRNN(
   (embedding): Embedding(2787, 256)
   (attention): BahdanauAttention(
     (Wa): Linear(in_features=256, out_features=256, bias=True)
     (Ua): Linear(in_features=256, out_features=256, bias=True)
     (Va): Linear(in_features=256, out_features=1, bias=True)
   )
   (gru): GRU(512, 256, batch_first=True)
   (out): Linear(in_features=256, out_features=2787, bias=True)
   (dropout): Dropout(p=0.1, inplace=False)
 ))

In [4]:
torch.manual_seed(42)

train(dataloader, val_dataloader, encoder, decoder, criterion, language, n_epochs)

2m 30s (- 122m 43s) (1 2.0%) | Train Loss: 0.6629 | Val Loss: 2.0492 | Jaccard: 0.1379 | F1: 0.1698 | 
	Rouge-1 (p): 0.4138 | Rouge-1 (r): 0.5537 | Rouge-1 (f1): 0.4404
	Rouge-2 (p): 0.2808 | Rouge-2 (r): 0.3325 | Rouge-2 (f1): 0.2805
	Rouge-L (p): 0.4133 | Rouge-L (r): 0.5528 | Rouge-L (f1): 0.4398
	METEOR: 0.3039
5m 0s (- 120m 3s) (2 4.0%) | Train Loss: 0.2866 | Val Loss: 2.3245 | Jaccard: 0.1735 | F1: 0.2245 | 
	Rouge-1 (p): 0.4599 | Rouge-1 (r): 0.3804 | Rouge-1 (f1): 0.3882
	Rouge-2 (p): 0.3212 | Rouge-2 (r): 0.2418 | Rouge-2 (f1): 0.2549
	Rouge-L (p): 0.4584 | Rouge-L (r): 0.3784 | Rouge-L (f1): 0.3866
	METEOR: 0.333
7m 28s (- 116m 59s) (3 6.0%) | Train Loss: 0.2448 | Val Loss: 2.3634 | Jaccard: 0.1887 | F1: 0.2445 | 
	Rouge-1 (p): 0.4644 | Rouge-1 (r): 0.496 | Rouge-1 (f1): 0.4481
	Rouge-2 (p): 0.3323 | Rouge-2 (r): 0.3322 | Rouge-2 (f1): 0.3077
	Rouge-L (p): 0.4629 | Rouge-L (r): 0.4938 | Rouge-L (f1): 0.4464
	METEOR: 0.3527
9m 56s (- 114m 17s) (4 8.0%) | Train Loss: 0.2185 | V

In [5]:
# save parameters
# torch.save(encoder.state_dict(), './models/m03_EcgToText_RNN/saved_models/SimpleEncoderRNN.pth')
# torch.save(decoder.state_dict(), './models/m03_EcgToText_RNN/saved_models/SimpleDecoderRNN.pth')

In [7]:
# load parameters
hidden_size = 256

encoder = SimpleEncoderRNN(input_size=12, hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size, output_size=language.n_words, max_len=language.max_len).to(device)

encoder.load_state_dict(torch.load('./models/m03_EcgToText_RNN/saved_models/SimpleEncoderRNN.pth'))
decoder.load_state_dict(torch.load('./models/m03_EcgToText_RNN/saved_models/SimpleDecoderRNN.pth'))

<All keys matched successfully>

### Test

In [8]:
_, test_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='test', device=device, _lang=language)

total_loss, f1, jaccard, rouge, meteor = validate_epoch(dataloader, encoder, decoder, criterion, language)

print(f'Test Loss:    {round(total_loss, 5)}')
print(f'F1:           {round(f1, 5)}')
print(f'Jaccard:      {round(jaccard, 5)}')
print(f'Rouge-1 (p):  {round(rouge["rouge-1"]["p"], 3)}')
print(f'Rouge-1 (r):  {round(rouge["rouge-1"]["r"], 3)}')
print(f'Rouge-1 (f1): {round(rouge["rouge-1"]["f"], 3)}')
print(f'Rouge-2 (p):  {round(rouge["rouge-2"]["p"], 3)}')
print(f'Rouge-2 (r):  {round(rouge["rouge-2"]["r"], 3)}')
print(f'Rouge-2 (f1): {round(rouge["rouge-2"]["f"], 3)}')
print(f'Rouge-L (p):  {round(rouge["rouge-l"]["p"], 3)}')
print(f'Rouge-L (r):  {round(rouge["rouge-l"]["r"], 3)}')
print(f'Rouge-L (f1): {round(rouge["rouge-l"]["f"], 3)}')
print(f'METEOR:       {round(meteor, 4)}')

Test Loss:    2.37621
F1:           0.57021
Jaccard:      0.45567
Rouge-1 (p):  0.677
Rouge-1 (r):  0.718
Rouge-1 (f1): 0.677
Rouge-2 (p):  0.556
Rouge-2 (r):  0.583
Rouge-2 (f1): 0.552
Rouge-L (p):  0.674
Rouge-L (r):  0.715
Rouge-L (f1): 0.674
METEOR:       0.5851


In [9]:
encoder.eval()
decoder.eval()
print_first_n_target_prediction(test_dataloader, encoder, decoder, language)

= sinus rhythm suspected p-sinistrocardiale position type normal qrs(t) abnormal inferior myocardial damage possible t abnormal in anterior leads lateral leads
< sinus rhythm position type normal st &amp; t abnormal, probably anterolateral ischemia or left strain

= sinus rhythm. incomplete right bundle branch block. otherwise normal ecg.
< sinus rhythm. normal ecg.

= sinus rhythm. normal ecg. edit: norm 100, (norm 100)
< sinus rhythm position type normal normal ecg

= sinus rhythm position type normal normal ecg
< sinus rhythm position type normal normal ecg

= sinus rhythm. prolonged pr interval. non-specific intraventricular delay. st segments are depressed in i, ii, iii, avf, v4,5,6. t waves are slightly inverted in these leads. this may be due to lv strain or ischaemia.
< sinus rhythm av block i position type normal st &amp; t abnormal, probably anterolateral ischemia or left strain inferolateral ischemia or left strain

= sinus rhythm position type normal normal ecg 4.46 unconfi

## RNN Encoder

### Train

In [10]:
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

_lang, dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='train', device=device)
_, val_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='val', device=device, _lang=_lang)

n_epochs=50
hidden_size = 256
criterion = nn.NLLLoss()

encoder = EncoderRNN(num_leads=12, hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size, output_size=language.n_words, max_len=language.max_len).to(device)

encoder, decoder

(EncoderRNN(
   (conv1): Conv1d(12, 24, kernel_size=(5,), stride=(1,), padding=(2,))
   (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
   (conv2): Conv1d(24, 48, kernel_size=(5,), stride=(1,), padding=(2,))
   (pool2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
   (dropout): Dropout(p=0.1, inplace=False)
   (gru): GRU(48, 256, batch_first=True)
 ),
 AttnDecoderRNN(
   (embedding): Embedding(2787, 256)
   (attention): BahdanauAttention(
     (Wa): Linear(in_features=256, out_features=256, bias=True)
     (Ua): Linear(in_features=256, out_features=256, bias=True)
     (Va): Linear(in_features=256, out_features=1, bias=True)
   )
   (gru): GRU(512, 256, batch_first=True)
   (out): Linear(in_features=256, out_features=2787, bias=True)
   (dropout): Dropout(p=0.1, inplace=False)
 ))

In [11]:
torch.manual_seed(42)

train(dataloader, val_dataloader, encoder, decoder, criterion, language, n_epochs)

0m 27s (- 22m 45s) (1 2.0%) | Train Loss: 0.6671 | Val Loss: 2.1255 | Jaccard: 0.1402 | F1: 0.1741 | 
	Rouge-1 (p): 0.442 | Rouge-1 (r): 0.3723 | Rouge-1 (f1): 0.3742
	Rouge-2 (p): 0.3055 | Rouge-2 (r): 0.232 | Rouge-2 (f1): 0.2422
	Rouge-L (p): 0.4415 | Rouge-L (r): 0.3718 | Rouge-L (f1): 0.3737
	METEOR: 0.3147
0m 55s (- 22m 9s) (2 4.0%) | Train Loss: 0.2876 | Val Loss: 2.3191 | Jaccard: 0.1676 | F1: 0.2125 | 
	Rouge-1 (p): 0.4672 | Rouge-1 (r): 0.3889 | Rouge-1 (f1): 0.3958
	Rouge-2 (p): 0.3323 | Rouge-2 (r): 0.2528 | Rouge-2 (f1): 0.2653
	Rouge-L (p): 0.4657 | Rouge-L (r): 0.3869 | Rouge-L (f1): 0.3942
	METEOR: 0.3402
1m 23s (- 21m 41s) (3 6.0%) | Train Loss: 0.2437 | Val Loss: 2.4262 | Jaccard: 0.1907 | F1: 0.2478 | 
	Rouge-1 (p): 0.4764 | Rouge-1 (r): 0.4396 | Rouge-1 (f1): 0.4265
	Rouge-2 (p): 0.34 | Rouge-2 (r): 0.2965 | Rouge-2 (f1): 0.2929
	Rouge-L (p): 0.4749 | Rouge-L (r): 0.4379 | Rouge-L (f1): 0.425
	METEOR: 0.3583
1m 50s (- 21m 14s) (4 8.0%) | Train Loss: 0.2177 | Val Los

In [12]:
# save parameters
# torch.save(encoder.state_dict(), './models/m03_EcgToText_RNN/saved_models/EncoderRNN.pth')
# torch.save(decoder.state_dict(), './models/m03_EcgToText_RNN/saved_models/DecoderRNN.pth')

In [13]:
# load parameters
hidden_size = 256

encoder = EncoderRNN(num_leads=12, hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size, output_size=language.n_words, max_len=language.max_len).to(device)

encoder.load_state_dict(torch.load('./models/m03_EcgToText_RNN/saved_models/EncoderRNN.pth'))
decoder.load_state_dict(torch.load('./models/m03_EcgToText_RNN/saved_models/DecoderRNN.pth'))

<All keys matched successfully>

### Test

In [14]:
_, test_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='test', device=device, _lang=language)

total_loss, f1, jaccard, rouge, meteor = validate_epoch(dataloader, encoder, decoder, criterion, language)

print(f'Test Loss:    {round(total_loss, 5)}')
print(f'F1:           {round(f1, 5)}')
print(f'Jaccard:      {round(jaccard, 5)}')
print(f'Rouge-1 (p):  {round(rouge["rouge-1"]["p"], 3)}')
print(f'Rouge-1 (r):  {round(rouge["rouge-1"]["r"], 3)}')
print(f'Rouge-1 (f1): {round(rouge["rouge-1"]["f"], 3)}')
print(f'Rouge-2 (p):  {round(rouge["rouge-2"]["p"], 3)}')
print(f'Rouge-2 (r):  {round(rouge["rouge-2"]["r"], 3)}')
print(f'Rouge-2 (f1): {round(rouge["rouge-2"]["f"], 3)}')
print(f'Rouge-L (p):  {round(rouge["rouge-l"]["p"], 3)}')
print(f'Rouge-L (r):  {round(rouge["rouge-l"]["r"], 3)}')
print(f'Rouge-L (f1): {round(rouge["rouge-l"]["f"], 3)}')
print(f'METEOR:       {round(meteor, 4)}')

Test Loss:    2.41486
F1:           0.56926
Jaccard:      0.45522
Rouge-1 (p):  0.683
Rouge-1 (r):  0.709
Rouge-1 (f1): 0.675
Rouge-2 (p):  0.559
Rouge-2 (r):  0.573
Rouge-2 (f1): 0.548
Rouge-L (p):  0.68
Rouge-L (r):  0.706
Rouge-L (f1): 0.672
METEOR:       0.5891


In [15]:
encoder.eval()
decoder.eval()
print_first_n_target_prediction(test_dataloader, encoder, decoder, language)

= regular rhythm, no p-wave detected, left-sided nonspecific intraventricular block, left hypertrophy, qrs(t) abnormal, anteroseptal infarction, possible acute inferior infarction, age undetermined.
< sinus tachycardia hyperexcited left type left bundle branch block left hypertrophy possible 4.46 unconfirmed report

= sinus rhythm position type normal t abnormal in high lateral leads pathological
< sinus rhythm position type normal normal ecg

= supraventricular extrasystole(s) sinus rhythm av block i hyperrotated left type qrs(t) abnormal anterior myocardial damage inferior myocardial damage st &amp; t abnormal, probably high lateral ischemia or left strain
< sinus rhythm hyperexcited left type left anterior hemiblock right bundle branch block bifascicular block qrs(t) abnormal anteroseptal myocardial damage possible

= sinus rhythm p-widening position type normal 4.46 unconfirmed report
< sinus rhythm position type normal normal ecg 4.46 unconfirmed report

= sinus rhythm av block i 