In [28]:
# import all packages needed
import numpy as np
import pandas as pd
from matplotlib import pyplot
from base64 import b64decode as decode

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

## Data Processing / Cleaning

In [3]:
# use class base64 to decode waveform data
def to_array(wf):
    barr = bytearray(decode(wf))
    vals = np.array(barr)
    return vals.view(np.int16)

# read in data
exam_data = pd.read_csv("data/d_exam.csv").drop(columns = ["site_num", "patient_id_edit"])
waveform_data = pd.read_csv("data/d_waveform.csv")
lead_data = pd.read_csv("data/d_lead_data.csv").drop(columns = ["exam_id"])
diagnosis_data = pd.read_csv("data/d_diagnosis.csv").drop(columns = ["user_input"])

# add decoded data as a column to lead data
waveforms = list(lead_data['waveform_data'])
lead_data['decoded_waveform'] = [to_array(i) for i in waveforms]

# merge waveform data and lead data
waveform_lead = lead_data.merge(waveform_data, how = "left", left_on = "waveform_id", right_on = "waveform_id", suffixes = (None, None))

#  sort by exam id and lead id
waveform_lead.sort_values(by = ["waveform_id", "lead_id"], inplace = True)

waveform_lead.loc[:, ['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']]


# adding the diagnosis and labels
waveform_and_diag = pd.merge(waveform_lead[['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']], diagnosis_data[["exam_id", "Full_text", "Original_Diag"]], left_on= "exam_id", right_on="exam_id")
waveform_and_diag.head()

Unnamed: 0,exam_id,lead_id,decoded_waveform,waveform_type,Full_text,Original_Diag
0,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,No previous ECGs available,0
1,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,Otherwise normal ECG,0
2,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,Sinus bradycardia,0
3,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,,0
4,549871,I,"[-8, -8, -8, -8, -8, -8, -8, -7, -6, -5, -4, -...",Rhythm,Sinus bradycardia,1


In [4]:
# concatenate all leads into a single array
waveform_lead_concat = waveform_lead.groupby(["exam_id", "waveform_type"])['decoded_waveform'].apply(lambda x: tuple(x)).reset_index()

# remove irregular observations, concat tuple into numpy array
waveform_lead_concat = waveform_lead_concat.drop([12,17], axis = 0)
waveform_lead_concat['decoded_waveform'] = waveform_lead_concat['decoded_waveform'].apply(lambda x: np.vstack(x))

waveform_lead_rhythm = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Rhythm"]

In [13]:
# split data into training and testing datasets
# y not included for now
train_x, test_x, _, _ = train_test_split(waveform_lead_rhythm['decoded_waveform'], waveform_lead_rhythm['decoded_waveform'], test_size = 0.1, random_state = 2021)
train_x = torch.tensor(list(train_x)).float()
train_x.shape

torch.Size([7, 8, 2500])

## Model 1 - Conv1D Encoder w/ LSTM Decoder

In [16]:
# HYPERPARAMETERS
J = 10 # max number of filters per class
LR = 1e-3

# define global max pooling
class global_max_pooling_1d(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x, _ = torch.max(x, dim = 2)
        return(x)

# 1D grouped encoder model
encoder_conv = nn.Sequential()
encoder_conv.add_module('initial_norm', nn.BatchNorm1d(8))
encoder_conv.add_module('conv_1', nn.Conv1d(in_channels = 8, out_channels = 8, groups = 8, kernel_size = 5, padding = 2))
for i in range(2, (J+2), 2):
    if (i-2) == 0: 
        prev = 8
    else:
        prev = (i-2)*8
    encoder_conv.add_module('conv_{num}'.format(num = int(i / 2 + 1)), nn.Conv1d(in_channels = prev, out_channels = i*8, groups = 8, kernel_size = 5, padding = 2, stride = 3))
    encoder_conv.add_module('activation_{num}'.format(num = int(i / 2 + 1)), nn.ELU())
    encoder_conv.add_module('batch_norm_{num}'.format(num = int(i / 2 + 1)), nn.BatchNorm1d(i*8))
    
encoder_conv.add_module('final_conv', nn.Conv1d(in_channels = J * 8, out_channels = 8, groups = 8, kernel_size = 5, padding = 2))
encoder_conv.add_module('max_pool', nn.MaxPool1d(kernel_size = 5, padding = 2, stride = 1))

# summarize model, verify output is of desired shape
print(encoder_conv)
encoder_conv(train_x).shape

Sequential(
  (initial_norm): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_1): Conv1d(8, 8, kernel_size=(5,), stride=(1,), padding=(2,), groups=8)
  (conv_2): Conv1d(8, 16, kernel_size=(5,), stride=(3,), padding=(2,), groups=8)
  (activation_2): ELU(alpha=1.0)
  (batch_norm_2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_3): Conv1d(16, 32, kernel_size=(5,), stride=(3,), padding=(2,), groups=8)
  (activation_3): ELU(alpha=1.0)
  (batch_norm_3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_4): Conv1d(32, 48, kernel_size=(5,), stride=(3,), padding=(2,), groups=8)
  (activation_4): ELU(alpha=1.0)
  (batch_norm_4): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_5): Conv1d(48, 64, kernel_size=(5,), stride=(3,), padding=(2,), groups=8)
  (activation_5): ELU(alpha=1.0)
  (batch_norm_5): BatchNorm1d(64, eps=1e-05, momentum=0.1,

torch.Size([7, 8, 11])

## Model 2 - LSTM Encoder w/ Huggingface Decoder

In [30]:
# define hyperparameters 
hidden_layers = 512
embedding_dim = 11

class ECG_LSTM(nn.Module):
    def __init__(self, encoder, h_dim, e_dim, word_list_length):
        super(ECG_LSTM, self).__init__()
        self.encoder = encoder
        self.lstm = nn.LSTM(e_dim, h_dim)
        self.linear = nn.Linear(h_dim, word_list_length)
        
    def forward(self, seq):
        seq_embedded = self.encoder(seq)
        final_hidd, _ = self.lstm(seq_embedded)
        dec_seq = self.linear(final_hidd)
        return F.log_softmax(dec_seq)
    
lstm_dec = ECG_LSTM(encoder_conv, hidden_layers, embedding_dim, 500)
lstm_dec(train_x).shape

  app.launch_new_instance()


torch.Size([7, 8, 500])

## Model 3 - Basic Transformer Architecture with Multi-Head Attention

## Model 4 - FNET Transformer Architecture

## Model 5 - FNET/Basic Mixup Architecture 