In [28]:
# import all packages needed
import string
import numpy as np
import pandas as pd
from matplotlib import pyplot
from base64 import b64decode as decode

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

## Data Processing / Cleaning

In [3]:
# use class base64 to decode waveform data
def to_array(wf):
    barr = bytearray(decode(wf))
    vals = np.array(barr)
    return vals.view(np.int16)

# read in data
exam_data = pd.read_csv("data/d_exam.csv").drop(columns = ["site_num", "patient_id_edit"])
waveform_data = pd.read_csv("data/d_waveform.csv")
lead_data = pd.read_csv("data/d_lead_data.csv").drop(columns = ["exam_id"])
diagnosis_data = pd.read_csv("data/d_diagnosis.csv").drop(columns = ["user_input"])

# add decoded data as a column to lead data
waveforms = list(lead_data['waveform_data'])
lead_data['decoded_waveform'] = [to_array(i) for i in waveforms]

# merge waveform data and lead data
waveform_lead = lead_data.merge(waveform_data, how = "left", left_on = "waveform_id", right_on = "waveform_id", suffixes = (None, None))

#  sort by exam id and lead id
waveform_lead.sort_values(by = ["waveform_id", "lead_id"], inplace = True)

waveform_lead.loc[:, ['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']]


# adding the diagnosis and labels
waveform_and_diag = pd.merge(waveform_lead[['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']], diagnosis_data[["exam_id", "Full_text", "Original_Diag"]], left_on= "exam_id", right_on="exam_id")
waveform_and_diag

Unnamed: 0,exam_id,lead_id,decoded_waveform,waveform_type,Full_text,Original_Diag
0,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,No previous ECGs available,0
1,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,Otherwise normal ECG,0
2,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,Sinus bradycardia,0
3,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,,0
4,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,Sinus bradycardia,1
...,...,...,...,...,...,...
1419,554080,V6,"[10, 10, 10, 11, 12, 12, 12, 12, 12, 13, 14, 1...",Rhythm,Abnormal ECG,1
1420,554080,V6,"[10, 10, 10, 11, 12, 12, 12, 12, 12, 13, 14, 1...",Rhythm,No previous ECGs available,1
1421,554080,V6,"[10, 10, 10, 11, 12, 12, 12, 12, 12, 13, 14, 1...",Rhythm,No previous ECGs available,0
1422,554080,V6,"[10, 10, 10, 11, 12, 12, 12, 12, 12, 13, 14, 1...",Rhythm,Sinus rhythm,0


In [40]:
# concatenate all leads into a single array
waveform_lead_concat = waveform_lead.groupby(["exam_id", "waveform_type"])['decoded_waveform'].apply(lambda x: tuple(x)).reset_index()

# remove irregular observations, concat tuple into numpy array
waveform_lead_concat = waveform_lead_concat.drop([12,17], axis = 0)
waveform_lead_concat['decoded_waveform'] = waveform_lead_concat['decoded_waveform'].apply(lambda x: np.vstack(x))

waveform_lead_rhythm = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Rhythm"]
waveform_lead_median = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Median"]

waveform_lead_rhythm

Unnamed: 0,diagnosis_id,exam_id,statement_order,Original_Diag,Full_text
11,7599587,548759,1,1,Normal sinus rhythm
17,7609137,548759,2,1,Low voltage QRS
3,7585624,548759,3,1,Borderline ECG
12,7601558,548759,4,1,When compared with ECG of
16,7608373,548759,5,1,
10,7598485,548759,6,1,(unconfirmed)
15,7607110,548759,7,1,No significant change was found
6,7588835,549871,1,1,Sinus bradycardia
18,7610082,549871,2,1,Otherwise normal ECG
8,7596610,549871,3,1,No previous ECGs available


In [39]:
# Adding the labels/sentences
exams = diagnosis_data["exam_id"].unique()
diagnosis_data = diagnosis_data[diagnosis_data['Original_Diag'] == 1]
diagnosis_data[diagnosis_data['Full_text'].str.contains('previous', regex = False)
diagnosis_data.sort_values(by=["exam_id", "statement_order"], inplace=True)
diagnosis_data
diagnoses = []
curr_id = 0
curr_string = ""
for i, row in diagnosis_data.iterrows():
    if row["statement_order"] == 1 and curr_string != "":
        val = [curr_id, curr_string[2:] + "."]
        diagnoses.append(val)
        curr_string = ""
        curr_id = row["exam_id"]

    if curr_id == 0:
        curr_id = row["exam_id"]
    
    if type(row["Full_text"]) == str:
        if row["Full_text"][0].isupper():
            curr_string += "; " + row["Full_text"]
        else:
            curr_string += " " + row["Full_text"]

waveform_lead_rhythm_diag = pd.merge(left=waveform_lead_rhythm, right=diagnosis_df, left_on='exam_id', right_on='exam_id')
waveform_lead_median_diag = pd.merge(left=waveform_lead_median, right=diagnosis_df, left_on='exam_id', right_on='exam_id')

#waveform_lead_rhythm_diag
waveform_lead_rhythm_diag['diagnosis'] = waveform_lead_rhythm_diag['diagnosis'].apply(lambda x: x.lower().translate(str.maketrans('', '', string.punctuation)))

11    False
17    False
3     False
12    False
16      NaN
10    False
15    False
6     False
18    False
8      True
25    False
32    False
26     True
27    False
20    False
31     True
34    False
71    False
47     True
72    False
44    False
59    False
68    False
69     True
37    False
54    False
46    False
53     True
52    False
38    False
76     True
70    False
39    False
50    False
66    False
62     True
80    False
82    False
78    False
84    False
85     True
Name: Full_text, dtype: object


AttributeError: 'str' object has no attribute 'punctuation'

In [31]:
unique_words = set()
for num, string in diagnoses:
    
    string = string.lower()
    phrases = string.split("; ")[:-1]
    if phrases[-1][-3:] != "ecg":
        phrases = phrases[:-1]
    if "**" in phrases[0]:
        phrases = phrases[1:]
        
    for phrase in phrases:
        for i in phrase.split():
            unique_words.add(i)
print(unique_words)
    


['normal sinus rhythm', 'low voltage qrs', 'borderline ecg', 'when compared with ecg of (unconfirmed)']
['sinus bradycardia', 'otherwise normal ecg']
['sinus tachycardia', 'otherwise normal ecg']
['normal sinus rhythm', 'normal ecg']
['normal sinus rhythm', 'normal ecg']
['normal sinus rhythm with sinus arrhythmia', 'minimal voltage criteria for lvh, may be normal variant', 'borderline ecg']
['normal sinus rhythm', 't wave abnormality, consider inferior ischemia', 'abnormal ecg']
['atrial fibrillation', 'abnormal ecg']
['** poor data quality, interpretation may be adversely affected', 'normal sinus rhythm with sinus arrhythmia', 'normal ecg']
{'voltage', 'fibrillation', 'may', 'variant', 'for', 'arrhythmia', 'ecg', 't', 'inferior', 'rhythm', 'wave', 'borderline', 'bradycardia', 'low', 'tachycardia', 'criteria', 'be', 'sinus', 'lvh,', 'normal', 'otherwise', 'minimal', 'qrs', 'with', 'consider', 'abnormality,', 'ischemia', 'abnormal', 'atrial'}


In [24]:
# split data into training and testing datasets
# y not included for now
train_x, test_x, _, _ = train_test_split(waveform_lead_rhythm['decoded_waveform'], waveform_lead_rhythm['decoded_waveform'], test_size = 0.1, random_state = 2021)
train_x = torch.tensor(list(train_x)).float()
train_x.shape

torch.Size([7, 8, 2500])

## Model 1 - Conv1D Encoder w/ LSTM Decoder

In [16]:
# HYPERPARAMETERS
J = 10 # max number of filters per class
LR = 1e-3

# define global max pooling
class global_max_pooling_1d(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x, _ = torch.max(x, dim = 2)
        return(x)

# 1D grouped encoder model
encoder_conv = nn.Sequential()
encoder_conv.add_module('initial_norm', nn.BatchNorm1d(8))
encoder_conv.add_module('conv_1', nn.Conv1d(in_channels = 8, out_channels = 8, groups = 8, kernel_size = 5, padding = 2))
for i in range(2, (J+2), 2):
    if (i-2) == 0: 
        prev = 8
    else:
        prev = (i-2)*8
    encoder_conv.add_module('conv_{num}'.format(num = int(i / 2 + 1)), nn.Conv1d(in_channels = prev, out_channels = i*8, groups = 8, kernel_size = 5, padding = 2, stride = 3))
    encoder_conv.add_module('activation_{num}'.format(num = int(i / 2 + 1)), nn.ELU())
    encoder_conv.add_module('batch_norm_{num}'.format(num = int(i / 2 + 1)), nn.BatchNorm1d(i*8))
    
encoder_conv.add_module('final_conv', nn.Conv1d(in_channels = J * 8, out_channels = 8, groups = 8, kernel_size = 5, padding = 2))
encoder_conv.add_module('max_pool', nn.MaxPool1d(kernel_size = 5, padding = 2, stride = 1))

# summarize model, verify output is of desired shape
print(encoder_conv)
encoder_conv(train_x).shape

Sequential(
  (initial_norm): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_1): Conv1d(8, 8, kernel_size=(5,), stride=(1,), padding=(2,), groups=8)
  (conv_2): Conv1d(8, 16, kernel_size=(5,), stride=(3,), padding=(2,), groups=8)
  (activation_2): ELU(alpha=1.0)
  (batch_norm_2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_3): Conv1d(16, 32, kernel_size=(5,), stride=(3,), padding=(2,), groups=8)
  (activation_3): ELU(alpha=1.0)
  (batch_norm_3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_4): Conv1d(32, 48, kernel_size=(5,), stride=(3,), padding=(2,), groups=8)
  (activation_4): ELU(alpha=1.0)
  (batch_norm_4): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_5): Conv1d(48, 64, kernel_size=(5,), stride=(3,), padding=(2,), groups=8)
  (activation_5): ELU(alpha=1.0)
  (batch_norm_5): BatchNorm1d(64, eps=1e-05, momentum=0.1,

torch.Size([7, 8, 11])

## Model 2 - LSTM Encoder w/ Huggingface Decoder

In [30]:
# define hyperparameters 
hidden_layers = 512
embedding_dim = 8
num_words = len(unique_words)

class ECG_LSTM(nn.Module):
    def __init__(self, encoder, h_dim, e_dim, word_list_length):
        super(ECG_LSTM, self).__init__()
        self.encoder = encoder
        self.lstm = nn.LSTM(e_dim, h_dim)
        self.linear = nn.Linear(h_dim, word_list_length)
        
    def forward(self, seq):
        seq_embedded = self.encoder(seq)
        final_hidd, _ = self.lstm(seq_embedded)
        dec_seq = self.linear(final_hidd)
        return F.log_softmax(dec_seq)
    
lstm_dec = ECG_LSTM(encoder_conv, hidden_layers, embedding_dim, num_words)
lstm_dec(train_x).shape

  app.launch_new_instance()


torch.Size([7, 8, 500])

## Model 3 - Basic Transformer Architecture with Multi-Head Attention

## Model 4 - FNET Transformer Architecture

## Model 5 - FNET/Basic Mixup Architecture 