In [68]:
# import all packages needed
import string 
import numpy as np
import pandas as pd
from matplotlib import pyplot
from base64 import b64decode as decode
import math
import torch
import torch.nn as nn 
from torch.nn import CrossEntropyLoss, MSELoss
import torch.fft as fft
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2Config, GPT2PreTrainedModel, GPT2Model
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from typing import Optional, Tuple
import transformers

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

# The only time we need to define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data Processing / Cleaning

In [7]:
# use class base64 to decode waveform data
def to_array(wf):
    barr = bytearray(decode(wf))
    vals = np.array(barr)
    return vals.view(np.int16).astype(np.float32)

# read in data
exam_data = pd.read_csv("data/d_exam.csv").drop(columns = ["site_num", "patient_id_edit"])
waveform_data = pd.read_csv("data/d_waveform.csv")
lead_data = pd.read_csv("data/d_lead_data.csv").drop(columns = ["exam_id"])
diagnosis_data = pd.read_csv("data/d_diagnosis.csv").drop(columns = ["user_input"])

# add decoded data as a column to lead dataz
waveforms = list(lead_data['waveform_data'])
lead_data['decoded_waveform'] = [to_array(i) for i in waveforms]

# merge waveform data and lead data
waveform_lead = lead_data.merge(waveform_data, how = "left", left_on = "waveform_id", right_on = "waveform_id", suffixes = (None, None))

#  sort by exam id and lead id
waveform_lead.sort_values(by = ["waveform_id", "lead_id"], inplace = True)

waveform_lead.loc[:, ['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']]


# adding the diagnosis and labels
waveform_and_diag = pd.merge(waveform_lead[['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']], diagnosis_data[["exam_id", "Full_text", "Original_Diag"]], left_on= "exam_id", right_on="exam_id")


# concatenate all leads into a single array
waveform_lead_concat = waveform_lead.groupby(["exam_id", "waveform_type"])['decoded_waveform'].apply(lambda x: tuple(x)).reset_index()

# remove irregular observations, concat tuple into numpy array
waveform_lead_concat = waveform_lead_concat.drop([12,17], axis = 0)
waveform_lead_concat['decoded_waveform'] = waveform_lead_concat['decoded_waveform'].apply(lambda x: np.vstack(x))
waveform_lead_rhythm = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Rhythm"]

for value in waveform_lead_rhythm["decoded_waveform"]:
    value /= 1024
    value += .5

exams = diagnosis_data["exam_id"].unique()

diagnosis_data = diagnosis_data[diagnosis_data['Original_Diag'] == 1].dropna()
searchfor = ['previous', 'unconfirmed', 'compared', 'interpretation', 'significant']
diagnosis_data = diagnosis_data.loc[diagnosis_data['Full_text'].str.contains('|'.join(searchfor)) != 1]

diagnosis_data.sort_values(by=["exam_id", "statement_order"], inplace=True)
diagnoses = []
curr_id = 0
curr_string = ""
for i, row in diagnosis_data.iterrows():
    if row["statement_order"] == 1 and curr_string != "":
        curr_string = curr_string.lower().translate(str.maketrans('', '', string.punctuation))
        val = [curr_id, curr_string[1:]]
        diagnoses.append(val)
        curr_string = ""
        curr_id = row["exam_id"]

    if curr_id == 0:
        curr_id = row["exam_id"]
    
    curr_string += " " + row["Full_text"]

diagnosis_df = pd.DataFrame(diagnoses, columns = ['exam_id', 'diagnosis'])
waveform_lead_rhythm_diag = pd.merge(left=waveform_lead_rhythm, right=diagnosis_df, left_on='exam_id', right_on='exam_id')

waveform_lead_rhythm_diag

Unnamed: 0,exam_id,waveform_type,decoded_waveform,diagnosis
0,548759,Rhythm,"[[0.50390625, 0.5029297, 0.5019531, 0.49902344...",normal sinus rhythm low voltage qrs borderline...
1,549871,Rhythm,"[[0.4921875, 0.4921875, 0.4921875, 0.4921875, ...",sinus bradycardia otherwise normal ecg
2,550602,Rhythm,"[[0.47851562, 0.48046875, 0.48242188, 0.484375...",sinus tachycardia otherwise normal ecg
3,551485,Rhythm,"[[0.5449219, 0.5439453, 0.54296875, 0.5410156,...",normal sinus rhythm normal ecg
4,552077,Rhythm,"[[0.49316406, 0.49609375, 0.49902344, 0.494140...",normal sinus rhythm normal ecg
5,552856,Rhythm,"[[0.46875, 0.46875, 0.46875, 0.46777344, 0.466...",normal sinus rhythm with sinus arrhythmia mini...
6,553115,Rhythm,"[[0.4921875, 0.4951172, 0.49804688, 0.49804688...",atrial fibrillation abnormal ecg normal sinus ...


In [35]:
# define ecg_data, the dataset we will be using for training purposes
ecg_data = torch.tensor(list(waveform_lead_rhythm_diag['decoded_waveform'])).float().to(device)

## Embedder: Conv1D

In [None]:
# This is where we define components to be used for both the Conv1D encoder, and a Conv1D pre-embedder into a Transformer Encoder.

LR = 1e-3
KER_SIZE = 11
PADDING = 5

# define global max pooling
class global_max_pooling_1d(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x, _ = torch.max(x, dim = 2)
        return(x)

# define resblock for neural nets
class ResBlock1D(nn.Module):
    def __init__(self, num_filters, kernel_size, padding, groups = 1, stride = 1):
        super(ResBlock1D, self).__init__()
        self.act = nn.ReLU()
        self.conv1d_1 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.conv1d_2 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.batch_norm_1 = nn.BatchNorm1d(num_filters)
        self.batch_norm_2 = nn.BatchNorm1d(num_filters)

    def forward(self, x):
        res = x
        x = self.batch_norm_1(self.act(self.conv1d_1(x)))
        x = self.batch_norm_2(self.act(self.conv1d_2(x)))
        return x + res

  

tensor(1.3016, grad_fn=<MseLossBackward>)
tensor(1.2688, grad_fn=<MseLossBackward>)
tensor(1.1736, grad_fn=<MseLossBackward>)
tensor(1.1416, grad_fn=<MseLossBackward>)
tensor(1.0946, grad_fn=<MseLossBackward>)
tensor(1.0625, grad_fn=<MseLossBackward>)
tensor(1.0360, grad_fn=<MseLossBackward>)
tensor(0.9886, grad_fn=<MseLossBackward>)
tensor(0.9671, grad_fn=<MseLossBackward>)
tensor(0.9323, grad_fn=<MseLossBackward>)
tensor(0.9133, grad_fn=<MseLossBackward>)
tensor(0.8988, grad_fn=<MseLossBackward>)
tensor(0.8857, grad_fn=<MseLossBackward>)
tensor(0.8753, grad_fn=<MseLossBackward>)
tensor(0.8668, grad_fn=<MseLossBackward>)
tensor(0.8571, grad_fn=<MseLossBackward>)
tensor(0.8494, grad_fn=<MseLossBackward>)
tensor(0.8425, grad_fn=<MseLossBackward>)
tensor(0.8355, grad_fn=<MseLossBackward>)
tensor(0.8292, grad_fn=<MseLossBackward>)
tensor(0.8223, grad_fn=<MseLossBackward>)
tensor(0.8157, grad_fn=<MseLossBackward>)
tensor(0.8091, grad_fn=<MseLossBackward>)
tensor(0.8028, grad_fn=<MseLossBac

## Encoder 1: ResNet Encoder

In [13]:
# HYPERPARAMETERS
J = 10 # max number of filters per class
LR = 1e-3

# build resent model and display the shape of feed through
conv_model = nn.Sequential()
init_channels = 8
for i in range(5):
    next_channels = 2 * init_channels
    conv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = 249, padding = 124, stride = 1))
    conv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    conv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    conv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = 249, padding = 124))
    conv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
    
conv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 8, kernel_size = 249, padding = 124))
conv_model.add_module('act_fin', nn.ReLU())
conv_model.add_module('batch_fin', nn.BatchNorm1d(8))

## LSTM Decoder

In [17]:
# define hyperparameters 
hidden_layers = 250
embedding_dim = 8
num_words = len(dict_words)

class LSTM_EncoderDecoder(nn.Module):
    def __init__(self, h_dim, e_dim, word_list_length):
        super(ECG_LSTM, self).__init__()
        self.lstm = nn.LSTM(e_dim, h_dim, num_layers = 4, bidirectional = True)
        
    def forward(self, seq):
        seq_embedded = seq.view(len(seq), -1, embedding_dim)
        final_hidd, _ = self.lstm(seq_embedded)
        dec_seq = self.linear(final_hidd)
        return F.log_softmax(dec_seq, dim = 1)
    

## Encoder 3 - Multi-Head Attention Transformer Encoder

In [None]:
# Work in progress, will clean later


from torch.nn import TransformerEncoder, TransformerEncoderLayer

class ECGTransformerEncoder(nn.Module):
    # Takes the ECG discrete signals sequence and maps into a probability distribution of diagnosis
    # For working/verification purposes
    def __init__(self, vector_size, embed_dim, n_heads, hidden_linear_dim, n_layers, dropout):
        super(ECGTransformerEncoder, self).__init__()
        self.model_type = "Transformer"
        self.positional_encoder = PositionalEncoder(embed_dim, dropout)
    
        #Since our data is already discrete numbers, might need some tweaking for this
        self.embedder = conv_embedder
                        #64 31              #39        64
        
        
        self.encoder = TransformerEncoder(
            TransformerEncoderLayer(embed_dim, n_heads, hidden_linear_dim, dropout),
            n_layers)
        
        self.n_inputs = embed_dim
        self.n_layers = n_layers
        
        # Simple linear decoder
        self.decoder = nn.Sequential(
                        nn.Linear(768, 17),
                        Transpose(17, 2500),
                        nn.Linear(2500, 30),
                        nn.LogSoftmax()
                        )
        self.init_weights()
        
    def init_weights(self):
        #self.embedder.weight.data.uniform_(-.1, .1)
        #self.decoder.bias.data.zero_()
        #self.decoder.weight.data.uniform_(-.1, .1)
        pass
        
    def forward(self, x):
        #x = self.embedder(x) # * math.sqrt(self.n_inputs)
        x = x.squeeze(0)
        #x = x.view(2500, 8)
        x = x.unsqueeze(1)
        x = self.positional_encoder(x)
        x = self.encoder(x)
        x = x.squeeze(1) 
        #x = self.decoder(x)
        return x


## Encoder 4 - FNET Transformer Architecture

In [28]:
class FeedForwardNet(nn.Module):
    def __init__(self, features, expansion, dropout):
        super(FeedForwardNet, self).__init__()
        self.linear_1 = nn.Linear(features, features * expansion)
        self.linear_2 = nn.Linear(features * expansion, features)
        self.dropout_1 = nn.Dropout(dropout)
        #self.dropout_2 = nn.Dropout(dropout)
        self.norm_1 = nn.LayerNorm(features)

    def forward(self, x):
        res = x
        x = F.relu(self.linear_1(x))
        x = self.dropout_1(x)
        x = self.linear_2(x)
        x = self.norm_1(x + res)
        return x

    
class FNETLayer(nn.Module):
    def __init__(self, features, expansion, dropout):
        super(FNETLayer, self).__init__()
        self.feed_forward = FeedForwardNet(features, expansion, dropout)
        self.norm_1 = nn.LayerNorm(features)
    
    def forward(self, x):
        res = x
        x = fft.fftn(x, dim = (-2, -1)).real
        x = self.norm_1(x + res)
        x = self.feed_forward(x)
        return x
    
class FNETEncoder(nn.TransformerEncoder):
    def __init__(self, features, expansion=2, dropout=0.5, num_layers=6):
        encoder_layer = FNETLayer(features, expansion, dropout)
        super().__init__(encoder_layer=encoder_layer, num_layers=num_layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
 

## Misc Encoder Helper Functions/Components

In [30]:
class PositionalEncoder(nn.Module):
    # Necessary to store positional data about the input data
    def __init__(self, embed_dim, dropout=0.1, max_len=2500, batch_size = 1):
        super(PositionalEncoder, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pos_encoding = torch.zeros(max_len, 1, embed_dim)
        position = torch.arange(max_len).unsqueeze(1)
        
        divisor = torch.exp(torch.arange(0, embed_dim, 2).float() * (- math.log(10000.0) / embed_dim))
        
        pos_encoding[:, 0, 0::2] = torch.sin(position * divisor)
        pos_encoding[:, 0, 1::2] = torch.cos(position * divisor)
        pos_encoding = pos_encoding.repeat(1, batch_size, 1)
        self.register_buffer("pos_encoding", pos_encoding)

        
    def forward(self, x):
        x = x + self.pos_encoding[:x.size(0), :]
        return self.dropout(x)
   

class Transpose(nn.Module):
    def __init__(self, *args):
        super(Transpose, self).__init__()
        self.shape = args

    def forward(self, x):
        # If the number of the last batch sample in the data set is smaller than the defined batch_batch size, mismatch problems will occur. You can modify it yourself, for example, just pass in the shape behind, and then enter it through x.szie(0).
        return x.view(self.shape)

    
    

class WindowEmbedder(nn.Module):
    # Necessary to convert the signal into "word" vectors for transformer processing.
    # Currently a simple group and slice method, but will modify later for multi-channel inputs
    
    def __init__(self, num_slices, size_of_slice):
        super(SignalEmbedder, self).__init__()
        self.num_slices = num_slices
        self.size_of_slice = size_of_slice
        
    def forward(self, x):
        x = x[: self.num_slices * self.size_of_slice]
        x = x.reshape((self.num_slices, self.size_of_slice))
        return x  


## Decoder 1 - Huggingface GPT2 Decoder

In [63]:
# Replace with child class of GPT2LMHeadModel

class GPT2LMHeadModel(GPT2PreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.transformer = GPT2Model(config)
        #self.transformer.forward = forward2.__get__(self.transformer, GPT2Model)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.init_weights()

        # Model parallel
        self.model_parallel = False
        self.device_map = None
        
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.transformer.h))
        self.transformer.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.transformer.first_device)
        self.model_parallel = True
        
    def deparallelize(self):
        self.transformer.deparallelize()
        self.transformer = self.transformer.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        torch.cuda.empty_cache()

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
        token_type_ids = kwargs.get("token_type_ids", None)
        # only last token for inputs_ids if past is defined in kwargs
        if past:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            if token_type_ids is not None:
                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

        attention_mask = kwargs.get("attention_mask", None)
        position_ids = kwargs.get("position_ids", None)

        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past:
                position_ids = position_ids[:, -1].unsqueeze(-1)
        else:
            position_ids = None
        return {
            "input_ids": input_ids,
            "past_key_values": past,
            "use_cache": kwargs.get("use_cache"),
            "encoder_hidden_states": kwargs.get("encoder_hidden_states", None), # The one line changed hehe
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
        }
    
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)

        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )
    
    @staticmethod
    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
        """
        This function is used to re-order the :obj:`past_key_values` cache if
        :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
        called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
        """
        return tuple(
            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
            for layer_past in past
        )

In [64]:
# define tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# preprocess training labels and tokenize
train_labels = list(waveform_lead_rhythm_diag['diagnosis'])
inputs = tokenizer(train_labels, padding = True, verbose = False, return_tensors="pt")

#Necessary to add for generating first word
inputs["input_ids"] = torch.cat((torch.tensor([[50256] for i in range(len(inputs["input_ids"]))]), inputs["input_ids"]), dim=1)
inputs["attention_mask"] = torch.cat((torch.tensor([[1] for i in range(len(inputs["attention_mask"]))]), inputs["attention_mask"]), dim=1)

inputs = inputs.to(device)

## EncoderDecoder - FNET Encoder Huggingface Decoder

In [54]:
# create encoder decoder model with GPT2 
class CustEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder):
        super(CustEncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pos_enb = PositionalEncoder(embed_dim = 768)
        self.embedder = embedder
    
    def forward(self, x):
        ecgs, labels = x
        x = self.embedder(ecgs).permute(2, 0, 1)
        x = self.pos_enb(x).permute(1, 0, 2)
        x = self.encoder(x)
        out = self.decoder(**labels, labels = labels["input_ids"], encoder_hidden_states = x.contiguous())
        return out
    
    # Should only take 1 input at a time?
    def predict(self, x):
        ecgs = x
        x = self.embedder(ecgs).permute(2, 0, 1)
        x = self.pos_enb(x).permute(1, 0, 2)
        x = self.encoder(x)
        output = []
        for ecg in x:
            output.append(tokenizer.decode(self.decoder.generate(encoder_hidden_states = ecg.unsqueeze(0).contiguous())[0]))
        return output
    
    def return_enc(self):
        return self.encoder

    
# Connect an embedder and de-embedder for training (we will then isolate the Encoder portion of this autoencoder as our embedder)
class ConvAutoEncoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConvAutoEncoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
    def make_encoder(self):
        return self.encoder
    
    def make_decoder(self):
        return self.decoder
    

In [55]:
# The following code is for the pre-embedder
    
# Make embedder
conv_model = nn.Sequential()
init_channels = 8
for i in range(2):
    next_channels = 2 * init_channels
    conv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = KER_SIZE, padding = PADDING, stride = 1))
    conv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    conv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    conv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = KER_SIZE, padding = PADDING))
    conv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
conv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 768, kernel_size = KER_SIZE, padding = PADDING))
conv_model.add_module('act_fin', nn.ReLU())
conv_model.add_module('batch_fin', nn.BatchNorm1d(768))


# Make de-embedder
deconv_model = nn.Sequential()
init_channels = 768
for i in range(2):
    next_channels = init_channels // 2
    deconv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = KER_SIZE, padding = PADDING, stride = 1))
    deconv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    deconv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    deconv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = KER_SIZE, padding = PADDING))
    deconv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
deconv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 8, kernel_size = KER_SIZE, padding = PADDING))
deconv_model.add_module('act_fin', nn.ReLU())
deconv_model.add_module('batch_fin', nn.BatchNorm1d(8))

auto_model = ConvAutoEncoder(conv_model, deconv_model).to(device)
auto_optimizer = torch.optim.Adam(auto_model.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)


# Training params
loss_function = nn.MSELoss()

epochs = 0

for i in range(epochs):
    auto_optimizer.zero_grad()
    outputs = auto_model(ecg_data)
    loss = loss_function(outputs, ecg_data)
    loss.backward(retain_graph=True)
    auto_optimizer.step()
    print(loss)
        
# Saving/loading weights
torch.save(auto_model.state_dict(), 'model/autoencoder.pt')
auto_model.load_state_dict(torch.load('model/autoencoder.pt'))
conv_embedder = auto_model.make_encoder()
torch.save(conv_embedder.state_dict(), "model/embedder.pt")

In [56]:
# Define encoder, we don't need to pretrain rn
encoder = FNETEncoder(768, expansion = 2, dropout=0.1, num_layers = 6)

In [69]:
# define and pretrain Decoder
decoder = GPT2LMHeadModel.from_pretrained('gpt2', config = GPT2Config(add_cross_attention = True, is_encoder_decoder = True))

# pretrain decoder
optimizer = torch.optim.Adam(decoder.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)

# set number of epochs
epochs = 1
for i in range(epochs):
    optimizer.zero_grad()
    outputs = decoder(**inputs, labels = inputs["input_ids"])
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    
    print(loss)
    
torch.save(decoder.state_dict(), 'model/gpt2.pt')

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.1.crossattention.masked_bias', 'h.2.crossattention.c_attn.weight', 'h.7.crossattention.c_attn.weight', 'h.5.crossattention.c_attn.weight', 'h.8.crossattention.c_proj.weight', 'h.11.crossattention.c_proj.bias', 'h.3.crossattention.q_attn.weight', 'h.2.crossattention.c_proj.bias', 'h.4.crossattention.c_attn.weight', 'h.4.crossattention.c_proj.weight', 'h.5.crossattention.c_proj.weight', 'h.11.crossattention.q_attn.weight', 'h.4.crossattention.masked_bias', 'h.10.crossattention.q_attn.weight', 'h.0.ln_cross_attn.weight', 'h.3.ln_cross_attn.weight', 'h.9.crossattention.bias', 'h.1.crossattention.c_attn.weight', 'h.7.crossattention.q_attn.weight', 'h.3.crossattention.c_proj.bias', 'h.7.ln_cross_attn.weight', 'h.10.crossattention.c_attn.weight', 'h.6.crossattention.bias', 'h.10.crossattention.masked_bias', 'h.11.crossattention.masked_bias', 'h.5.crossattention.c_proj.bias', '

tensor(6.2856, grad_fn=<NllLossBackward>)


In [70]:
# define component models

conv_embedder = auto_model.make_encoder()

encoder = FNETEncoder(768, expansion = 2, dropout=0.1, num_layers = 6)

#decoder.load_state_dict(torch.load('model/gpt2.pt'))

enc_dec_model = CustEncoderDecoder(encoder, decoder, conv_embedder)

enc_dec_model.predict(ecg_data)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


ModuleAttributeError: 'GPT2LMHeadModel' object has no attribute 'get_encoder'

In [40]:
# train encoder decoder model!
optimizer = torch.optim.Adam(enc_dec_model.parameters(), lr = 1e-5)
torch.autograd.set_detect_anomaly(True)

# set number of epochs
epochs = 2

for i in range(epochs):
    optimizer.zero_grad()
    outputs = enc_dec_model((ecg_data, inputs))
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    
    print(loss)
    
torch.save(enc_dec_model.state_dict(), 'model/gpt2_enc_dec.pt')

  allow_unreachable=True)  # allow_unreachable flag


tensor(3.7518, grad_fn=<NllLossBackward>)
tensor(3.0280, grad_fn=<NllLossBackward>)


In [None]:
enc_dec_model.load_state_dict(torch.load('model/gpt2_enc_dec.pt'))

In [85]:
torch.save(model.state_dict(), 'model/gpt2.pt')


50256