# ISIBrnoAIMT Encoder with Attention Decoder

Encoder was taken from the winner of the [Will Two Do?](https://physionet.org/content/challenge-2021/1.0.3/sources/) challenge [ISIBrnoAIMT](https://www.cinc.org/archives/2021/pdf/CinC2021-014.pdf)
Decoder was taken from the [sequence to sequence tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html) from Pytorch.

In [1]:
import pickle
import os
import sys
import torch
import pandas as pd
from sklearn.metrics import f1_score, jaccard_score, confusion_matrix, precision_score, recall_score, accuracy_score

from models.m04_EcgToText_ISIBrnoAIMT.dataset import *
from models.m04_EcgToText_ISIBrnoAIMT.model import *
from models.m04_EcgToText_ISIBrnoAIMT.train import *

In [2]:
os.chdir('..')

### Train

In [3]:
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

language, dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='train', device=device)
_, val_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='val', device=device, _lang=language)

n_epochs=50
hidden_size = 256
criterion = nn.NLLLoss()

encoder = NN(num_leads=12, hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size, output_size=language.n_words, max_len=language.max_len).to(device)

encoder, decoder

(NN(
   (conv): Conv2d(12, 256, kernel_size=(1, 15), stride=(1, 2), padding=(0, 7), bias=False)
   (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (rb_0): MyResidualBlock(
     (conv1): Conv2d(256, 256, kernel_size=(1, 9), stride=(1, 2), padding=(0, 4), bias=False)
     (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv2): Conv2d(256, 256, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
     (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (idfunc_0): AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)
     (idfunc_1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
   )
   (rb_1): MyResidualBlock(
     (conv1): Conv2d(256, 256, kernel_size=(1, 9), stride=(1, 2), padding=(0, 4), bias=False)
     (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv2): Conv2d(256, 256, kernel_

In [4]:
torch.manual_seed(42)

train(dataloader, val_dataloader, encoder, decoder, criterion, language, n_epochs)

0m 29s (- 24m 17s) (1 2.0%) | Train Loss: 0.8271 | Val Loss: 1.9763 | Jaccard: 0.0018 | F1: 0.0025 | 
	Rouge-1 (p): 0.4161 | Rouge-1 (r): 0.5798 | Rouge-1 (f1): 0.4451
	Rouge-2 (p): 0.29 | Rouge-2 (r): 0.3708 | Rouge-2 (f1): 0.2938
	Rouge-L (p): 0.4156 | Rouge-L (r): 0.5788 | Rouge-L (f1): 0.4445
	METEOR: 0.3096
0m 57s (- 22m 52s) (2 4.0%) | Train Loss: 0.3249 | Val Loss: 2.109 | Jaccard: 0.0044 | F1: 0.0069 | 
	Rouge-1 (p): 0.4966 | Rouge-1 (r): 0.5372 | Rouge-1 (f1): 0.4875
	Rouge-2 (p): 0.3512 | Rouge-2 (r): 0.3601 | Rouge-2 (f1): 0.3326
	Rouge-L (p): 0.4942 | Rouge-L (r): 0.5342 | Rouge-L (f1): 0.4849
	METEOR: 0.3879
1m 24s (- 22m 8s) (3 6.0%) | Train Loss: 0.2533 | Val Loss: 2.3466 | Jaccard: 0.0057 | F1: 0.0091 | 
	Rouge-1 (p): 0.4881 | Rouge-1 (r): 0.544 | Rouge-1 (f1): 0.4856
	Rouge-2 (p): 0.3427 | Rouge-2 (r): 0.3757 | Rouge-2 (f1): 0.3359
	Rouge-L (p): 0.4861 | Rouge-L (r): 0.5416 | Rouge-L (f1): 0.4835
	METEOR: 0.3711
1m 52s (- 21m 31s) (4 8.0%) | Train Loss: 0.2213 | Val Lo

In [5]:
# save parameters
# torch.save(encoder.state_dict(), './models/m04_EcgToText_ISIBrnoAIMT/saved_models/Encoder.pth')
# torch.save(decoder.state_dict(), './models/m04_EcgToText_ISIBrnoAIMT/saved_models/Decoder.pth')

In [6]:
# load parameters
hidden_size = 256

encoder = NN(num_leads=12, hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size, output_size=language.n_words, max_len=language.max_len).to(device)

encoder.load_state_dict(torch.load('./models/m04_EcgToText_ISIBrnoAIMT/saved_models/Encoder.pth'))
decoder.load_state_dict(torch.load('./models/m04_EcgToText_ISIBrnoAIMT/saved_models/Decoder.pth'))

<All keys matched successfully>

### Test

In [7]:
_, test_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='test', device=device, _lang=language)

total_loss, f1, jaccard, rouge, meteor = validate_epoch(dataloader, encoder, decoder, criterion, language)

print(f'Test Loss:    {round(total_loss, 4)}')
print(f'F1:           {round(f1, 4)}')
print(f'Jaccard:      {round(jaccard, 4)}')
print(f'Rouge-1 (p):  {round(rouge["rouge-1"]["p"], 3)}')
print(f'Rouge-1 (r):  {round(rouge["rouge-1"]["r"], 3)}')
print(f'Rouge-1 (f1): {round(rouge["rouge-1"]["f"], 3)}')
print(f'Rouge-2 (p):  {round(rouge["rouge-2"]["p"], 3)}')
print(f'Rouge-2 (r):  {round(rouge["rouge-2"]["r"], 3)}')
print(f'Rouge-2 (f1): {round(rouge["rouge-2"]["f"], 3)}')
print(f'Rouge-L (p):  {round(rouge["rouge-l"]["p"], 3)}')
print(f'Rouge-L (r):  {round(rouge["rouge-l"]["r"], 3)}')
print(f'Rouge-L (f1): {round(rouge["rouge-l"]["f"], 3)}')
print(f'METEOR:       {round(meteor, 4)}')

Test Loss:    2.7337
F1:           0.045
Jaccard:      0.0312
Rouge-1 (p):  0.705
Rouge-1 (r):  0.695
Rouge-1 (f1): 0.681
Rouge-2 (p):  0.58
Rouge-2 (r):  0.568
Rouge-2 (f1): 0.557
Rouge-L (p):  0.702
Rouge-L (r):  0.692
Rouge-L (f1): 0.678
METEOR:       0.6263


In [8]:
encoder.eval()
decoder.eval()
print_first_n_target_prediction(test_dataloader, encoder, decoder, language)

= sinus tachycardia left type qrs(t) abnormal lateral myocardial damage cannot be excluded
< sinus tachycardia position type normal st &amp; t abnormal, probably anterolateral ischemia or left strain

= sinus bradycardia <unk> pacemaker left type left bundle branch block ventricular hypertrophy pathological
< sinus bradycardia left type complete left bundle branch block st elevation in v1-4 t neg in ii,iii,avf,v5,6 pathological ecg

= sinus rhythm. left axis deviation. left anterior fascicular block. st segments are elevated in v2,3. t waves are inverted in i, ii, v4,5,6 with terminal inversion in v3. findings are likely to be due to ischaemic heart  disease and may represent <unk> edit: injas, injal, lafb, (injal 100, lafb)
< sinus rhythm. left axis deviation. left anterior fascicular block. t waves are inverted in i, avl, v3,4,5 and flat in v6. findings are likely to be due to ischaemic heart    disease. the age of the changes is uncertain.

= sinus rhythm position type normal amplit