In [33]:
# import all packages needed
import string 
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from base64 import b64decode as decode
import math
import torch
import torch.nn as nn 
from torch.nn import CrossEntropyLoss, MSELoss
from torch.utils.tensorboard import SummaryWriter
import torch.fft as fft
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2Config, GPT2PreTrainedModel, GPT2Model
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from typing import Optional, Tuple
import transformers

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

# The only time we need to define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data Processing / Cleaning

In [169]:
# use class base64 to decode waveform data
def to_array(wf):
    barr = bytearray(decode(wf))
    vals = np.array(barr)
    return vals.view(np.int16).astype(np.float32)

# read in data
exam_data = pd.read_csv("data/combined_exam.csv").drop(columns = ["site_num", "patient_id_edit"])
waveform_data = pd.read_csv("data/combined_waveform.csv")
lead_data = pd.read_csv("data/combined_lead_data.csv").drop(columns = ["exam_id"])
diagnosis_data = pd.read_csv("data/combined_diagnosis.csv").drop(columns = ["user_input"])

# add decoded data as a column to lead dataz
waveforms = list(lead_data['waveform_data'])
lead_data['decoded_waveform'] = [to_array(i) for i in waveforms]

# merge waveform data and lead data
waveform_lead = lead_data.merge(waveform_data, how = "left", left_on = "waveform_id", right_on = "waveform_id", suffixes = (None, None))

#  sort by exam id and lead id
waveform_lead.sort_values(by = ["waveform_id", "lead_id"], inplace = True)

waveform_lead.loc[:, ['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']]


# adding the diagnosis and labels
waveform_and_diag = pd.merge(waveform_lead[['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']], diagnosis_data[["exam_id", "Full_text", "Original_Diag"]], left_on= "exam_id", right_on="exam_id")


# concatenate all leads into a single array
waveform_lead_concat = waveform_lead.groupby(["exam_id", "waveform_type"])['decoded_waveform'].apply(lambda x: tuple(x)).reset_index()


# remove irregular observations, concat tuple into numpy array
#waveform_lead_concat = waveform_lead_concat.drop([12,17], axis = 0)
waveform_lead_concat = waveform_lead_concat[waveform_lead_concat["decoded_waveform"].apply(lambda x: len(x[0]) == 2500)]
waveform_lead_concat = waveform_lead_concat[waveform_lead_concat["decoded_waveform"].apply(lambda x: len(x) == 8)]
   

waveform_lead_concat['decoded_waveform'] = waveform_lead_concat['decoded_waveform'].apply(lambda x: np.vstack(x))
waveform_lead_rhythm = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Rhythm"]

waveform_lead_rhythm["decoded_waveform"] = waveform_lead_rhythm["decoded_waveform"].apply(lambda value: MinMaxScaler().fit_transform(value))



exams = diagnosis_data["exam_id"].unique()

diagnosis_data = diagnosis_data[diagnosis_data['Original_Diag'] == 1].dropna()

searchfor = ['previous', 'unconfirmed', 'compared', 'interpretation', 'significant']
diagnosis_data = diagnosis_data.loc[diagnosis_data['Full_text'].str.contains('|'.join(searchfor)) != 1]

diagnosis_data.sort_values(by=["exam_id", "statement_order"], inplace=True)
diagnoses = []
curr_id = 0
curr_string = ""


# making the tokens
tokens = set()


prefixed_phrase = ""
for i, row in diagnosis_data.iterrows():
    if curr_id == 0:
        curr_id = row["exam_id"]
    
    if row["exam_id"] != curr_id and curr_string != "":
        curr_string = curr_string.lower()
        curr_string = curr_string.replace("     ", "").replace(" ,", "")
        val = [curr_id, curr_string[2:]]
        print(val)
        diagnoses.append(val)
        curr_string = ""
        curr_id = row["exam_id"]

    
    if "*" in row["Full_text"] or "(" in row["Full_text"]:
        continue
    
    
    if row["Full_text"][-3:] == "for" or row["Full_text"][-4:] == "with" or row["Full_text"][-1] == "&":
        prefixed_phrase = row["Full_text"].lower() + " "
        curr_string += "@"
        continue
    
    if curr_string and curr_string[-1] == "@":
        curr_string = curr_string[:-1]
        curr_string += " " + row["Full_text"]
    else:
        curr_string += "; " + row["Full_text"]
    
    tokens.add(prefixed_phrase + row["Full_text"].lower())
    prefixed_phrase = ""
    
    
diagnosis_df = pd.DataFrame(diagnoses, columns = ['exam_id', 'diagnosis'])
waveform_lead_rhythm_diag = pd.merge(left=waveform_lead_rhythm, right=diagnosis_df, left_on='exam_id', right_on='exam_id')

full_x = torch.tensor(waveform_lead_rhythm_diag['decoded_waveform']).float()
full_y = waveform_lead_rhythm_diag['diagnosis']
for i in full_y:
    print(i)

[548561, 'sinus bradycardia; otherwise normal ecg']
[548592, 'normal sinus rhythm; normal ecg']
[548593, 'normal sinus rhythm; normal ecg']
[548609, 'normal sinus rhythm; normal ecg']
[548759, 'normal sinus rhythm; normal sinus rhythm; low voltage qrs; low voltage qrs; borderline ecg; borderline ecg']
[549810, 'normal sinus rhythm; left axis deviation; right bundle branch block; abnormal ecg']
[549871, 'sinus bradycardia; sinus bradycardia; otherwise normal ecg; otherwise normal ecg']
[549964, 'sinus rhythm; with marked sinus arrhythmia; otherwise normal ecg']
[550065, 'normal sinus rhythm; with sinus arrhythmia; cannot rule out; inferior infarct; age undetermined; abnormal ecg']
[550307, 'normal sinus rhythm t wave abnormality, consider anterior ischemia; prolonged qt; abnormal ecg; t wave inversion now evident in; anterior leads; qt has lengthened']
[550391, 'sinus bradycardia premature atrial complexes; in a pattern of bigeminy; minimal voltage criteria for lvh, may be normal varian

In [170]:
tokens = list(tokens)
from tokenizers import Tokenizer
from tokenizers.models import BPE, Unigram, WordPiece, WordLevel
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordPieceTrainer, WordLevelTrainer


class custom_tokenizer(nn.Module):
    def __init__(self, vocab, max_len=22):
        super(custom_tokenizer, self).__init__()
        self.tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        self.trainer = WordLevelTrainer(special_tokens=["[PAD]", "[UNK]"])
        self.tokenizer.train_from_iterator(np.array(tokens), trainer)
        self.max_len = max_len
        
    def forward(self, x):
        input_ids = []
        attention_masks = []
        for sentence in x:
            tokenized = self.tokenizer.encode(sentence.split("; "), is_pretokenized=True)
            input_id = [0] + tokenized.ids + [0 for i in range(self.max_len - len(tokenized.ids))]
            attention_mask = [1] + tokenized.attention_mask + [0 for i in range(self.max_len - len(tokenized.attention_mask))]
            input_ids.append(input_id)
            attention_masks.append(attention_mask)
        return {"input_ids": torch.tensor(input_ids).detach(), "attention_mask": torch.tensor(attention_masks).detach()}
    
    def decode(self, x):
        return self.tokenizer.decode(list(x))

    
    
    
    
# BPE, Unigram, WordPiece, WordLevel,  unk_token="[UNK]"
transformer_tokenizer = custom_tokenizer(vocab=tokens)
print(len(transformer_tokenizer.tokenizer.get_vocab()))
output = transformer_tokenizer(full_y)["input_ids"]

99


## Embedder: Conv1D

In [5]:
# This is where we define components to be used for both the Conv1D encoder, and a Conv1D pre-embedder into a Transformer Encoder.

LR = 1e-3
KER_SIZE = 11
PADDING = 5

# define global max pooling
class global_max_pooling_1d(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x, _ = torch.max(x, dim = 2)
        return(x)

# define resblock for neural nets
class ResBlock1D(nn.Module):
    def __init__(self, num_filters, kernel_size, padding, groups = 1, stride = 1):
        super(ResBlock1D, self).__init__()
        self.act = nn.ReLU()
        self.conv1d_1 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.conv1d_2 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.batch_norm_1 = nn.BatchNorm1d(num_filters)
        self.batch_norm_2 = nn.BatchNorm1d(num_filters)

    def forward(self, x):
        res = x
        x = self.batch_norm_1(self.act(self.conv1d_1(x)))
        x = self.batch_norm_2(self.act(self.conv1d_2(x)))
        return x + res

  

## Encoder 1: ResNet Encoder

In [190]:
def init_weights(x):
    if isinstance(x, nn.Conv1d):
        nn.init.kaiming_uniform_(x.weight, mode='fan_in', nonlinearity='relu')
        x.bias.data.fill_(0.01)

# HYPERPARAMETERS
J = 10 # max number of filters per class
LR = 1e-3

# build resent model and display the shape of feed through
conv_model = nn.Sequential()
init_channels = 8
for i in range(4):
    next_channels = 2 * init_channels
    conv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = 249, padding = 124, stride = 2))
    conv_model.add_module('act_{num}'.format(num = i), nn.ELU())
    conv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    conv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = 249, padding = 124))
    conv_model.add_module('act_res_{num}'.format(num = i), nn.ELU())
    init_channels = next_channels
    
conv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 256, kernel_size = 249, padding = 124))
conv_model.add_module('act_fin', nn.ELU())
conv_model.add_module('batch_fin', nn.BatchNorm1d(256))

conv_model.apply(init_weights)

print(conv_model)
conv_model(full_x).shape

Sequential(
  (conv_0): Conv1d(8, 16, kernel_size=(249,), stride=(2,), padding=(124,))
  (act_0): ELU(alpha=1.0)
  (batch_norm_0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_0): ResBlock1D(
    (act): ReLU()
    (conv1d_1): Conv1d(16, 16, kernel_size=(249,), stride=(1,), padding=(124,))
    (conv1d_2): Conv1d(16, 16, kernel_size=(249,), stride=(1,), padding=(124,))
    (batch_norm_1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (batch_norm_2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (act_res_0): ELU(alpha=1.0)
  (conv_1): Conv1d(16, 32, kernel_size=(249,), stride=(2,), padding=(124,))
  (act_1): ELU(alpha=1.0)
  (batch_norm_1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_1): ResBlock1D(
    (act): ReLU()
    (conv1d_1): Conv1d(32, 32, kernel_size=(249,), stride=(1,), padding=(124,))
    (conv1d_2): Conv1d(32, 32, ke

torch.Size([52, 256, 157])

## LSTM Decoder

In [199]:
# define hyperparameters 
hidden_layers = 256
embedding_dim = 256
word_list_length = 98
start_token = end_token = 0

class LSTM_Encoder(nn.Module):
    def __init__(self, h_dim, e_dim):
        super(LSTM_Encoder, self).__init__()
        self.lstm = nn.LSTM(e_dim, h_dim, num_layers = 4, bidirectional = True)
        for name, param in self.lstm.named_parameters():
            if 'bias' in name:
                nn.init.constant(param, 0.01)
            elif 'weight' in name:
                torch.nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain('tanh'))

        
    def forward(self, x, hidden, cell_state):
        if hidden is None and cell_state is None:
            final, comp = self.lstm(x)
        else:
            final, comp = self.lstm(x, (hidden, cell_state))
        hid, cell = comp
        return final, hid, cell
    
    def initial_hidden_cell(self):
        return torch.zeros(8, 1, 256), torch.zeros(8, 1, 256)
    
class LSTM_Decoder(nn.Module):
    def __init__(self, h_dim, e_dim, word_list_length, max_length = 157):
        super(LSTM_Decoder, self).__init__()
        self.emb = nn.Embedding(word_list_length, e_dim)
        
        self.attention = nn.Linear(h_dim*2, max_length)
        torch.nn.init.xavier_uniform_(self.attention.weight, gain=nn.init.calculate_gain('linear'))
        self.attention.bias.data.fill_(0.01)
        
        self.attention_combined = nn.Linear(h_dim * 3, h_dim)
        torch.nn.init.xavier_uniform_(self.attention_combined.weight, gain=nn.init.calculate_gain('linear'))
        self.attention_combined.bias.data.fill_(0.01)
        
        self.lstm = nn.LSTM(e_dim, h_dim)
        for name, param in self.lstm.named_parameters():
            if 'bias' in name:
                nn.init.constant(param, 0.01)
            elif 'weight' in name:
                torch.nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain('tanh'))

        self.out = nn.Linear(h_dim, word_list_length)
        torch.nn.init.xavier_uniform_(self.out.weight, gain=nn.init.calculate_gain('linear'))
        self.out.bias.data.fill_(0.01)
        
    def forward(self, x, hidden, cell_state, encoder_outputs):
        seq_embedded = self.emb(x).view(1, 1, -1)
        
        attention_weights = F.softmax(self.attention(torch.cat((seq_embedded[0], hidden[0]), 1)), 1)
        attention_applied = torch.bmm(attention_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
        
        output = F.relu(self.attention_combined(torch.cat((seq_embedded[0], attention_applied[0]), 1)))
        
        final, states  = self.lstm(output.unsqueeze(0), (hidden, cell_state))
        hidden, cell_state = states
        dec_seq = self.out(final)
        return F.log_softmax(dec_seq[0], dim = -1), hidden, cell_state
    
    def initial_hidden_cell(self):
        return torch.zeros(1, 1, 256), torch.zeros(1, 1, 256)

def train(x, y, embedder, encoder, decoder, emb_optimizer, enc_optimizer, dec_optimizer, teacher_ratio = 0.5):
    hidden_enc, cell_enc = encoder.initial_hidden_cell()
    enc_outputs = torch.zeros(157, hidden_layers * 2)
    
    loss_fn = nn.NLLLoss()
    loss = 0
    
    emb_optimizer.zero_grad()
    enc_optimizer.zero_grad()
    dec_optimizer.zero_grad()
    
    emb_x = embedder(x.unsqueeze(0)).permute(2, 0, 1)
    
    for i in range(len(emb_x)):
        seq = emb_x[i].unsqueeze(0)
        enc_out, hidden_enc, cell_enc = encoder(seq, hidden_enc, cell_enc)
        enc_outputs[i] = enc_out[0, 0]
    
    hidden_dec, cell_dec = decoder.initial_hidden_cell()
    target_lab = torch.tensor(y, dtype = torch.long)
    decoder_input = torch.tensor([[start_token]], dtype = torch.long)
    teacher_forcing = True if torch.rand(1) <= teacher_ratio else False
    fin_len = 0
    if teacher_forcing:
        for j in range(len(target_lab)):
            logit, hidden_dec, cell_dec = decoder(decoder_input, hidden_dec, cell_dec, enc_outputs)
            current_targ = target_lab[j].unsqueeze(0)
            loss = loss_fn(logit, current_targ) + loss
            decoder_input = current_targ
            fin_len = j + 1
            if decoder_input == end_token:
                break
    else:
        for j in range(len(target_lab)):
            logit, hidden_dec, cell_dec = decoder(decoder_input, hidden_dec, cell_dec, enc_outputs)
            _, val = logit.topk(1)
            current_targ = target_lab[j].unsqueeze(0)
            loss = loss_fn(logit, current_targ) + loss
            decoder_input = val.squeeze(0).detach()
            fin_len = j + 1
            if decoder_input == end_token:
                print("trained without teacher enforcing")
                break
    
    loss.backward()
        
    emb_optimizer.step()
    enc_optimizer.step()
    dec_optimizer.step()
    
    return (loss.item() / fin_len)

In [200]:
def append_end_token(i):
    i.append(end_token)
    return i

# define tokenizer
#tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
#tokenizer.pad_token = tokenizer.eos_token
labels = list(full_y)
for i, sentence in enumerate(labels):
    labels[i] = sentence.split("; ")

tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
trainer = WordLevelTrainer(special_tokens=["[PAD]", "[UNK]"])
tokenizer.train_from_iterator(np.array(tokens), trainer)
token_y = [tokenizer.encode(label, is_pretokenized=True).ids for label in labels]
for i in token_y:
    i.append(0)

#token_y = [torch.cat((i, torch.tensor([0]))) for i in token_y]
encoder = LSTM_Encoder(hidden_layers, embedding_dim)
decoder = LSTM_Decoder(hidden_layers, embedding_dim, word_list_length)

emb_optimizer = torch.optim.Adam(conv_model.parameters(), lr = 1e-3)
enc_optimizer = torch.optim.Adam(encoder.parameters(), lr = 1e-3)
dec_optimizer = torch.optim.Adam(decoder.parameters(), lr = 1e-3)

  del sys.path[0]


In [24]:
conv_model.load_state_dict(torch.load('model/lstm_embedder.pt'))
encoder.load_state_dict(torch.load('model/lstm_encoder.pt'))
decoder.load_state_dict(torch.load("model/lstm_decoder.pt"))

<All keys matched successfully>

In [None]:
writer = SummaryWriter('runs/lstm_enc_dec_part_3')
for epoch in range(30):
    tot = 0.0
    for j, k in zip(full_x, token_y):
        loss = train(j, k, conv_model, encoder, decoder, emb_optimizer, enc_optimizer, dec_optimizer, teacher_ratio = 0.8)
        tot = tot + loss
    avg_loss = tot / len(token_y)
    
    print(avg_loss)
    info_dict = {'train_loss': avg_loss } #, 'train_acc': train_accuracy, 'val_loss': avg_val_loss,'val_acc': val_accuracy}
           
    for tag, value in conv_model.named_parameters():
        tag = tag.replace('.', '/')
        writer.add_histogram(tag, value.data.cpu().numpy(), epoch)
        if value.grad is None:
            writer.add_histogram(tag+ '_emb' + '/grad', 0, epoch)           
        else:
            writer.add_histogram(tag+ '_emb' + '/grad', value.grad.data.cpu().numpy(), epoch)           
    
    for tag, value in encoder.named_parameters():
        tag = tag.replace('.', '/')
        writer.add_histogram(tag, value.data.cpu().numpy(), epoch)
        if value.grad is None:
            writer.add_histogram(tag+ '_enc' + '/grad', 0, epoch)           
        else:
            writer.add_histogram(tag+ '_enc' + '/grad', value.grad.data.cpu().numpy(), epoch)
    
    
    for tag, value in decoder.named_parameters():
        tag = tag.replace('.', '/')
        writer.add_histogram(tag, value.data.cpu().numpy(), epoch)
        if value.grad is None:
            writer.add_histogram(tag+'_dec' + '/grad', 0, epoch)           
        else:
            writer.add_histogram(tag+'_dec' + '/grad', value.grad.data.cpu().numpy(), epoch)
    

    for tag, value in info_dict.items():
        writer.add_scalar(tag, value, epoch)

trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
2.9374151184227006
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing
2.186834273994042
trained without teacher enforcing
trained without teacher enforcing
trained without teacher enforcing


In [28]:
torch.save(conv_model.state_dict(), 'model/lstm_embedder.pt')
torch.save(encoder.state_dict(), 'model/lstm_encoder.pt')
torch.save(decoder.state_dict(), 'model/lstm_decoder.pt')

In [None]:
for x, y in zip(full_x, token_y):
    #print(x)
    print("ground truth: ", " ".join([tokenizer.id_to_token(i) for i in y]))
    emb_x = conv_model(x.unsqueeze(0)).permute(2, 0, 1)
    hidden_enc, cell_enc = encoder.initial_hidden_cell()
    enc_outputs = torch.zeros(79, hidden_layers * 2)
    
    for i in range(len(emb_x)):
        seq = emb_x[i].unsqueeze(0)
        enc_out, hidden_enc, cell_enc = encoder(seq, hidden_enc, cell_enc)
        enc_outputs[i] = enc_out[0, 0]       
    
    hidden_dec, cell_dec = decoder.initial_hidden_cell()
    decoder_input = torch.tensor([[start_token]], dtype = torch.long)
    for i in range(len(y)):
        logit, hidden_dec, cell_dec = decoder(decoder_input, hidden_dec, cell_dec, enc_outputs)
        _, val = logit.topk(1)
        decoder_input = val.squeeze(0).detach()
        print("predicted: ", tokenizer.id_to_token(decoder_input))
        if decoder_input == end_token:
            break
    print("\n\n\n\n\n")

## Encoder 3 - Multi-Head Attention Transformer Encoder

In [None]:
# Work in progress, will clean later


from torch.nn import TransformerEncoder, TransformerEncoderLayer

class ECGTransformerEncoder(nn.Module):
    # Takes the ECG discrete signals sequence and maps into a probability distribution of diagnosis
    # For working/verification purposes
    def __init__(self, vector_size, embed_dim, n_heads, hidden_linear_dim, n_layers, dropout):
        super(ECGTransformerEncoder, self).__init__()
        self.model_type = "Transformer"
        self.positional_encoder = PositionalEncoder(embed_dim, dropout)
    
        #Since our data is already discrete numbers, might need some tweaking for this
        self.embedder = conv_embedder
                        #64 31              #39        64
        
        
        self.encoder = TransformerEncoder(
            TransformerEncoderLayer(embed_dim, n_heads, hidden_linear_dim, dropout),
            n_layers)
        
        self.n_inputs = embed_dim
        self.n_layers = n_layers
        
        # Simple linear decoder
        self.decoder = nn.Sequential(
                        nn.Linear(768, 17),
                        Transpose(17, 2500),
                        nn.Linear(2500, 30),
                        nn.LogSoftmax()
                        )
        self.init_weights()
        
    def init_weights(self):
        #self.embedder.weight.data.uniform_(-.1, .1)
        #self.decoder.bias.data.zero_()
        #self.decoder.weight.data.uniform_(-.1, .1)
        pass
        
    def forward(self, x):
        #x = self.embedder(x) # * math.sqrt(self.n_inputs)
        x = x.squeeze(0)
        #x = x.view(2500, 8)
        x = x.unsqueeze(1)
        x = self.positional_encoder(x)
        x = self.encoder(x)
        x = x.squeeze(1) 
        #x = self.decoder(x)
        return x


## Encoder 4 - FNET Transformer Architecture

In [None]:
class FeedForwardNet(nn.Module):
    def __init__(self, features, expansion, dropout):
        super(FeedForwardNet, self).__init__()
        self.linear_1 = nn.Linear(features, features * expansion)
        self.linear_2 = nn.Linear(features * expansion, features)
        self.dropout_1 = nn.Dropout(dropout)
        #self.dropout_2 = nn.Dropout(dropout)
        self.norm_1 = nn.LayerNorm(features)

    def forward(self, x):
        res = x
        x = F.relu(self.linear_1(x))
        x = self.dropout_1(x)
        x = self.linear_2(x)
        x = self.norm_1(x + res)
        return x

    
class FNETLayer(nn.Module):
    def __init__(self, features, expansion, dropout):
        super(FNETLayer, self).__init__()
        self.feed_forward = FeedForwardNet(features, expansion, dropout)
        self.norm_1 = nn.LayerNorm(features)
        
        # additions to add
        self.attention_layer = nn.TransformerEncoderLayer(256, 16, 512, 0.1)
    
    def forward(self, x):
        res = x
        x = fft.fftn(x, dim = (-2, -1)).real
        x = self.norm_1(x + res)
        x = self.feed_forward(x)
        x = self.attention_layer(x)
        return x
    
class FNETEncoder(nn.TransformerEncoder):
    def __init__(self, features, expansion=2, dropout=0.5, num_layers=6):
        encoder_layer = FNETLayer(features, expansion, dropout)
        super().__init__(encoder_layer=encoder_layer, num_layers=num_layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
 

## Misc Encoder Helper Functions/Components

In [174]:
class PositionalEncoder(nn.Module):
    # Necessary to store positional data about the input data
    def __init__(self, embed_dim, dropout=0.1, max_len=2500):
        super(PositionalEncoder, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pos_encoding = torch.zeros(max_len, 1, embed_dim)
        position = torch.arange(max_len).unsqueeze(1)
        
        divisor = torch.exp(torch.arange(0, embed_dim, 2).float() * (- math.log(10000.0) / embed_dim))
        
        pos_encoding[:, 0, 0::2] = torch.sin(position * divisor)
        pos_encoding[:, 0, 1::2] = torch.cos(position * divisor)
        self.register_buffer("pos_encoding", pos_encoding)

        
    def forward(self, x):
        pos_encoding = self.pos_encoding.repeat(1, x.shape[1], 1)
        x = x + pos_encoding[:x.size(0), :]
        return self.dropout(x)
   

class Transpose(nn.Module):
    def __init__(self, *args):
        super(Transpose, self).__init__()
        self.shape = args

    def forward(self, x):
        # If the number of the last batch sample in the data set is smaller than the defined batch_batch size, mismatch problems will occur. You can modify it yourself, for example, just pass in the shape behind, and then enter it through x.szie(0).
        return x.view(self.shape)

    
    

class WindowEmbedder(nn.Module):
    # Necessary to convert the signal into "word" vectors for transformer processing.
    # Currently a simple group and slice method, but will modify later for multi-channel inputs
    
    def __init__(self, num_slices, size_of_slice):
        super(SignalEmbedder, self).__init__()
        self.num_slices = num_slices
        self.size_of_slice = size_of_slice
        
    def forward(self, x):
        x = x[: self.num_slices * self.size_of_slice]
        x = x.reshape((self.num_slices, self.size_of_slice))
        return x  


## Decoder 1 - Huggingface GPT2 Decoder

In [51]:
# Replace with child class of GPT2LMHeadModel

class GPT2LMHeadModel(GPT2PreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.transformer = GPT2Model(config)
        #self.transformer.forward = forward2.__get__(self.transformer, GPT2Model)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.init_weights()

        # Model parallel
        self.model_parallel = False
        self.device_map = None
        
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.transformer.h))
        self.transformer.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.transformer.first_device)
        self.model_parallel = True
        
    def deparallelize(self):
        self.transformer.deparallelize()
        self.transformer = self.transformer.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        torch.cuda.empty_cache()

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
        token_type_ids = kwargs.get("token_type_ids", None)
        # only last token for inputs_ids if past is defined in kwargs
        if past:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            if token_type_ids is not None:
                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

        attention_mask = kwargs.get("attention_mask", None)
        position_ids = kwargs.get("position_ids", None)

        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past:
                position_ids = position_ids[:, -1].unsqueeze(-1)
        else:
            position_ids = None
        return {
            "input_ids": input_ids,
            "past_key_values": past,
            "use_cache": kwargs.get("use_cache"),
            "encoder_hidden_states": kwargs.get("encoder_hidden_states", None), # The one line changed hehe
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
        }
    
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)

        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )
    
    @staticmethod
    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
        """
        This function is used to re-order the :obj:`past_key_values` cache if
        :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
        called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
        """
        return tuple(
            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
            for layer_past in past
        )

In [53]:
# define tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# preprocess training labels and tokenize
train_labels = list(waveform_lead_rhythm_diag['diagnosis'])
inputs = tokenizer(train_labels, padding = True, verbose = False, return_tensors="pt")

#Necessary to add for generating first word
inputs["input_ids"] = torch.cat((torch.tensor([[50256] for i in range(len(inputs["input_ids"]))]), inputs["input_ids"]), dim=1)
inputs["attention_mask"] = torch.cat((torch.tensor([[1] for i in range(len(inputs["attention_mask"]))]), inputs["attention_mask"]), dim=1)

inputs = inputs.to(device)
inputs

{'input_ids': tensor([[50256, 31369,   385,  ..., 50256, 50256, 50256],
        [50256, 11265,  7813,  ..., 50256, 50256, 50256],
        [50256, 11265,  7813,  ..., 50256, 50256, 50256],
        ...,
        [50256, 11265,  7813,  ..., 50256, 50256, 50256],
        [50256,  9509,  4565,  ..., 50256, 50256, 50256],
        [50256, 31369,   385,  ..., 50256, 50256, 50256]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

## EncoderDecoder - FNET Encoder Huggingface Decoder

In [178]:
# create encoder decoder model with GPT2 
class CustEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder, tokenizer):
        super(CustEncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pos_enb = PositionalEncoder(embed_dim = 256)
        self.embedder = embedder
        self.tokenizer = tokenizer
    
    def forward(self, x):
        ecgs, labels = x
        x = self.embedder(ecgs).permute(2, 0, 1)
        x = self.pos_enb(x).permute(1, 0, 2)
        x = self.encoder(x)
        out = self.decoder(**labels, labels = labels["input_ids"], encoder_hidden_states = x.contiguous())
        return out
    
    # Should only take 1 input at a time
    def predict_single(self, x):
        ecgs = x
        x = self.embedder(ecgs).permute(2, 0, 1)
        print(x.shape)
        x = self.pos_enb(x).permute(1, 0, 2)
        print(x.shape)
        x = self.encoder(x)
        print(x.shape)
        return self.tokenizer.decode(self.decoder.generate(encoder_hidden_states = x.contiguous())[0])

    
    # Takes in multiple inputs
    def predict_batch(self, x):
        ecgs = x
        x = self.embedder(ecgs).permute(2, 0, 1)
        x = self.pos_enb(x).permute(1, 0, 2)
        x = self.encoder(x)
        output = []
        for ecg in x:
            output.append(self.tokenizer.decode(self.decoder.generate(encoder_hidden_states = ecg.unsqueeze(0).contiguous())[0]))
        return output
    
    def return_enc(self):
        return self.encoder

    
# Connect an embedder and de-embedder for training (we will then isolate the Encoder portion of this autoencoder as our embedder)
class ConvAutoEncoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConvAutoEncoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
    def make_encoder(self):
        return self.encoder
    
    def make_decoder(self):
        return self.decoder
    

In [55]:
# The following code is for the pre-embedder
    
# Make embedder
conv_model = nn.Sequential()
init_channels = 8
for i in range(2):
    next_channels = 2 * init_channels
    conv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = KER_SIZE, padding = PADDING, stride = 1))
    conv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    conv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    conv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = KER_SIZE, padding = PADDING))
    conv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
conv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 768, kernel_size = KER_SIZE, padding = PADDING))
conv_model.add_module('act_fin', nn.ReLU())
conv_model.add_module('batch_fin', nn.BatchNorm1d(768))


# Make de-embedder
deconv_model = nn.Sequential()
init_channels = 768
for i in range(2):
    next_channels = init_channels // 2
    deconv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = KER_SIZE, padding = PADDING, stride = 1))
    deconv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    deconv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    deconv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = KER_SIZE, padding = PADDING))
    deconv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
deconv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 8, kernel_size = KER_SIZE, padding = PADDING))
deconv_model.add_module('act_fin', nn.ReLU())
deconv_model.add_module('batch_fin', nn.BatchNorm1d(8))

auto_model = ConvAutoEncoder(conv_model, deconv_model).to(device)
auto_optimizer = torch.optim.Adam(auto_model.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)


# Training params
loss_function = nn.MSELoss()

epochs = 0

for i in range(epochs):
    auto_optimizer.zero_grad()
    outputs = auto_model(ecg_data)
    loss = loss_function(outputs, ecg_data)
    loss.backward(retain_graph=True)
    auto_optimizer.step()
    print(loss)
        
# Saving/loading weights
torch.save(auto_model.state_dict(), 'model/autoencoder.pt')
auto_model.load_state_dict(torch.load('model/autoencoder.pt'))
conv_embedder = auto_model.make_encoder()
torch.save(conv_embedder.state_dict(), "model/embedder.pt")

In [56]:
# Define encoder, we don't need to pretrain rn
encoder = FNETEncoder(768, expansion = 2, dropout=0.1, num_layers = 6)

In [179]:
# define and pretrain Decoder
config = GPT2Config(vocab_size = 99, n_embd = 256, n_head = 16, add_cross_attention = True, is_encoder_decoder = False, bos_token_id=0, eos_token_id= 0)
print(config)
decoder = GPT2LMHeadModel(config = config)
# pretrain decoder
optimizer = torch.optim.Adam(decoder.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)

inputs = transformer_tokenizer(full_y)


#decoder.load_state_dict(torch.load('model/gpt2.pt'))
# set number of epochs
epochs = 150
for i in range(epochs):
    optimizer.zero_grad()
    outputs = decoder(**inputs, labels = inputs["input_ids"])
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    
    print(loss)
    
    for tag, value in decoder.named_parameters():
        tag = tag.replace('.', '/')
        writer.add_histogram(tag+ "_gpt2", value.data.cpu().numpy(), epoch)
        if value.grad is None:
            writer.add_histogram(tag+ '_gpt2' + '/grad', 0, epoch)           
        else:
            writer.add_histogram(tag+ '_gpt2' + '/grad', value.grad.data.cpu().numpy(), epoch)
    
    
    
torch.save(decoder.state_dict(), 'model/gpt2.pt')

GPT2Config {
  "activation_function": "gelu_new",
  "add_cross_attention": true,
  "attn_pdrop": 0.1,
  "bos_token_id": 0,
  "embd_pdrop": 0.1,
  "eos_token_id": 0,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 256,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.9.1",
  "use_cache": true,
  "vocab_size": 99
}



KeyboardInterrupt: 

In [None]:
# define component models

conv_embedder = conv_model #auto_model.make_encoder()

encoder = FNETEncoder(256, expansion = 2, dropout=0.1, num_layers = 6)

#decoder.load_state_dict(torch.load('model/gpt2.pt'))

enc_dec_model = CustEncoderDecoder(encoder, decoder, conv_embedder, transformer_tokenizer)

enc_dec_model.predict_single(ecg_data[0].unsqueeze(0))

In [183]:
print(inputs)

{'input_ids': tensor([[ 0, 94, 92,  ...,  0,  0,  0],
        [ 0, 32,  8,  ...,  0,  0,  0],
        [ 0, 32,  8,  ...,  0,  0,  0],
        ...,
        [ 0, 32,  8,  ...,  0,  0,  0],
        [ 0, 81, 50,  ...,  0,  0,  0],
        [ 0, 60, 88,  ...,  0,  0,  0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}


In [None]:
# train encoder decoder model!
optimizer = torch.optim.Adam(enc_dec_model.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)

# set number of epochs
epochs = 130

for i in range(epochs):
    optimizer.zero_grad()
    losses = 0
    j = 0
    for ecg in full_x:
        outputs = enc_dec_model((ecg.unsqueeze(0), {"input_ids": inputs["input_ids"][j], "attention_mask": inputs["attention_mask"][j]}))
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        losses += loss
        j += 1
    if i == 40:
        optimizer = torch.optim.Adam(enc_dec_model.parameters(), lr = 1e-4)
    if i == 60:
        optimizer = torch.optim.Adam(enc_dec_model.parameters(), lr = 1e-5)
    if i == 80:
        optimizer = torch.optim.Adam(enc_dec_model.parameters(), lr = 1e-6)
    if i == 105:
        optimizer = torch.optim.Adam(enc_dec_model.parameters(), lr = 1e-7)
    print(losses)
    for tag, value in decoder.named_parameters():
        tag = tag.replace('.', '/')
        writer.add_histogram(tag+ "_mixup2", value.data.cpu().numpy(), i)
        if value.grad is None:
            writer.add_histogram(tag+ '_mixup2' + '/grad', 0, i)           
        else:
            writer.add_histogram(tag+ '_mixup2' + '/grad', value.grad.data.cpu().numpy(), i)
 

   
torch.save(enc_dec_model.state_dict(), 'model/gpt2_enc_dec2.pt')

In [None]:
enc_dec_model.load_state_dict(torch.load('model/gpt2_enc_dec.pt'))

In [85]:
torch.save(model.state_dict(), 'model/gpt2.pt')


In [195]:
for inp, out in zip(ecg_data, inputs["input_ids"]):
    print("ground truth: ", transformer_tokenizer.decode(out))
    print("predicted label: ", enc_dec_model.predict_single(inp.unsqueeze(0)))
    print("\n\n\n\n")

ground truth:  sinus bradycardia otherwise normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block cannot rule out right bundle branch block abnormal ecg abnormal ecg





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg abnormal ecg





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg





ground truth:  sinus rhythm with marked sinus arrhythmia otherwise normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  sinus tachycardia inferior infarct abnormal ecg





ground truth:  normal sinus rhythm with sinus arrhythmia cannot rule out inferior infarct abnormal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out





ground truth:  prolonged qt abnormal ecg t wave inversion now evident in anterior leads qt has lengthened
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  abnormal ecg





ground truth:  sinus rhythm with marked sinus arrhythmia otherwise normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block present abnormal ecg





ground truth:  normal sinus rhythm possible left atrial enlargement incomplete left bundle branch block borderline ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg





ground truth:  normal sinus rhythm normal ecg incomplete left bundle branch block is no longer present
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block right bundle branch block abnormal ecg abnormal ecg abnormal ecg abnormal ecg abnormal ecg





ground truth:  sinus tachycardia left axis deviation pulmonary disease pattern inferior infarct abnormal ecg sinus rhythm has replaced atrial fibrillation inverted t waves have replaced nonspecific t wave abnormality in lateral leads
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out





ground truth:  atrial fibrillation with rapid ventricular response with premature ventricular or aberrantly conducted complexes right bundle branch block left posterior fascicular block abnormal ecg atrial fibrillation has replaced sinus rhythm qrs duration has increased
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  





ground truth:  undetermined rhythm right bundle branch block abnormal ecg current undetermined rhythm precludes rhythm comparison, needs review qrs duration has increased qt has lengthened
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block cannot rule out right bundle branch block present abnormal ecg





ground truth:  undetermined rhythm possible right ventricular hypertrophy nonspecific t wave abnormality abnormal ecg current undetermined rhythm precludes rhythm comparison, needs review is no longer present
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  sinus tachycardia right bundle branch block abnormal ecg abnormal ecg premature supraventricular complexes





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  normal ecg





ground truth:  electronic atrial pacemaker indeterminate axis pulmonary disease pattern st elevation consider anterolateral injury or acute infarct st elevation consider inferior injury or acute infarct abnormal ecg electronic atrial pacemaker has replaced electronic ventricular pacemaker
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out abnormal ecg





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg abnormal ecg abnormal ecg abnormal ecg present sinus rhythm atrial fibrillation sinus rhythm lateral leads lateral leads





ground truth:  atrial fibrillation left axis deviation pulmonary disease pattern septal infarct abnormal ecg atrial fibrillation has replaced sinus rhythm nonspecific t wave abnormality has replaced inverted t waves in lateral leads
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block cannot rule out





ground truth:  sinus tachycardia and fusion complexes right bundle branch block cannot rule out inferior infarct t wave abnormality, consider lateral ischemia abnormal ecg previous ecg has undetermined rhythm, needs review right bundle branch block is now present
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  left axis deviation





ground truth:  sinus tachycardia right bundle branch block t wave abnormality, consider lateral ischemia abnormal ecg premature supraventricular complexes are now present
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out right bundle branch block abnormal ecg





ground truth:  av sequential or dual chamber electronic pacemaker electronic ventricular pacemaker has replaced sinus rhythm vent. rate has decreased
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  sinus tachycardia right bundle branch block abnormal ecg





ground truth:  av sequential or dual chamber electronic pacemaker
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  left axis deviation pulmonary disease pattern t wave inversion now evident in qrs duration abnormal ecg atrial fibrillation has increased





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out right bundle branch block premature supraventricular complexes





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg





ground truth:  normal sinus rhythm cannot rule out anterior infarct abnormal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  left axis deviation





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  sinus tachycardia right bundle branch block abnormal ecg





ground truth:  normal sinus rhythm possible left atrial enlargement incomplete right bundle branch block borderline ecg incomplete right bundle branch block is now present
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg abnormal ecg





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg t wave inversion now evident in anterior leads abnormal ecg





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  sinus tachycardia right bundle branch block abnormal ecg abnormal ecg is no longer present





ground truth:  sinus bradycardia
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg





ground truth:  sinus rhythm with marked sinus arrhythmia otherwise normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  left axis deviation





ground truth:  normal sinus rhythm minimal voltage criteria for lvh, may be normal variant borderline ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg premature supraventricular complexes





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block right bundle branch block abnormal ecg





ground truth:  sinus bradycardia otherwise normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block left axis deviation





ground truth:  sinus tachycardia cannot rule out anterior infarct abnormal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block cannot rule out right bundle branch block abnormal ecg





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block cannot rule out





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  left axis deviation





ground truth:  normal sinus rhythm inferior infarct abnormal ecg t wave amplitude has decreased in lateral leads
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg abnormal ecg





ground truth:  electronic ventricular pacemaker previous ecg has undetermined rhythm, needs review
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  normal ecg pulmonary disease pattern is no longer present qrs duration





ground truth:  normal sinus rhythm inferior infarct abnormal ecg inverted t waves have replaced nonspecific t wave abnormality in inferior leads
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  normal ecg





ground truth:  normal sinus rhythm possible anterior infarct abnormal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block cannot rule out right bundle branch block abnormal ecg abnormal ecg





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  left axis deviation





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg





ground truth:  normal sinus rhythm inferior infarct prolonged qt abnormal ecg t wave inversion no longer evident in lateral leads
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  abnormal ecg





ground truth:  normal sinus rhythm are no longer present
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block abnormal ecg





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  left axis deviation pulmonary disease pattern





ground truth:  normal sinus rhythm inferior infarct prolonged qt abnormal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out





ground truth:  normal sinus rhythm st abnormality, possible digitalis effect abnormal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out right bundle branch block abnormal ecg abnormal ecg





ground truth:  normal sinus rhythm normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  cannot rule out





ground truth:  electronic ventricular pacemaker vent. rate has decreased
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  right bundle branch block right bundle branch block abnormal ecg abnormal ecg





ground truth:  sinus rhythm with marked sinus arrhythmia otherwise normal ecg
torch.Size([157, 1, 256])
torch.Size([1, 157, 256])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


torch.Size([1, 157, 256])
predicted label:  abnormal ecg abnormal ecg





