In [1]:
# import all packages needed
import string, math
import numpy as np
import pandas as pd
from matplotlib import pyplot
from base64 import b64decode as decode
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config


import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

## Data Processing / Cleaning

In [2]:
# use class base64 to decode waveform data
def to_array(wf):
    barr = bytearray(decode(wf))
    vals = np.array(barr)
    return vals.view(np.int16).astype(np.float32)

# read in data
exam_data = pd.read_csv("data/d_exam.csv").drop(columns = ["site_num", "patient_id_edit"])
waveform_data = pd.read_csv("data/d_waveform.csv")
lead_data = pd.read_csv("data/d_lead_data.csv").drop(columns = ["exam_id"])
diagnosis_data = pd.read_csv("data/d_diagnosis.csv").drop(columns = ["user_input"])

# add decoded data as a column to lead data
waveforms = list(lead_data['waveform_data'])
lead_data['decoded_waveform'] = [to_array(i) for i in waveforms]

# merge waveform data and lead data
waveform_lead = lead_data.merge(waveform_data, how = "left", left_on = "waveform_id", right_on = "waveform_id", suffixes = (None, None))

#  sort by exam id and lead id
waveform_lead.sort_values(by = ["waveform_id", "lead_id"], inplace = True)

waveform_lead.loc[:, ['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']]
waveform_lead

Unnamed: 0,lead_data_id,waveform_id,WavfmType,lead_id,lead_byte_count_total,lead_time_offset,waveform_data,lead_sample_count_total,lead_amplitude,lead_units,...,exam_id,waveform_type,number_of_leads,Waveform_Start_Time,Sample_Type,Sample_Base,Sample_Exponent,High_Pass_Filter,Low_Pass_Filter,AC_Filter
10,9078054,1095618,,I,5000,0,+P/4//j/+P/4//j/+P/5//r/+//8//z//P/7//r/+f/4/...,2500,4.88,MICROVOLTS,...,549871,Rhythm,8,0,CONTINUOUS_SAMPLES,250,0,5,150,NONE
15,9081703,1095618,,II,5000,0,9v/2//b/8//w//D/8P/x//L/8//0//T/9P/z//L/8f/w/...,2500,4.88,MICROVOLTS,...,549871,Rhythm,8,0,CONTINUOUS_SAMPLES,250,0,5,150,NONE
8,9074278,1095618,,V1,5000,0,/v/+//7//v/+////AAAAAAAAAQACAAIAAgACAAIAAgACA...,2500,4.88,MICROVOLTS,...,549871,Rhythm,8,0,CONTINUOUS_SAMPLES,250,0,5,150,NONE
1,9066887,1095618,,V2,5000,0,9v/1//T/9P/0//T/9P/0//T/9f/2//b/9v/2//b/9v/2/...,2500,4.88,MICROVOLTS,...,549871,Rhythm,8,0,CONTINUOUS_SAMPLES,250,0,5,150,NONE
18,9082771,1095618,,V3,5000,0,7v/u/+7/7f/s/+z/7P/t/+7/7v/u/+7/7v/u/+7/7v/u/...,2500,4.88,MICROVOLTS,...,549871,Rhythm,8,0,CONTINUOUS_SAMPLES,250,0,5,150,NONE
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
150,9187141,1109067,,V4,1200,0,KAApACoAKwAsACwALQAtAC4ALgAuAC4ALgAuAC4ALgAvA...,600,4.88,MICROVOLTS,...,554080,Rhythm,8,0,CONTINUOUS_SAMPLES,500,0,16,150,NONE
152,9190675,1109067,,V5,1200,0,FgAXABkAGQAbABsAGwAbABsAGwAbABwAHQAeAB4AHgAfA...,600,4.88,MICROVOLTS,...,554080,Rhythm,8,0,CONTINUOUS_SAMPLES,500,0,16,150,NONE
155,9177603,1109067,,V5,10000,0,+v/6//r/+v/7//z//f/+//z//P/8//z//v/+//7//v/+/...,5000,4.88,MICROVOLTS,...,554080,Rhythm,8,0,CONTINUOUS_SAMPLES,500,0,16,150,NONE
140,9172851,1109067,,V6,10000,0,7v/u/+7/7v/x//L/8//0//T/9P/0//T/9P/0//T/9P/0/...,5000,4.88,MICROVOLTS,...,554080,Rhythm,8,0,CONTINUOUS_SAMPLES,500,0,16,150,NONE


In [3]:
# concatenate all leads into a single array
waveform_lead_concat = waveform_lead.groupby(["exam_id", "waveform_type"])['decoded_waveform'].apply(lambda x: tuple(x)).reset_index()
waveform_lead_concat
# remove irregular observations, concat tuple into numpy array
waveform_lead_concat = waveform_lead_concat.drop([12,17], axis = 0)
waveform_lead_concat['decoded_waveform'] = waveform_lead_concat['decoded_waveform'].apply(lambda x: np.vstack(x))#.apply(lambda x: np.transpose(x))

waveform_lead_rhythm = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Rhythm"]
#waveform_lead_median = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Median"]

for value in waveform_lead_rhythm["decoded_waveform"]:
    value /= 1024
    value += .5
    print(np.max(value))
    print(np.min(value))

0.63183594
0.3154297
0.75097656
0.31347656
0.7832031
0.08496094
0.8984375
0.29296875
0.78222656
0.31152344
0.8417969
0.19335938
0.6904297
0.33691406
0.6972656
0.40234375


In [4]:
# Adding the labels/sentences
exams = diagnosis_data["exam_id"].unique()

# Let's look over this tomorrow
diagnosis_data = diagnosis_data[diagnosis_data['Original_Diag'] == 1].dropna()
searchfor = ['previous', 'unconfirmed', 'compared', 'interpretation', 'significant']
diagnosis_data = diagnosis_data.loc[diagnosis_data['Full_text'].str.contains('|'.join(searchfor)) != 1]
#

diagnosis_data.sort_values(by=["exam_id", "statement_order"], inplace=True)
diagnoses = []
curr_id = 0
curr_string = ""
for i, row in diagnosis_data.iterrows():
    if row["statement_order"] == 1 and curr_string != "":
        curr_string = curr_string.lower().translate(str.maketrans('', '', string.punctuation))
        val = [curr_id, curr_string[1:]]
        diagnoses.append(val)
        curr_string = ""
        curr_id = row["exam_id"]

    if curr_id == 0:
        curr_id = row["exam_id"]
    
    curr_string += " " + row["Full_text"]

diagnosis_df = pd.DataFrame(diagnoses, columns = ['exam_id', 'diagnosis'])
waveform_lead_rhythm_diag = pd.merge(left=waveform_lead_rhythm, right=diagnosis_df, left_on='exam_id', right_on='exam_id')

#waveform_lead_rhythm_diag
for i in waveform_lead_rhythm_diag["diagnosis"]:
    print(i)

normal sinus rhythm low voltage qrs borderline ecg
sinus bradycardia otherwise normal ecg
sinus tachycardia otherwise normal ecg
normal sinus rhythm normal ecg
normal sinus rhythm normal ecg
normal sinus rhythm with sinus arrhythmia minimal voltage criteria for lvh may be normal variant borderline ecg
atrial fibrillation abnormal ecg normal sinus rhythm with sinus arrhythmia normal ecg


In [5]:
unique_words = set()
for num, sentence in diagnoses:
    for word in sentence.split():
        unique_words.add(word)
print(unique_words)
unique_words = list(unique_words)
word_map = dict()
for i, word in enumerate(unique_words):
    word_map[word] = i+1
word_map[""] = 0
print(word_map)

{'variant', 'arrhythmia', 'qrs', 'rhythm', 'ecg', 't', 'low', 'abnormality', 'sinus', 'for', 'tachycardia', 'be', 'inferior', 'otherwise', 'may', 'minimal', 'with', 'wave', 'fibrillation', 'bradycardia', 'ischemia', 'lvh', 'normal', 'voltage', 'atrial', 'borderline', 'consider', 'abnormal', 'criteria'}
{'variant': 1, 'arrhythmia': 2, 'qrs': 3, 'rhythm': 4, 'ecg': 5, 't': 6, 'low': 7, 'abnormality': 8, 'sinus': 9, 'for': 10, 'tachycardia': 11, 'be': 12, 'inferior': 13, 'otherwise': 14, 'may': 15, 'minimal': 16, 'with': 17, 'wave': 18, 'fibrillation': 19, 'bradycardia': 20, 'ischemia': 21, 'lvh': 22, 'normal': 23, 'voltage': 24, 'atrial': 25, 'borderline': 26, 'consider': 27, 'abnormal': 28, 'criteria': 29, '': 0}


In [11]:
# split data into training and testing datasets
# y not included for now
def one_hot(x, dict_words):
    x = x.split(" ")
    array = []
    for i in x:
        array.append([0] + [1 if y == i else 0 for y in dict_words] + [0,0])
    for i in range(17-len(x)):
        array.append([1 if i == 30 else 0 for i in range(32)])
    return array

dict_words = list(unique_words)
#waveform_lead_rhythm_diag['diagnosis'] = waveform_lead_rhythm_diag['diagnosis'].apply(lambda x: one_hot(x, dict_words))

len(waveform_lead_rhythm_diag["diagnosis"][5])
train_x, test_x, train_y, test_y = train_test_split(waveform_lead_rhythm_diag['decoded_waveform'], waveform_lead_rhythm_diag['diagnosis'], test_size = 0.1, random_state = 2021)
train_x = torch.tensor(list(train_x)).float()
train_x.shape
train_x = torch.tensor(list(waveform_lead_rhythm_diag['decoded_waveform'])).float()
train_x.shape

torch.Size([7, 8, 2500])

## Model 1 - Conv1D Encoder w/ LSTM Decoder

In [25]:
# HYPERPARAMETERS
J = 8 # max number of filters per class
LR = 1e-3

# define global max pooling
class global_max_pooling_1d(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x, _ = torch.max(x, dim = 2)
        return(x)

# 1D grouped encoder model
encoder_conv = nn.Sequential()
encoder_conv.add_module('initial_norm', nn.BatchNorm1d(8))
encoder_conv.add_module('conv_1', nn.Conv1d(in_channels = 8, out_channels = 8, groups = 8, kernel_size = 5, padding = 2))
for i in range(2, (J+2), 2):
    if (i-2) == 0: 
        prev = 8
    else:
        prev = (i-2)*8
    encoder_conv.add_module('conv_{num}'.format(num = int(i / 2 + 1)), nn.Conv1d(in_channels = prev, out_channels = i*8, groups = 8, kernel_size = 5, padding = 2, stride = 3))
    encoder_conv.add_module('activation_{num}'.format(num = int(i / 2 + 1)), nn.ELU())
    encoder_conv.add_module('batch_norm_{num}'.format(num = int(i / 2 + 1)), nn.BatchNorm1d(i*8))
    
#encoder_conv.add_module('final_conv', nn.Conv1d(in_channels = J * 8, out_channels = 8, groups = 8, kernel_size = 5, padding = 2))
#encoder_conv.add_module('max_pool', nn.MaxPool1d(kernel_size = 5, padding = 2, stride = 1))
encoder_conv.add_module('reshape', nn.MaxPool1d(kernel_size = 5, padding = 2, stride = 1))


# summarize model, verify output is of desired shape
print(train_x[0].shape)
print(encoder_conv(torch.unsqueeze(train_x[0], 0)).shape)

torch.Size([8, 2500])
torch.Size([1, 64, 31])


In [7]:
# ResConv

# HYPERPARAMETERS
J = 10 # max number of filters per class
LR = 1e-3
KER_SIZE = 11
PADDING = 5
# define global max pooling
class global_max_pooling_1d(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x, _ = torch.max(x, dim = 2)
        return(x)

# define resblock for neural nets
class ResBlock1D(nn.Module):
    def __init__(self, num_filters, kernel_size, padding, groups = 1, stride = 1):
        super(ResBlock1D, self).__init__()
        self.act = nn.ReLU()
        self.conv1d_1 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.conv1d_2 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.batch_norm_1 = nn.BatchNorm1d(num_filters)
        self.batch_norm_2 = nn.BatchNorm1d(num_filters)

    def forward(self, x):
        res = x
        x = self.batch_norm_1(self.act(self.conv1d_1(x)))
        x = self.batch_norm_2(self.act(self.conv1d_2(x)))
        return x + res

conv_model = nn.Sequential()
init_channels = 8
for i in range(5):
    next_channels = 2 * init_channels
    conv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = KER_SIZE, padding = PADDING, stride = 1))
    conv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    conv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    conv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = KER_SIZE, padding = PADDING))
    conv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
conv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 768, kernel_size = KER_SIZE, padding = PADDING))
conv_model.add_module('act_fin', nn.ReLU())
conv_model.add_module('batch_fin', nn.BatchNorm1d(768))
print(conv_model)
#print(conv_model(train_x).shape)
conv_embedder = conv_model





Sequential(
  (conv_0): Conv1d(8, 16, kernel_size=(11,), stride=(1,), padding=(5,))
  (act_0): ReLU()
  (batch_norm_0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_0): ResBlock1D(
    (act): ReLU()
    (conv1d_1): Conv1d(16, 16, kernel_size=(11,), stride=(1,), padding=(5,))
    (conv1d_2): Conv1d(16, 16, kernel_size=(11,), stride=(1,), padding=(5,))
    (batch_norm_1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (batch_norm_2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (act_res_0): ReLU()
  (conv_1): Conv1d(16, 32, kernel_size=(11,), stride=(1,), padding=(5,))
  (act_1): ReLU()
  (batch_norm_1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_1): ResBlock1D(
    (act): ReLU()
    (conv1d_1): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(5,))
    (conv1d_2): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(

In [30]:
deconv_model = nn.Sequential()
init_channels = 768
for i in range(5):
    next_channels = init_channels // 2
    deconv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = KER_SIZE, padding = PADDING, stride = 1))
    deconv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    deconv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    deconv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = KER_SIZE, padding = PADDING))
    deconv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
deconv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 8, kernel_size = KER_SIZE, padding = PADDING))
deconv_model.add_module('act_fin', nn.ReLU())
deconv_model.add_module('batch_fin', nn.BatchNorm1d(8))

print(train_x.shape)
print(data.shape)
print(conv_model(train_x).shape)
print(deconv_model(conv_model(train_x)).shape)

torch.Size([7, 8, 2500])
torch.Size([7, 2500, 8])
torch.Size([7, 768, 2500])
torch.Size([7, 8, 2500])


In [36]:
class ConvAutoEncoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConvAutoEncoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
    def make_encoder(self):
        return self.encoder
    
    def make_decoder(self):
        return self.decoder
    
    


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model to set to
auto_model = ConvAutoEncoder(conv_model, deconv_model)
auto_optimizer = torch.optim.Adam(auto_model.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)

In [46]:
# Training params
loss_function = nn.MSELoss()

#print(train_x[0])
for i in range(180):
    #print(train_x.shape)
    auto_optimizer.zero_grad()
    outputs = auto_model(train_x)
    #print(outputs.shape)
    losses = loss_function(outputs, train_x)
    losses.backward(retain_graph=True)
    auto_optimizer.step()
    print(losses)
    if losses < .001:
        break

tensor(0.0028, grad_fn=<MseLossBackward>)
tensor(0.0027, grad_fn=<MseLossBackward>)
tensor(0.0027, grad_fn=<MseLossBackward>)
tensor(0.0027, grad_fn=<MseLossBackward>)
tensor(0.0027, grad_fn=<MseLossBackward>)
tensor(0.0027, grad_fn=<MseLossBackward>)
tensor(0.0027, grad_fn=<MseLossBackward>)
tensor(0.0027, grad_fn=<MseLossBackward>)
tensor(0.0027, grad_fn=<MseLossBackward>)
tensor(0.0026, grad_fn=<MseLossBackward>)
tensor(0.0026, grad_fn=<MseLossBackward>)
tensor(0.0026, grad_fn=<MseLossBackward>)
tensor(0.0026, grad_fn=<MseLossBackward>)
tensor(0.0026, grad_fn=<MseLossBackward>)
tensor(0.0026, grad_fn=<MseLossBackward>)
tensor(0.0026, grad_fn=<MseLossBackward>)
tensor(0.0026, grad_fn=<MseLossBackward>)
tensor(0.0026, grad_fn=<MseLossBackward>)
tensor(0.0025, grad_fn=<MseLossBackward>)
tensor(0.0025, grad_fn=<MseLossBackward>)
tensor(0.0025, grad_fn=<MseLossBackward>)
tensor(0.0025, grad_fn=<MseLossBackward>)
tensor(0.0025, grad_fn=<MseLossBackward>)
tensor(0.0025, grad_fn=<MseLossBac

In [54]:
torch.save(auto_model.state_dict(), 'model/autoencoder.pt')

conv_embedder = auto_model.make_encoder()


torch.save(conv_embedder.state_dict(), "model/embedder.pt")

## Model 2 - LSTM Encoder w/ Huggingface Decoder

In [1]:
# define hyperparameters 
hidden_layers = 512
embedding_dim = 8
num_words = len(unique_words)

class ECG_LSTM(nn.Module):
    def __init__(self, encoder, h_dim, e_dim, word_list_length):
        super(ECG_LSTM, self).__init__()
        self.encoder = encoder
        self.lstm = nn.LSTM(e_dim, h_dim)
        self.linear = nn.Linear(h_dim, word_list_length)
        
    def forward(self, seq):
        seq_embedded = self.encoder(seq)
        final_hidd, _ = self.lstm(seq_embedded)
        dec_seq = self.linear(final_hidd)
        return F.log_softmax(dec_seq)
    
lstm_dec = ECG_LSTM(encoder_conv, hidden_layers, embedding_dim, num_words)
lstm_dec(train_x).shape

NameError: name 'unique_words' is not defined

## Model 3 - Basic Transformer Architecture with Multi-Head Attention

In [49]:
print(data.shape)
new_data = conv_embedder(train_x)
print(new_data.shape)

torch.Size([7, 2500, 8])
torch.Size([7, 768, 2500])


In [15]:
new_data = new_data.detach()
print(new_data)

tensor([[[ 2.0995, -0.5898, -0.4277,  ...,  3.1599, -0.5898,  7.7721],
         [-0.5547,  0.6758, -0.5547,  ..., -0.5547, -0.5547,  3.1045],
         [-0.5680, -0.5680,  2.9091,  ...,  0.8775, -0.5680,  1.5757],
         ...,
         [-0.5168,  1.1706,  0.0231,  ..., -0.5168, -0.5168,  3.2656],
         [-0.4377,  2.8713,  0.2086,  ..., -0.4377,  2.6868,  0.6346],
         [ 0.8298, -0.3843, -0.3843,  ..., -0.3843, -0.3843, -0.3843]],

        [[-0.5898, -0.5898, -0.5898,  ..., 11.0307, -0.5898,  6.4339],
         [-0.5547,  1.8430, -0.5547,  ..., -0.5547, -0.5547,  0.7982],
         [-0.5680, -0.4991,  1.5885,  ..., -0.5680,  3.4990, -0.5680],
         ...,
         [-0.5168, -0.5168, -0.5168,  ..., 12.3154, -0.0480, -0.5168],
         [-0.4377, -0.4377, -0.4377,  ..., -0.4377,  4.6131, -0.4377],
         [ 1.8815, -0.3843, -0.3843,  ..., -0.3843, -0.3843, -0.3843]],

        [[-0.5767, -0.5898, -0.5898,  ..., -0.5898, -0.5898,  2.2787],
         [ 1.3027, -0.5547, -0.0715,  ..., -0

# Transformer testing

In [50]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class ECGTransformerEncoder(nn.Module):
    # Takes the ECG discrete signals sequence and maps into a probability distribution of diagnosis
    # For working/verification purposes
    def __init__(self, vector_size, embed_dim, n_heads, hidden_linear_dim, n_layers, dropout):
        super(ECGTransformerEncoder, self).__init__()
        self.model_type = "Transformer"
        self.positional_encoder = PositionalEncoder(embed_dim, dropout)
    
        #Since our data is already discrete numbers, might need some tweaking for this
        self.embedder = conv_embedder
                        #64 31              #39        64
        
        
        self.encoder = TransformerEncoder(
            TransformerEncoderLayer(embed_dim, n_heads, hidden_linear_dim, dropout),
            n_layers)
        
        self.n_inputs = embed_dim
        self.n_layers = n_layers
        
        # Simple linear decoder
        self.decoder = nn.Sequential(
                        nn.Linear(768, 17),
                        Transpose(17, 2500),
                        nn.Linear(2500, 30),
                        nn.LogSoftmax()
                        )
        self.init_weights()
        
    def init_weights(self):
        #self.embedder.weight.data.uniform_(-.1, .1)
        #self.decoder.bias.data.zero_()
        #self.decoder.weight.data.uniform_(-.1, .1)
        pass
        
    def forward(self, x):
        #x = self.embedder(x) # * math.sqrt(self.n_inputs)
        x = x.squeeze(0)
        #x = x.view(2500, 8)
        x = x.unsqueeze(1)
        x = self.positional_encoder(x)
        x = self.encoder(x)
        x = x.squeeze(1) 
        #x = self.decoder(x)
        return x

class Transpose(nn.Module):
    def __init__(self, *args):
        super(Transpose, self).__init__()
        self.shape = args

    def forward(self, x):
        # If the number of the last batch sample in the data set is smaller than the defined batch_batch size, mismatch problems will occur. You can modify it yourself, for example, just pass in the shape behind, and then enter it through x.szie(0).
        return x.view(self.shape)

class SignalEmbedder(nn.Module):
    # Necessary to convert the signal into "word" vectors for transformer processing.
    # Currently a simple group and slice method, but will modify later for multi-channel inputs
    
    def __init__(self, num_slices, size_of_slice):
        super(SignalEmbedder, self).__init__()
        self.num_slices = num_slices
        self.size_of_slice = size_of_slice
        
    def forward(self, x):
        x = x[: self.num_slices * self.size_of_slice]
        x = x.reshape((self.num_slices, self.size_of_slice))
        return x
'''
class OneHotConverter(nn.Module):
    # Converts the sigmoid output into one-hots
    
    def __init__(self, size, sentence_length):
        super(OneHotConverter, self).__init__()
        self.arr_length = size
        self.num_words = sentence_length
        
    def forward(self, x):
        output = []
        for num in x:
            num = num.item()
            num *= self.arr_length
            val = np.zeros(self.arr_length)
            val[int(round(num))] = 1
        
            output.append(val)
        output = torch.as_tensor(output)
        output.requires_grad_()
        return output
'''    

class PositionalEncoder(nn.Module):
    # Necessary to store positional data about the input data
    def __init__(self, embed_dim, dropout=0.1, max_len=2500, batch_size = 1):
        super(PositionalEncoder, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pos_encoding = torch.zeros(max_len, 1, embed_dim)
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
        
        divisor = torch.exp(torch.arange(0, embed_dim, 2).float() * (- math.log(10000.0) / embed_dim))
        
        pos_encoding[:, 0, 0::2] = torch.sin(position * divisor)
        pos_encoding[:, 0, 1::2] = torch.cos(position * divisor)
        pos_encoding = pos_encoding.repeat(1, batch_size, 1)
        self.register_buffer("pos_encoding", pos_encoding)

        
    def forward(self, x):
        x = x + self.pos_encoding[:x.size(0), :]
        return self.dropout(x)

In [51]:
# Training pipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model to set to
model = ECGTransformerEncoder(vector_size=5, embed_dim=768, n_heads=16, hidden_linear_dim=2048, n_layers=2, dropout=0.3).to(device)

# Training params
loss_function = nn.L1Loss()

optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)
data = []
for arr in waveform_lead_rhythm_diag["decoded_waveform"]:
    #print(arr)
    arr = arr.transpose()
    data.append(arr)

labels = []
for sentence in waveform_lead_rhythm_diag["diagnosis"]:
    #label = one_hot(sentence, dict_words)
    label = []
    for word in sentence.split():
        label.append(word_map[word])
    
    while len(label) < 17:
        label.append(0)
    labels.append(np.array(label))
data = torch.from_numpy(np.array(data, dtype=np.float64)).type(torch.FloatTensor)
print(labels[1])
labels = torch.from_numpy(np.array(labels))
print(data)
model.train()

[ 9 20 14 23  5  0  0  0  0  0  0  0  0  0  0  0  0]
tensor([[[0.5039, 0.5059, 0.4980,  ..., 0.5117, 0.5078, 0.5059],
         [0.5029, 0.5049, 0.4980,  ..., 0.5088, 0.5049, 0.5039],
         [0.5020, 0.5039, 0.4980,  ..., 0.5059, 0.5020, 0.5020],
         ...,
         [0.4932, 0.4941, 0.5020,  ..., 0.4814, 0.4824, 0.4844],
         [0.4902, 0.4922, 0.5059,  ..., 0.4805, 0.4805, 0.4805],
         [0.4902, 0.4922, 0.5059,  ..., 0.4805, 0.4805, 0.4805]],

        [[0.4922, 0.4902, 0.4980,  ..., 0.4746, 0.4746, 0.4844],
         [0.4922, 0.4902, 0.4980,  ..., 0.4756, 0.4756, 0.4854],
         [0.4922, 0.4902, 0.4980,  ..., 0.4766, 0.4766, 0.4863],
         ...,
         [0.5098, 0.5088, 0.4980,  ..., 0.5088, 0.5107, 0.5068],
         [0.5059, 0.5020, 0.5000,  ..., 0.4990, 0.5020, 0.5000],
         [0.5059, 0.5020, 0.5000,  ..., 0.4990, 0.5020, 0.5000]],

        [[0.4785, 0.4961, 0.5078,  ..., 0.5195, 0.5176, 0.5137],
         [0.4805, 0.4961, 0.5088,  ..., 0.5215, 0.5195, 0.5146],
     

ECGTransformerEncoder(
  (positional_encoder): PositionalEncoder(
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (embedder): Sequential(
    (conv_0): Conv1d(8, 16, kernel_size=(11,), stride=(1,), padding=(5,))
    (act_0): ReLU()
    (batch_norm_0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (res_0): ResBlock1D(
      (act): ReLU()
      (conv1d_1): Conv1d(16, 16, kernel_size=(11,), stride=(1,), padding=(5,))
      (conv1d_2): Conv1d(16, 16, kernel_size=(11,), stride=(1,), padding=(5,))
      (batch_norm_1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (batch_norm_2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act_res_0): ReLU()
    (conv_1): Conv1d(16, 32, kernel_size=(11,), stride=(1,), padding=(5,))
    (act_1): ReLU()
    (batch_norm_1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (res_1): ResBlock1D(
      (a

In [70]:
loss_function = nn.MSELoss()

new_data = new_data.detach().transpose(1,2)
for i in range(500):
    losses = 0
    for x, y in zip(new_data, labels):
        optimizer.zero_grad()
        outputs = model(x.unsqueeze(0))
        loss = loss_function(outputs, x)
        print("loss: " + str(loss.item()))
        losses += loss
    losses.backward(retain_graph=True)
    optimizer.step()
    print("epoch loss: ", str(losses.item()))
    if losses < .001:
        break

        '''        
for x, y in zip(new_data, labels):
    print(np.argmax(model(x.view(2500,768)).detach().numpy(), axis=1))
    print(y.detach().numpy())
'''

loss: 0.7370614409446716
loss: 0.6791639924049377
loss: 0.3094591200351715
loss: 0.5278668403625488
loss: 0.34741342067718506
loss: 0.21746216714382172
loss: 0.4615210294723511
epoch loss:  3.2799482345581055
loss: 0.734946072101593
loss: 0.6778768301010132
loss: 0.3100247383117676
loss: 0.5258252024650574
loss: 0.34639695286750793
loss: 0.2173265665769577
loss: 0.4615813195705414
epoch loss:  3.273977756500244
loss: 0.7365527153015137
loss: 0.6790151000022888
loss: 0.310395210981369
loss: 0.5262953639030457
loss: 0.3462901711463928
loss: 0.21628332138061523
loss: 0.46113672852516174
epoch loss:  3.2759687900543213
loss: 0.7375586032867432
loss: 0.6762019395828247
loss: 0.3095795214176178
loss: 0.5260034799575806
loss: 0.3452601432800293
loss: 0.2161376029253006
loss: 0.4588436782360077
epoch loss:  3.269584894180298
loss: 0.7332217693328857
loss: 0.6752980351448059
loss: 0.3084912896156311
loss: 0.5254095792770386
loss: 0.34439337253570557
loss: 0.21576423943042755
loss: 0.45846849679

loss: 0.3173438012599945
loss: 0.19620873034000397
loss: 0.42660072445869446
epoch loss:  3.0821802616119385
loss: 0.7072585821151733
loss: 0.648779571056366
loss: 0.28226491808891296
loss: 0.49447309970855713
loss: 0.31689944863319397
loss: 0.19547352194786072
loss: 0.42521652579307556
epoch loss:  3.0703656673431396
loss: 0.7105720043182373
loss: 0.64705890417099
loss: 0.2804194986820221
loss: 0.495645672082901
loss: 0.3166711926460266
loss: 0.19561223685741425
loss: 0.42468318343162537
epoch loss:  3.070662498474121
loss: 0.7088840007781982
loss: 0.6469820141792297
loss: 0.28018665313720703
loss: 0.49448731541633606
loss: 0.3157186210155487
loss: 0.19518138468265533
loss: 0.42388829588890076
epoch loss:  3.0653281211853027
loss: 0.7050816416740417
loss: 0.645703136920929
loss: 0.2805134654045105
loss: 0.4920767545700073
loss: 0.314157634973526
loss: 0.19419549405574799
loss: 0.4245525002479553
epoch loss:  3.0562806129455566
loss: 0.7109255194664001
loss: 0.642879843711853
loss: 0.2

epoch loss:  2.8985064029693604
loss: 0.6829267144203186
loss: 0.6212787628173828
loss: 0.25892961025238037
loss: 0.4668465256690979
loss: 0.29011425375938416
loss: 0.17959058284759521
loss: 0.3936387896537781
epoch loss:  2.893325090408325
loss: 0.6824503540992737
loss: 0.6210655570030212
loss: 0.2562723755836487
loss: 0.4659837484359741
loss: 0.28984832763671875
loss: 0.17923122644424438
loss: 0.3936094343662262
epoch loss:  2.888461112976074
loss: 0.6848329305648804
loss: 0.6227961182594299
loss: 0.25555479526519775
loss: 0.4659068286418915
loss: 0.28926920890808105
loss: 0.17833280563354492
loss: 0.39109566807746887
epoch loss:  2.8877882957458496
loss: 0.6808115839958191
loss: 0.6218788027763367
loss: 0.25577518343925476
loss: 0.4650154709815979
loss: 0.2890644669532776
loss: 0.17762236297130585
loss: 0.3914775252342224
epoch loss:  2.881645441055298
loss: 0.680741012096405
loss: 0.6193974018096924
loss: 0.2545902132987976
loss: 0.46395719051361084
loss: 0.2876981198787689
loss: 0

loss: 0.6602972745895386
loss: 0.5981490612030029
loss: 0.2358904480934143
loss: 0.44127821922302246
loss: 0.2670598030090332
loss: 0.166484996676445
loss: 0.3645620048046112
epoch loss:  2.733721971511841
loss: 0.6586454510688782
loss: 0.5953205823898315
loss: 0.2342836558818817
loss: 0.4397212266921997
loss: 0.26630598306655884
loss: 0.16499045491218567
loss: 0.36368805170059204
epoch loss:  2.7229554653167725
loss: 0.6592363119125366
loss: 0.5958686470985413
loss: 0.23520660400390625
loss: 0.44125130772590637
loss: 0.2672121226787567
loss: 0.16569747030735016
loss: 0.3643198311328888
epoch loss:  2.728792428970337
loss: 0.6595476269721985
loss: 0.5957673192024231
loss: 0.23404939472675323
loss: 0.4399004280567169
loss: 0.2661578357219696
loss: 0.16504575312137604
loss: 0.363160103559494
epoch loss:  2.723628520965576
loss: 0.6561885476112366
loss: 0.5952978134155273
loss: 0.23309111595153809
loss: 0.4368353486061096
loss: 0.26460567116737366
loss: 0.16389529407024384
loss: 0.3618593

loss: 0.21380984783172607
loss: 0.4138464629650116
loss: 0.2440020740032196
loss: 0.153212770819664
loss: 0.33399754762649536
epoch loss:  2.5632846355438232
loss: 0.6343693733215332
loss: 0.5705179572105408
loss: 0.21248801052570343
loss: 0.4147000312805176
loss: 0.24355854094028473
loss: 0.15354491770267487
loss: 0.33218780159950256
epoch loss:  2.5613667964935303
loss: 0.6331907510757446
loss: 0.5680468678474426
loss: 0.21136678755283356
loss: 0.41170457005500793
loss: 0.2419542819261551
loss: 0.153114452958107
loss: 0.3318275213241577
epoch loss:  2.5512051582336426
loss: 0.632131040096283
loss: 0.568535327911377
loss: 0.21034877002239227
loss: 0.41116365790367126
loss: 0.24175767600536346
loss: 0.15264101326465607
loss: 0.3313818871974945
epoch loss:  2.547959327697754
loss: 0.6309769749641418
loss: 0.5706782341003418
loss: 0.2103147953748703
loss: 0.41034746170043945
loss: 0.24051469564437866
loss: 0.1521732062101364
loss: 0.32931405305862427
epoch loss:  2.5443193912506104
loss:

loss: 0.38520756363868713
loss: 0.21928168833255768
loss: 0.14209173619747162
loss: 0.30424371361732483
epoch loss:  2.394357442855835
loss: 0.6118438839912415
loss: 0.5430666208267212
loss: 0.18979372084140778
loss: 0.3856427073478699
loss: 0.21884067356586456
loss: 0.1412000209093094
loss: 0.30355140566825867
epoch loss:  2.3939390182495117
loss: 0.6102650165557861
loss: 0.5445623993873596
loss: 0.18943020701408386
loss: 0.3857511579990387
loss: 0.21842695772647858
loss: 0.1407075822353363
loss: 0.30232277512550354
epoch loss:  2.3914661407470703
loss: 0.6087973117828369
loss: 0.5422974824905396
loss: 0.18909095227718353
loss: 0.3827359676361084
loss: 0.21778571605682373
loss: 0.14097535610198975
loss: 0.30219119787216187
epoch loss:  2.38387393951416
loss: 0.6071667671203613
loss: 0.5432260036468506
loss: 0.18872994184494019
loss: 0.3838512897491455
loss: 0.21701332926750183
loss: 0.14049983024597168
loss: 0.30089065432548523
epoch loss:  2.381377935409546
loss: 0.610371470451355
lo

loss: 0.13227878510951996
loss: 0.2811209261417389
epoch loss:  2.270521879196167
loss: 0.5893149971961975
loss: 0.5253823399543762
loss: 0.1741001158952713
loss: 0.3630402684211731
loss: 0.20086583495140076
loss: 0.13211049139499664
loss: 0.2801170349121094
epoch loss:  2.2649312019348145
loss: 0.5904955267906189
loss: 0.522892951965332
loss: 0.17399823665618896
loss: 0.3619498610496521
loss: 0.20038099586963654
loss: 0.13189175724983215
loss: 0.28027814626693726
epoch loss:  2.261887550354004
loss: 0.5896156430244446
loss: 0.523681640625
loss: 0.1725614368915558
loss: 0.36161008477211
loss: 0.199580118060112
loss: 0.13152305781841278
loss: 0.27874070405960083
epoch loss:  2.257312536239624
loss: 0.5932161211967468
loss: 0.5233979821205139
loss: 0.17201846837997437
loss: 0.3612832725048065
loss: 0.19943709671497345
loss: 0.1309831738471985
loss: 0.2790580093860626
epoch loss:  2.2593941688537598
loss: 0.5871614813804626
loss: 0.522607684135437
loss: 0.17236408591270447
loss: 0.3607567

epoch loss:  2.164316177368164
loss: 0.5743978023529053
loss: 0.5073972940444946
loss: 0.16144664585590363
loss: 0.34283578395843506
loss: 0.18707358837127686
loss: 0.12446510046720505
loss: 0.26312723755836487
epoch loss:  2.160743474960327
loss: 0.5733863711357117
loss: 0.5073127746582031
loss: 0.1616535633802414
loss: 0.34242725372314453
loss: 0.18626654148101807
loss: 0.12392211705446243
loss: 0.26139792799949646
epoch loss:  2.1563665866851807
loss: 0.5739397406578064
loss: 0.507036030292511
loss: 0.1611405611038208
loss: 0.3415936827659607
loss: 0.18535707890987396
loss: 0.1237591877579689
loss: 0.26182854175567627
epoch loss:  2.1546547412872314
loss: 0.5740092396736145
loss: 0.5063887238502502
loss: 0.15998122096061707
loss: 0.34301769733428955
loss: 0.18569643795490265
loss: 0.12409556657075882
loss: 0.2607095539569855
epoch loss:  2.1538984775543213
loss: 0.5726830363273621
loss: 0.5046035051345825
loss: 0.1605639010667801
loss: 0.3412225842475891
loss: 0.18532349169254303
lo

loss: 0.5607642531394958
loss: 0.4952254891395569
loss: 0.1537228673696518
loss: 0.32862481474876404
loss: 0.17680121958255768
loss: 0.11909927427768707
loss: 0.24885429441928864
epoch loss:  2.083092212677002
loss: 0.5627892017364502
loss: 0.49395641684532166
loss: 0.1535744071006775
loss: 0.3286449611186981
loss: 0.17692017555236816
loss: 0.1194380670785904
loss: 0.24827203154563904
epoch loss:  2.0835952758789062
loss: 0.5602575540542603
loss: 0.49572235345840454
loss: 0.15410108864307404
loss: 0.3281877338886261
loss: 0.17614516615867615
loss: 0.1195257157087326
loss: 0.2492993324995041
epoch loss:  2.0832390785217285
loss: 0.559352457523346
loss: 0.4945198595523834
loss: 0.1521582454442978
loss: 0.32771801948547363
loss: 0.17614729702472687
loss: 0.11934905499219894
loss: 0.24821989238262177
epoch loss:  2.0774648189544678
loss: 0.5627308487892151
loss: 0.4954156279563904
loss: 0.15314015746116638
loss: 0.3268163502216339
loss: 0.17615292966365814
loss: 0.11831577867269516
loss: 0

loss: 0.4833579659461975
loss: 0.1467922031879425
loss: 0.3144601583480835
loss: 0.16827717423439026
loss: 0.11459264159202576
loss: 0.2379608303308487
epoch loss:  2.0204360485076904
loss: 0.5511054396629333
loss: 0.48417648673057556
loss: 0.14701546728610992
loss: 0.3142998516559601
loss: 0.16769616305828094
loss: 0.1140153780579567
loss: 0.23736168444156647
epoch loss:  2.0156702995300293
loss: 0.5507424473762512
loss: 0.4825613796710968
loss: 0.14602702856063843
loss: 0.3146819770336151
loss: 0.16854910552501678
loss: 0.11423654109239578
loss: 0.2366643249988556
epoch loss:  2.013462781906128
loss: 0.5525040626525879
loss: 0.48430079221725464
loss: 0.1465781033039093
loss: 0.3131634294986725
loss: 0.16791567206382751
loss: 0.11449184268712997
loss: 0.23672005534172058
epoch loss:  2.015673875808716
loss: 0.5503508448600769
loss: 0.482476145029068
loss: 0.14598038792610168
loss: 0.3138553202152252
loss: 0.1678025871515274
loss: 0.1134597510099411
loss: 0.23666232824325562
epoch loss

loss: 0.473579466342926
loss: 0.14094235002994537
loss: 0.3021235167980194
loss: 0.16078901290893555
loss: 0.11056685447692871
loss: 0.2269517481327057
epoch loss:  1.9548527002334595
loss: 0.542399525642395
loss: 0.4732648432254791
loss: 0.1400357037782669
loss: 0.30183395743370056
loss: 0.16107015311717987
loss: 0.11056899279356003
loss: 0.22587136924266815
epoch loss:  1.9550445079803467
loss: 0.5439355373382568
loss: 0.47389498353004456
loss: 0.13965582847595215
loss: 0.3011467456817627
loss: 0.16022071242332458
loss: 0.11009621620178223
loss: 0.2267804890871048
epoch loss:  1.9557305574417114
loss: 0.5435824394226074
loss: 0.4724372327327728
loss: 0.14056077599525452
loss: 0.30120593309402466
loss: 0.1610337346792221
loss: 0.1098509356379509
loss: 0.22633413970470428
epoch loss:  1.9550050497055054
loss: 0.5388643741607666
loss: 0.47180232405662537
loss: 0.13909777998924255
loss: 0.300655335187912
loss: 0.1607562154531479
loss: 0.10984116047620773
loss: 0.22554078698158264
epoch l

loss: 0.4650704562664032
loss: 0.13633038103580475
loss: 0.2904833257198334
loss: 0.15488216280937195
loss: 0.10741567611694336
loss: 0.21837854385375977
epoch loss:  1.90736985206604
loss: 0.534217119216919
loss: 0.46385717391967773
loss: 0.13484002649784088
loss: 0.2917906939983368
loss: 0.1548757553100586
loss: 0.10643579065799713
loss: 0.21769143640995026
epoch loss:  1.9037079811096191
loss: 0.5345015525817871
loss: 0.4633285701274872
loss: 0.13624247908592224
loss: 0.29008492827415466
loss: 0.15404626727104187
loss: 0.1072128638625145
loss: 0.21721157431602478
epoch loss:  1.9026284217834473
loss: 0.5323505401611328
loss: 0.4649980366230011
loss: 0.1346813589334488
loss: 0.2895457148551941
loss: 0.1547623872756958
loss: 0.10704650729894638
loss: 0.21718449890613556
epoch loss:  1.900568962097168
loss: 0.534225583076477
loss: 0.46395188570022583
loss: 0.13519519567489624
loss: 0.28953778743743896
loss: 0.15458105504512787
loss: 0.10723628103733063
loss: 0.21746936440467834
epoch l

loss: 0.4568825662136078
loss: 0.13041217625141144
loss: 0.2801719903945923
loss: 0.14867718517780304
loss: 0.10400459170341492
loss: 0.20912981033325195
epoch loss:  1.8538992404937744
loss: 0.5246441960334778
loss: 0.45662403106689453
loss: 0.13032843172550201
loss: 0.27985191345214844
loss: 0.14917799830436707
loss: 0.10442136973142624
loss: 0.20931361615657806
epoch loss:  1.854361653327942
loss: 0.5284889936447144
loss: 0.4586355984210968
loss: 0.1302054226398468
loss: 0.281062513589859
loss: 0.14953386783599854
loss: 0.10394299775362015
loss: 0.2100769728422165
epoch loss:  1.8619462251663208
loss: 0.5268982648849487
loss: 0.4567541778087616
loss: 0.130122572183609
loss: 0.28115570545196533
loss: 0.15042632818222046
loss: 0.1044456884264946
loss: 0.2097441405057907
epoch loss:  1.8595467805862427
loss: 0.5254129767417908
loss: 0.4560297727584839
loss: 0.13114707171916962
loss: 0.27973511815071106
loss: 0.14971676468849182
loss: 0.1040036678314209
loss: 0.209462970495224
epoch los

In [71]:
torch.save(model.state_dict(), 'model/transformer_768.pt')

In [None]:
from transformers import BertTokenizer, BertForNextSentencePrediction
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')

prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
encoding = tokenizer(prompt, next_sentence, return_tensors='pt')

outputs = model(**encoding, labels=torch.LongTensor([1]))
logits = outputs.logits
#assert logits[0, 0] < logits[0, 1] # next sentence was random
print(logits)

In [81]:
epochs = 1

# define tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

# preprocess training labels and tokenize
train_labels = list(waveform_lead_rhythm_diag['diagnosis'])
inputs = tokenizer(train_labels, padding = True, pad_token = tokenizer.add_special_tokens({'pad_token': '[PAD]'}), verbose = False, return_tensors="pt")

# adjust model parameters to account for padding token
model.resize_token_embeddings(len(tokenizer))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)

for i in range(epochs):
#model_gpt2DoubleHeadsModel.resize_token_embeddings(len(gpt2_tokenizer))
    optimizer.zero_grad()
    outputs = model(**inputs, labels = inputs["input_ids"])
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    
    print(loss)
    
logits = model.logits
print(np.argmax(logits[0].detach().numpy(), axis = 1))

Keyword arguments {'pad_token': 1} not recognized.
Keyword arguments {'pad_token': 1} not recognized.
Keyword arguments {'pad_token': 1} not recognized.
Keyword arguments {'pad_token': 1} not recognized.
Keyword arguments {'pad_token': 1} not recognized.
Keyword arguments {'pad_token': 1} not recognized.
Keyword arguments {'pad_token': 1} not recognized.


tensor(42.4789, grad_fn=<NllLossBackward>)


ModuleAttributeError: 'GPT2LMHeadModel' object has no attribute 'logits'

## Model 4 - Cohesive 1 Wrapper Transformer Architecture

## Model 5 - FNET/Basic Mixup Architecture 