In [2]:
# import all packages needed
import string 
import numpy as np
import pandas as pd
from matplotlib import pyplot
from base64 import b64decode as decode

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

## Data Processing / Cleaning

In [3]:
# use class base64 to decode waveform data
def to_array(wf):
    barr = bytearray(decode(wf))
    vals = np.array(barr)
    return vals.view(np.int16)

# read in data
exam_data = pd.read_csv("data/d_exam.csv").drop(columns = ["site_num", "patient_id_edit"])
waveform_data = pd.read_csv("data/d_waveform.csv")
lead_data = pd.read_csv("data/d_lead_data.csv").drop(columns = ["exam_id"])
diagnosis_data = pd.read_csv("data/d_diagnosis.csv").drop(columns = ["user_input"])

# add decoded data as a column to lead data
waveforms = list(lead_data['waveform_data'])
lead_data['decoded_waveform'] = [to_array(i) for i in waveforms]

# merge waveform data and lead data
waveform_lead = lead_data.merge(waveform_data, how = "left", left_on = "waveform_id", right_on = "waveform_id", suffixes = (None, None))

#  sort by exam id and lead id
waveform_lead.sort_values(by = ["waveform_id", "lead_id"], inplace = True)

waveform_lead.loc[:, ['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']]


# adding the diagnosis and labels
waveform_and_diag = pd.merge(waveform_lead[['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']], diagnosis_data[["exam_id", "Full_text", "Original_Diag"]], left_on= "exam_id", right_on="exam_id")
waveform_and_diag

Unnamed: 0,exam_id,lead_id,decoded_waveform,waveform_type,Full_text,Original_Diag
0,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,No previous ECGs available,0
1,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,Otherwise normal ECG,0
2,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,Sinus bradycardia,0
3,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,,0
4,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,Sinus bradycardia,1
...,...,...,...,...,...,...
1419,554080,V6,"[10, 10, 10, 11, 12, 12, 12, 12, 12, 13, 14, 1...",Rhythm,Abnormal ECG,1
1420,554080,V6,"[10, 10, 10, 11, 12, 12, 12, 12, 12, 13, 14, 1...",Rhythm,No previous ECGs available,1
1421,554080,V6,"[10, 10, 10, 11, 12, 12, 12, 12, 12, 13, 14, 1...",Rhythm,No previous ECGs available,0
1422,554080,V6,"[10, 10, 10, 11, 12, 12, 12, 12, 12, 13, 14, 1...",Rhythm,Sinus rhythm,0


In [4]:
# concatenate all leads into a single array
waveform_lead_concat = waveform_lead.groupby(["exam_id", "waveform_type"])['decoded_waveform'].apply(lambda x: tuple(x)).reset_index()

# remove irregular observations, concat tuple into numpy array
waveform_lead_concat = waveform_lead_concat.drop([12,17], axis = 0)
waveform_lead_concat['decoded_waveform'] = waveform_lead_concat['decoded_waveform'].apply(lambda x: np.vstack(x))

waveform_lead_rhythm = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Rhythm"]
waveform_lead_median = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Median"]

waveform_lead_rhythm

Unnamed: 0,exam_id,waveform_type,decoded_waveform
1,548759,Rhythm,"[[4, 3, 2, -1, -4, -4, -4, -4, -4, -7, -10, -8..."
3,549871,Rhythm,"[[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, ..."
5,550602,Rhythm,"[[-22, -20, -18, -16, -14, -14, -14, -12, -10,..."
7,551485,Rhythm,"[[46, 45, 44, 42, 40, 35, 30, 26, 22, 18, 14, ..."
9,552077,Rhythm,"[[-7, -4, -1, -6, -10, -12, -14, -11, -11, -14..."
11,552856,Rhythm,"[[-32, -32, -32, -33, -34, -34, -34, -33, -32,..."
14,553115,Rhythm,"[[-8, -5, -2, -2, -2, -5, -8, -8, -8, -7, -6, ..."
16,553528,Rhythm,"[[-12, -12, -12, -12, -12, -12, -12, -12, -12,..."


In [5]:
# Adding the labels/sentences
exams = diagnosis_data["exam_id"].unique()

# Let's look over this tomorrow
diagnosis_data = diagnosis_data[diagnosis_data['Original_Diag'] == 1].dropna()
searchfor = ['previous', 'unconfirmed', 'compared', 'interpretation', 'significant']
diagnosis_data = diagnosis_data.loc[diagnosis_data['Full_text'].str.contains('|'.join(searchfor)) != 1]
#

diagnosis_data.sort_values(by=["exam_id", "statement_order"], inplace=True)
diagnoses = []
curr_id = 0
curr_string = ""
for i, row in diagnosis_data.iterrows():
    if row["statement_order"] == 1 and curr_string != "":
        curr_string = curr_string.lower().translate(str.maketrans('', '', string.punctuation))
        val = [curr_id, curr_string[1:]]
        diagnoses.append(val)
        curr_string = ""
        curr_id = row["exam_id"]

    if curr_id == 0:
        curr_id = row["exam_id"]
    
    curr_string += " " + row["Full_text"]

diagnosis_df = pd.DataFrame(diagnoses, columns = ['exam_id', 'diagnosis'])
waveform_lead_rhythm_diag = pd.merge(left=waveform_lead_rhythm, right=diagnosis_df, left_on='exam_id', right_on='exam_id')

#waveform_lead_rhythm_diag
waveform_lead_rhythm_diag

Unnamed: 0,exam_id,waveform_type,decoded_waveform,diagnosis
0,548759,Rhythm,"[[4, 3, 2, -1, -4, -4, -4, -4, -4, -7, -10, -8...",normal sinus rhythm low voltage qrs borderline...
1,549871,Rhythm,"[[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, ...",sinus bradycardia otherwise normal ecg
2,550602,Rhythm,"[[-22, -20, -18, -16, -14, -14, -14, -12, -10,...",sinus tachycardia otherwise normal ecg
3,551485,Rhythm,"[[46, 45, 44, 42, 40, 35, 30, 26, 22, 18, 14, ...",normal sinus rhythm normal ecg
4,552077,Rhythm,"[[-7, -4, -1, -6, -10, -12, -14, -11, -11, -14...",normal sinus rhythm normal ecg
5,552856,Rhythm,"[[-32, -32, -32, -33, -34, -34, -34, -33, -32,...",normal sinus rhythm with sinus arrhythmia mini...
6,553115,Rhythm,"[[-8, -5, -2, -2, -2, -5, -8, -8, -8, -7, -6, ...",atrial fibrillation abnormal ecg normal sinus ...


In [6]:
unique_words = set()
for num, sentence in diagnoses:
    for word in sentence.split():
        unique_words.add(word)
print(unique_words)

{'inferior', 'tachycardia', 'for', 'abnormal', 'qrs', 'ecg', 'criteria', 'arrhythmia', 'wave', 'consider', 'rhythm', 'sinus', 'low', 'otherwise', 'ischemia', 'bradycardia', 'with', 't', 'abnormality', 'be', 'minimal', 'normal', 'borderline', 'may', 'voltage', 'lvh', 'variant', 'fibrillation', 'atrial'}


In [18]:
# split data into training and testing datasets
# y not included for now
def one_hot(x, dict_words, max_length):
    x = x.split(" ")
    array = []
    for i in x:
        array.append(dict_words.index(i))
    while(len(array) < max_length):
        array.append(29)
    return array

dict_words = list(unique_words)
dict_words.append([" "])
print(len(dict_words))
Y = waveform_lead_rhythm_diag['diagnosis'].apply(lambda x: one_hot(x, dict_words, 20))

train_x, test_x, train_y, test_y = train_test_split(waveform_lead_rhythm_diag['decoded_waveform'], Y, test_size = 0.1, random_state = 2021)
train_x = torch.tensor(list(train_x)).float()
train_y = torch.tensor(list(train_y))

30


## Model 1 - Conv1D Encoder w/ LSTM Decoder

In [22]:
# HYPERPARAMETERS
J = 10 # max number of filters per class
LR = 1e-3

# define global max pooling
class global_max_pooling_1d(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x, _ = torch.max(x, dim = 2)
        return(x)

# 1D grouped encoder model
encoder_conv = nn.Sequential()
encoder_conv.add_module('initial_norm', nn.BatchNorm1d(8))
encoder_conv.add_module('conv_1', nn.Conv1d(in_channels = 8, out_channels = 8, groups = 8, kernel_size = 5, padding = 2))
for i in range(2, (J+2), 2):
    if (i-2) == 0: 
        prev = 8
    else:
        prev = (i-2)*8
    encoder_conv.add_module('conv_{num}'.format(num = int(i / 2 + 1)), nn.Conv1d(in_channels = prev, out_channels = i*8, groups = 8, kernel_size = 5, padding = 2, stride = 2))
    encoder_conv.add_module('activation_{num}'.format(num = int(i / 2 + 1)), nn.ELU())
    encoder_conv.add_module('batch_norm_{num}'.format(num = int(i / 2 + 1)), nn.BatchNorm1d(i*8))
    
encoder_conv.add_module('final_conv', nn.Conv1d(in_channels = J * 8, out_channels = 8, groups = 8, kernel_size = 5, padding = 2))
encoder_conv.add_module('max_pool', nn.MaxPool1d(kernel_size = 7, padding = 3, stride = 4))
encoder_conv.add_module('reshape', nn.MaxPool1d(kernel_size = 5, padding = 2, stride = 1))


# summarize model, verify output is of desired shape
print(encoder_conv)
print(encoder_conv(train_x).shape)

Sequential(
  (initial_norm): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_1): Conv1d(8, 8, kernel_size=(5,), stride=(1,), padding=(2,), groups=8)
  (conv_2): Conv1d(8, 16, kernel_size=(5,), stride=(2,), padding=(2,), groups=8)
  (activation_2): ELU(alpha=1.0)
  (batch_norm_2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_3): Conv1d(16, 32, kernel_size=(5,), stride=(2,), padding=(2,), groups=8)
  (activation_3): ELU(alpha=1.0)
  (batch_norm_3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_4): Conv1d(32, 48, kernel_size=(5,), stride=(2,), padding=(2,), groups=8)
  (activation_4): ELU(alpha=1.0)
  (batch_norm_4): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_5): Conv1d(48, 64, kernel_size=(5,), stride=(2,), padding=(2,), groups=8)
  (activation_5): ELU(alpha=1.0)
  (batch_norm_5): BatchNorm1d(64, eps=1e-05, momentum=0.1,

## Model 2 - LSTM Encoder w/ Huggingface Decoder

In [23]:
# define hyperparameters 
hidden_layers = 512
embedding_dim = 8
num_words = len(dict_words)

class ECG_LSTM(nn.Module):
    def __init__(self, encoder, h_dim, e_dim, word_list_length):
        super(ECG_LSTM, self).__init__()
        self.encoder = encoder
        self.lstm = nn.LSTM(e_dim, h_dim)
        self.linear = nn.Linear(h_dim, word_list_length)
        
    def forward(self, seq):
        seq_embedded = self.encoder(seq).view(len(seq), -1, embedding_dim)
        final_hidd, _ = self.lstm(seq_embedded)
        dec_seq = self.linear(final_hidd)
        return F.log_softmax(dec_seq, dim = 1)
    
lstm_mod = ECG_LSTM(encoder_conv, hidden_layers, embedding_dim, num_words)
lstm_mod(train_x).shape

torch.Size([6, 20, 30])

In [46]:
epoch = 10000
loss_fn = nn.NLLLoss()
optimizer = torch.optim.Adam(lstm_mod.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)
for i in range(epoch):
    for j, k in zip(train_x, train_y):
        optimizer.zero_grad()
        outputs = lstm_mod(j.unsqueeze(0)).squeeze(0)
        loss = loss_fn(outputs, k)
        loss.backward(retain_graph=True)
        optimizer.step()


KeyboardInterrupt: 

In [49]:
torch.save(lstm_mod.state_dict(), 'model/lstm.pt')

out = lstm_mod(train_x[1].unsqueeze(0))
print(out.squeeze(0).detach().numpy().shape)
out = np.argmax(out.squeeze(0).detach().numpy(), axis = 1)
print(out)
print(train_y[1])

(20, 30)
[28 27  3  5 22 11 10 16 11  6  2 26  4 23 23 23 23 25 25  4]
tensor([28, 27,  3,  5, 21, 11, 10, 16, 11,  7, 21,  5, 29, 29, 29, 29, 29, 29,
        29, 29])


## Model 3 - Basic Transformer Architecture with Multi-Head Attention

## Model 4 - FNET Transformer Architecture

In [None]:
class FeedForwardNet(nn.Module):
    def __init__(self, leads, expansion, dropout):
        super(FeedForwardNet, self).__init__()
        self.linear_1 = nn.Linear(leads, leads * expansion)
        self.linear_2 = nn.Linear(leads * expansion, leads)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

        def forward(self, x):
            res = x
            x = self.dropout_1(self.linear_1(x))
            x = self.dropout_2(self.linear_2(x))
            x = nn.LayerNorm(x + res)
            return x
        
class FNETEncoder(nn.Module):
    def __init__(self, leads, expansion, dropout):
        super(FNETEncoder, self).__init__()
        self.feed_forward = FeedForwardNet(leads, expansion, dropout)
        
    
    def forward(self, x):
        res = x
        x = torch.fft.fft(x, dim = 1).real
        x = nn.LayerNorm(x + res)
        x = self.feed_forward(x)
        return x

# class FNETTransformer(nn.Transformer):
#     def __init__(self): 
    
#     def forward(self, x):

## Model 5 - FNET/Basic Mixup Architecture 