# ISIBrnoAIMT Encoder with Attention Decoder

Encoder was taken from the winner of the [Will Two Do?](https://physionet.org/content/challenge-2021/1.0.3/sources/) challenge [ISIBrnoAIMT](https://www.cinc.org/archives/2021/pdf/CinC2021-014.pdf)
Decoder was taken from the [sequence to sequence tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html) from Pytorch.

In [1]:
import pickle
import os
import sys
import torch
import pandas as pd
from sklearn.metrics import f1_score, jaccard_score, confusion_matrix, precision_score, recall_score, accuracy_score

from models.m04_EcgToText_ISIBrnoAIMT.dataset import *
from models.m04_EcgToText_ISIBrnoAIMT.model import *
from models.m04_EcgToText_ISIBrnoAIMT.train import *

In [2]:
os.chdir('..')

## Setup Model

In [3]:
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

language, dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='train', device=device)
_, val_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='val', device=device, _lang=language)

n_epochs=50
hidden_size = 256
criterion = nn.NLLLoss()

encoder = NN(num_leads=12,
             hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size,
                         encoder_hidden_size=hidden_size,
                         output_size=language.n_words,
                         max_len=language.max_len).to(device)

In [4]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [5]:
n_param = count_parameters(encoder)
print(f"Number of parameters in Encoder: {n_param}")
encoder

Number of parameters in Encoder: 6545152


NN(
  (conv): Conv2d(12, 256, kernel_size=(1, 15), stride=(1, 2), padding=(0, 7), bias=False)
  (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (rb_0): MyResidualBlock(
    (conv1): Conv2d(256, 256, kernel_size=(1, 9), stride=(1, 2), padding=(0, 4), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(256, 256, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (idfunc_0): AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)
    (idfunc_1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (rb_1): MyResidualBlock(
    (conv1): Conv2d(256, 256, kernel_size=(1, 9), stride=(1, 2), padding=(0, 4), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(256, 256, kernel_size=(1, 9), st

In [6]:
n_param = count_parameters(decoder)
print(f"Number of parameters in Decoder: {n_param}")
decoder

Number of parameters in Decoder: 2219237


AttnDecoderRNN(
  (hidden_transform): Linear(in_features=256, out_features=256, bias=True)
  (embedding): Embedding(2788, 256)
  (attention): BahdanauAttention(
    (Wa): Linear(in_features=256, out_features=256, bias=True)
    (Ua): Linear(in_features=256, out_features=256, bias=True)
    (Va): Linear(in_features=256, out_features=1, bias=True)
  )
  (gru): GRU(512, 256, batch_first=True)
  (out): Linear(in_features=256, out_features=2788, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

## Train

In [7]:
torch.manual_seed(42)

train(dataloader, val_dataloader, encoder, decoder, criterion, language, n_epochs)

0m 50s (- 41m 22s) (1 2.0%) | Train Loss: 0.5817 | Val METEOR: 0.3364
1m 34s (- 37m 56s) (2 4.0%) | Train Loss: 0.2498 | Val METEOR: 0.3847
2m 20s (- 36m 33s) (3 6.0%) | Train Loss: 0.2119 | Val METEOR: 0.4025
3m 5s (- 35m 37s) (4 8.0%) | Train Loss: 0.1894 | Val METEOR: 0.4157
3m 51s (- 34m 47s) (5 10.0%) | Train Loss: 0.1756 | Val METEOR: 0.4164
4m 39s (- 34m 10s) (6 12.0%) | Train Loss: 0.1639 | Val METEOR: 0.4284
5m 26s (- 33m 22s) (7 14.0%) | Train Loss: 0.1537 | Val METEOR: 0.4231
6m 12s (- 32m 35s) (8 16.0%) | Train Loss: 0.1462 | Val METEOR: 0.4361
6m 58s (- 31m 47s) (9 18.0%) | Train Loss: 0.1396 | Val METEOR: 0.4311
7m 44s (- 30m 58s) (10 20.0%) | Train Loss: 0.136 | Val METEOR: 0.4577
8m 30s (- 30m 8s) (11 22.0%) | Train Loss: 0.1291 | Val METEOR: 0.4563
9m 16s (- 29m 22s) (12 24.0%) | Train Loss: 0.1244 | Val METEOR: 0.4563
10m 3s (- 28m 36s) (13 26.0%) | Train Loss: 0.1214 | Val METEOR: 0.4492
10m 48s (- 27m 48s) (14 28.0%) | Train Loss: 0.1175 | Val METEOR: 0.4537
11m 34s

## Test

In [8]:
# load parameters
hidden_size = 256

encoder = NN(num_leads=12,
             hidden_size=hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size=hidden_size,
                         encoder_hidden_size=hidden_size,
                         output_size=language.n_words,
                         max_len=language.max_len).to(device)

encoder.load_state_dict(torch.load('./models/m04_EcgToText_ISIBrnoAIMT/saved_models/Encoder.pth'))
decoder.load_state_dict(torch.load('./models/m04_EcgToText_ISIBrnoAIMT/saved_models/Decoder.pth'))

<All keys matched successfully>

In [9]:
_, test_dataloader = get_dataloader(file_path='./data_ptb-xl', batch_size=64, mode='test', device=device, _lang=language)

total_loss, f1, jaccard, rouge, meteor = validate_epoch(test_dataloader, encoder, decoder, criterion, language)

print(f'Test Loss:    {round(total_loss, 4)}')
print(f'F1:           {round(f1, 4)}')
print(f'Jaccard:      {round(jaccard, 4)}')
print(f'Rouge-1 (p):  {round(rouge["rouge-1"]["p"], 3)}')
print(f'Rouge-1 (r):  {round(rouge["rouge-1"]["r"], 3)}')
print(f'Rouge-1 (f1): {round(rouge["rouge-1"]["f"], 3)}')
print(f'Rouge-2 (p):  {round(rouge["rouge-2"]["p"], 3)}')
print(f'Rouge-2 (r):  {round(rouge["rouge-2"]["r"], 3)}')
print(f'Rouge-2 (f1): {round(rouge["rouge-2"]["f"], 3)}')
print(f'Rouge-L (p):  {round(rouge["rouge-l"]["p"], 3)}')
print(f'Rouge-L (r):  {round(rouge["rouge-l"]["r"], 3)}')
print(f'Rouge-L (f1): {round(rouge["rouge-l"]["f"], 3)}')
print(f'METEOR:       {round(meteor, 3)}')

Test Loss:    3.8925
F1:           0.0245
Jaccard:      0.0159
Rouge-1 (p):  0.582
Rouge-1 (r):  0.622
Rouge-1 (f1): 0.577
Rouge-2 (p):  0.445
Rouge-2 (r):  0.472
Rouge-2 (f1): 0.439
Rouge-L (p):  0.578
Rouge-L (r):  0.618
Rouge-L (f1): 0.574
METEOR:       0.495


In [10]:
encoder.eval()
decoder.eval()
print_first_n_target_prediction(test_dataloader, encoder, decoder, language)

= sinus rhythm and junctional rhythm. non-specific t wave flattening in i, avl, v5,6. slightly prolonged qt, this may be due to a drug    effect or an electrolyte disturbance.
< atrial fibrillation. st segments are depressed in i, ii, v5,6. t waves are low in i, v6 and inverted in avl. this may be due to lv strain or ischaemia.

= sinus rhythm extreme left electrical axis nonspecific leg block
< sinus rhythm. left axis deviation. left anterior fascicular block. voltages are high in limb leads suggesting lvh. qs complexes in v2 suggesting old anteroseptal infarct. st segments are depressed in i, avl, v5,6. t waves are inverted in avl, v5,6. this may be due to lv strain or ischaemia.

= sinus rhythm. normal ecg. edit: norm 100, <unk> bad quality
< sinus rhythm. normal ecg.

= <unk> <unk> sinus bradycardia suspected left ventricular hypertrophy pathological q-wave in v2, suggests suspicion of old anteroseptal myocardial damage t-changes anteriorly as in subendocardial myocardial affection