In [87]:
# import all packages needed
import string 
import numpy as np
import pandas as pd
from matplotlib import pyplot
from base64 import b64decode as decode
import math
import torch
import torch.nn as nn 
import torch.fft as fft
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
import transformers

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

## Data Processing / Cleaning

In [19]:
# use class base64 to decode waveform data
def to_array(wf):
    barr = bytearray(decode(wf))
    vals = np.array(barr)
    return vals.view(np.int16)

# read in data
exam_data = pd.read_csv("data/d_exam.csv").drop(columns = ["site_num", "patient_id_edit"])
waveform_data = pd.read_csv("data/d_waveform.csv")
lead_data = pd.read_csv("data/d_lead_data.csv").drop(columns = ["exam_id"])
diagnosis_data = pd.read_csv("data/d_diagnosis.csv").drop(columns = ["user_input"])

# add decoded data as a column to lead dataz
waveforms = list(lead_data['waveform_data'])
lead_data['decoded_waveform'] = [to_array(i) for i in waveforms]

# merge waveform data and lead data
waveform_lead = lead_data.merge(waveform_data, how = "left", left_on = "waveform_id", right_on = "waveform_id", suffixes = (None, None))

#  sort by exam id and lead id
waveform_lead.sort_values(by = ["waveform_id", "lead_id"], inplace = True)

waveform_lead.loc[:, ['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']]


# adding the diagnosis and labels
waveform_and_diag = pd.merge(waveform_lead[['exam_id', 'lead_id', 'decoded_waveform', 'waveform_type']], diagnosis_data[["exam_id", "Full_text", "Original_Diag"]], left_on= "exam_id", right_on="exam_id")


In [20]:
# concatenate all leads into a single array
waveform_lead_concat = waveform_lead.groupby(["exam_id", "waveform_type"])['decoded_waveform'].apply(lambda x: tuple(x)).reset_index()

# remove irregular observations, concat tuple into numpy array
waveform_lead_concat = waveform_lead_concat.drop([12,17], axis = 0)
waveform_lead_concat['decoded_waveform'] = waveform_lead_concat['decoded_waveform'].apply(lambda x: MinMaxScaler().fit_transform(np.vstack(x)))
waveform_lead_rhythm = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Rhythm"]
waveform_lead_median = waveform_lead_concat[waveform_lead_concat['waveform_type'] == "Median"]

waveform_lead_rhythm['decoded_waveform'][1].shape

(8, 2500)

In [21]:
# Adding the labels/sentences
exams = diagnosis_data["exam_id"].unique()

# Let's look over this tomorrow
diagnosis_data = diagnosis_data[diagnosis_data['Original_Diag'] == 1].dropna()
searchfor = ['previous', 'unconfirmed', 'compared', 'interpretation', 'significant']
diagnosis_data = diagnosis_data.loc[diagnosis_data['Full_text'].str.contains('|'.join(searchfor)) != 1]
#

diagnosis_data.sort_values(by=["exam_id", "statement_order"], inplace=True)
diagnoses = []
curr_id = 0
curr_string = ""
for i, row in diagnosis_data.iterrows():
    if row["statement_order"] == 1 and curr_string != "":
        curr_string = curr_string.lower().translate(str.maketrans('', '', string.punctuation))
        val = [curr_id, curr_string[1:]]
        diagnoses.append(val)
        curr_string = ""
        curr_id = row["exam_id"]

    if curr_id == 0:
        curr_id = row["exam_id"]
    
    curr_string += " " + row["Full_text"]

diagnosis_df = pd.DataFrame(diagnoses, columns = ['exam_id', 'diagnosis'])
waveform_lead_rhythm_diag = pd.merge(left=waveform_lead_rhythm, right=diagnosis_df, left_on='exam_id', right_on='exam_id')

#waveform_lead_rhythm_diag
waveform_lead_rhythm_diag

# define full_x
full_x = torch.tensor(list(waveform_lead_rhythm_diag['decoded_waveform'])).float()

In [22]:
unique_words = set()
for num, sentence in diagnoses:
    for word in sentence.split():
        unique_words.add(word)
print(unique_words)

{'abnormality', 'normal', 'lvh', 'sinus', 'variant', 'tachycardia', 'ecg', 'low', 'may', 'otherwise', 'be', 'for', 'arrhythmia', 't', 'minimal', 'criteria', 'qrs', 'atrial', 'rhythm', 'wave', 'abnormal', 'bradycardia', 'voltage', 'fibrillation', 'with', 'inferior', 'ischemia', 'borderline', 'consider'}


## Embedder: Conv1D

In [12]:
LR = 1e-3
KER_SIZE = 11
PADDING = 5

# define global max pooling
class global_max_pooling_1d(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x, _ = torch.max(x, dim = 2)
        return(x)

# define resblock for neural nets
class ResBlock1D(nn.Module):
    def __init__(self, num_filters, kernel_size, padding, groups = 1, stride = 1):
        super(ResBlock1D, self).__init__()
        self.act = nn.ReLU()
        self.conv1d_1 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.conv1d_2 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.batch_norm_1 = nn.BatchNorm1d(num_filters)
        self.batch_norm_2 = nn.BatchNorm1d(num_filters)

    def forward(self, x):
        res = x
        x = self.batch_norm_1(self.act(self.conv1d_1(x)))
        x = self.batch_norm_2(self.act(self.conv1d_2(x)))
        return x + res

conv_model = nn.Sequential()
init_channels = 8
for i in range(5):
    next_channels = 2 * init_channels
    conv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = KER_SIZE, padding = PADDING, stride = 1))
    conv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    conv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    conv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = KER_SIZE, padding = PADDING))
    conv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
conv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 768, kernel_size = KER_SIZE, padding = PADDING))
conv_model.add_module('act_fin', nn.ReLU())
conv_model.add_module('batch_fin', nn.BatchNorm1d(768))
print(conv_model)


Sequential(
  (conv_0): Conv1d(8, 16, kernel_size=(11,), stride=(1,), padding=(5,))
  (act_0): ReLU()
  (batch_norm_0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_0): ResBlock1D(
    (act): ReLU()
    (conv1d_1): Conv1d(16, 16, kernel_size=(11,), stride=(1,), padding=(5,))
    (conv1d_2): Conv1d(16, 16, kernel_size=(11,), stride=(1,), padding=(5,))
    (batch_norm_1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (batch_norm_2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (act_res_0): ReLU()
  (conv_1): Conv1d(16, 32, kernel_size=(11,), stride=(1,), padding=(5,))
  (act_1): ReLU()
  (batch_norm_1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_1): ResBlock1D(
    (act): ReLU()
    (conv1d_1): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(5,))
    (conv1d_2): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(

In [14]:
deconv_model = nn.Sequential()
init_channels = 768
for i in range(5):
    next_channels = init_channels // 2
    deconv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = KER_SIZE, padding = PADDING, stride = 1))
    deconv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    deconv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    deconv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = KER_SIZE, padding = PADDING))
    deconv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
deconv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 8, kernel_size = KER_SIZE, padding = PADDING))
deconv_model.add_module('act_fin', nn.ReLU())
deconv_model.add_module('batch_fin', nn.BatchNorm1d(8))

print(full_x.shape)
print(conv_model(full_x).shape)
print(deconv_model(conv_model(full_x)).shape)

torch.Size([7, 8, 2500])
torch.Size([7, 768, 2500])
torch.Size([7, 8, 2500])


In [15]:
class ConvAutoEncoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConvAutoEncoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
    def make_encoder(self):
        return self.encoder
    
    def make_decoder(self):
        return self.decoder
    

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model to set to
auto_model = ConvAutoEncoder(conv_model, deconv_model)
auto_optimizer = torch.optim.Adam(auto_model.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)

# Training params
loss_function = nn.MSELoss()

#print(train_x[0])
for i in range(180):
    #print(train_x.shape)
    auto_optimizer.zero_grad()
    outputs = auto_model(full_x)
    #print(outputs.shape)
    losses = loss_function(outputs, full_x)
    losses.backward(retain_graph=True)
    auto_optimizer.step()
    print(losses)
    if losses < .001:
        break

tensor(1.3016, grad_fn=<MseLossBackward>)
tensor(1.2688, grad_fn=<MseLossBackward>)
tensor(1.1736, grad_fn=<MseLossBackward>)
tensor(1.1416, grad_fn=<MseLossBackward>)
tensor(1.0946, grad_fn=<MseLossBackward>)
tensor(1.0625, grad_fn=<MseLossBackward>)
tensor(1.0360, grad_fn=<MseLossBackward>)
tensor(0.9886, grad_fn=<MseLossBackward>)
tensor(0.9671, grad_fn=<MseLossBackward>)
tensor(0.9323, grad_fn=<MseLossBackward>)
tensor(0.9133, grad_fn=<MseLossBackward>)
tensor(0.8988, grad_fn=<MseLossBackward>)
tensor(0.8857, grad_fn=<MseLossBackward>)
tensor(0.8753, grad_fn=<MseLossBackward>)
tensor(0.8668, grad_fn=<MseLossBackward>)
tensor(0.8571, grad_fn=<MseLossBackward>)
tensor(0.8494, grad_fn=<MseLossBackward>)
tensor(0.8425, grad_fn=<MseLossBackward>)
tensor(0.8355, grad_fn=<MseLossBackward>)
tensor(0.8292, grad_fn=<MseLossBackward>)
tensor(0.8223, grad_fn=<MseLossBackward>)
tensor(0.8157, grad_fn=<MseLossBackward>)
tensor(0.8091, grad_fn=<MseLossBackward>)
tensor(0.8028, grad_fn=<MseLossBac

In [None]:
torch.save(auto_model.state_dict(), 'model/autoencoder.pt')

conv_embedder = auto_model.make_encoder()

torch.save(conv_embedder.state_dict(), "model/embedder.pt")

## Encoder 1: ResNet Encoder

In [7]:
# HYPERPARAMETERS
J = 10 # max number of filters per class
LR = 1e-3

# define global max pooling
class global_max_pooling_1d(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x, _ = torch.max(x, dim = 2)
        return(x)

# define resblock for neural nets
class ResBlock1D(nn.Module):
    def __init__(self, num_filters, kernel_size, padding, groups = 1, stride = 1):
        super(ResBlock1D, self).__init__()
        self.act = nn.ReLU()
        self.conv1d_1 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.conv1d_2 = nn.Conv1d(num_filters, num_filters, kernel_size = kernel_size, padding = padding, groups = groups, stride = 1)
        self.batch_norm_1 = nn.BatchNorm1d(num_filters)
        self.batch_norm_2 = nn.BatchNorm1d(num_filters)

    def forward(self, x):
        res = x
        x = self.batch_norm_1(self.act(self.conv1d_1(x)))
        x = self.batch_norm_2(self.act(self.conv1d_2(x)))
        return x + res

# build resent model and display the shape of feed through
conv_model = nn.Sequential()
init_channels = 8
for i in range(5):
    next_channels = 2 * init_channels
    conv_model.add_module('conv_{num}'.format(num = i), nn.Conv1d(in_channels = init_channels, out_channels = next_channels, kernel_size = 249, padding = 124, stride = 1))
    conv_model.add_module('act_{num}'.format(num = i), nn.ReLU())
    conv_model.add_module('batch_norm_{num}'.format(num = i), nn.BatchNorm1d(next_channels))
    conv_model.add_module('res_{num}'.format(num = i), ResBlock1D(num_filters = next_channels, kernel_size = 249, padding = 124))
    conv_model.add_module('act_res_{num}'.format(num = i), nn.ReLU())
    init_channels = next_channels
    
conv_model.add_module('conv_fin', nn.Conv1d(in_channels = init_channels, out_channels = 8, kernel_size = 249, padding = 124))
conv_model.add_module('act_fin', nn.ReLU())
conv_model.add_module('batch_fin', nn.BatchNorm1d(8))
print(conv_model)
print(conv_model(train_x).shape)

Sequential(
  (conv_0): Conv1d(8, 16, kernel_size=(249,), stride=(1,), padding=(124,))
  (act_0): ReLU()
  (batch_norm_0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_0): ResBlock1D(
    (act): ReLU()
    (conv1d_1): Conv1d(16, 16, kernel_size=(249,), stride=(1,), padding=(124,))
    (conv1d_2): Conv1d(16, 16, kernel_size=(249,), stride=(1,), padding=(124,))
    (batch_norm_1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (batch_norm_2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (act_res_0): ReLU()
  (conv_1): Conv1d(16, 32, kernel_size=(249,), stride=(1,), padding=(124,))
  (act_1): ReLU()
  (batch_norm_1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_1): ResBlock1D(
    (act): ReLU()
    (conv1d_1): Conv1d(32, 32, kernel_size=(249,), stride=(1,), padding=(124,))
    (conv1d_2): Conv1d(32, 32, kernel_size=(249,), stride

## Encoder 2 - LSTM Encoder

In [17]:
# define hyperparameters 
hidden_layers = 250
embedding_dim = 8
num_words = len(dict_words)

class LSTM_EncoderDecoder(nn.Module):
    def __init__(self, h_dim, e_dim, word_list_length):
        super(ECG_LSTM, self).__init__()
        self.lstm = nn.LSTM(e_dim, h_dim, num_layers = 4, bidirectional = True)
        
    def forward(self, seq):
        seq_embedded = seq.view(len(seq), -1, embedding_dim)
        final_hidd, _ = self.lstm(seq_embedded)
        dec_seq = self.linear(final_hidd)
        return F.log_softmax(dec_seq, dim = 1)
    

## Encoder 3 - Basic Transformer Architecture with Multi-Head Attention

In [None]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class ECGTransformerEncoder(nn.Module):
    # Takes the ECG discrete signals sequence and maps into a probability distribution of diagnosis
    # For working/verification purposes
    def __init__(self, vector_size, embed_dim, n_heads, hidden_linear_dim, n_layers, dropout):
        super(ECGTransformerEncoder, self).__init__()
        self.model_type = "Transformer"
        self.positional_encoder = PositionalEncoder(embed_dim, dropout)
    
        #Since our data is already discrete numbers, might need some tweaking for this
        self.embedder = conv_embedder
                        #64 31              #39        64
        
        
        self.encoder = TransformerEncoder(
            TransformerEncoderLayer(embed_dim, n_heads, hidden_linear_dim, dropout),
            n_layers)
        
        self.n_inputs = embed_dim
        self.n_layers = n_layers
        
        # Simple linear decoder
        self.decoder = nn.Sequential(
                        nn.Linear(768, 17),
                        Transpose(17, 2500),
                        nn.Linear(2500, 30),
                        nn.LogSoftmax()
                        )
        self.init_weights()
        
    def init_weights(self):
        #self.embedder.weight.data.uniform_(-.1, .1)
        #self.decoder.bias.data.zero_()
        #self.decoder.weight.data.uniform_(-.1, .1)
        pass
        
    def forward(self, x):
        #x = self.embedder(x) # * math.sqrt(self.n_inputs)
        x = x.squeeze(0)
        #x = x.view(2500, 8)
        x = x.unsqueeze(1)
        x = self.positional_encoder(x)
        x = self.encoder(x)
        x = x.squeeze(1) 
        #x = self.decoder(x)
        return x

class Transpose(nn.Module):
    def __init__(self, *args):
        super(Transpose, self).__init__()
        self.shape = args

    def forward(self, x):
        # If the number of the last batch sample in the data set is smaller than the defined batch_batch size, mismatch problems will occur. You can modify it yourself, for example, just pass in the shape behind, and then enter it through x.szie(0).
        return x.view(self.shape)

class SignalEmbedder(nn.Module):
    # Necessary to convert the signal into "word" vectors for transformer processing.
    # Currently a simple group and slice method, but will modify later for multi-channel inputs
    
    def __init__(self, num_slices, size_of_slice):
        super(SignalEmbedder, self).__init__()
        self.num_slices = num_slices
        self.size_of_slice = size_of_slice
        
    def forward(self, x):
        x = x[: self.num_slices * self.size_of_slice]
        x = x.reshape((self.num_slices, self.size_of_slice))
        return x  

class PositionalEncoder(nn.Module):
    # Necessary to store positional data about the input data
    def __init__(self, embed_dim, dropout=0.1, max_len=2500, batch_size = 1):
        super(PositionalEncoder, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pos_encoding = torch.zeros(max_len, 1, embed_dim)
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
        
        divisor = torch.exp(torch.arange(0, embed_dim, 2).float() * (- math.log(10000.0) / embed_dim))
        
        pos_encoding[:, 0, 0::2] = torch.sin(position * divisor)
        pos_encoding[:, 0, 1::2] = torch.cos(position * divisor)
        pos_encoding = pos_encoding.repeat(1, batch_size, 1)
        self.register_buffer("pos_encoding", pos_encoding)

        
    def forward(self, x):
        x = x + self.pos_encoding[:x.size(0), :]
        return self.dropout(x)

## Encoder 4 - FNET Transformer Architecture

In [26]:
class FeedForwardNet(nn.Module):
    def __init__(self, features, expansion, dropout):
        super(FeedForwardNet, self).__init__()
        self.linear_1 = nn.Linear(features, features * expansion)
        self.linear_2 = nn.Linear(features * expansion, features)
        self.dropout_1 = nn.Dropout(dropout)
        #self.dropout_2 = nn.Dropout(dropout)
        self.norm_1 = nn.LayerNorm(features)

    def forward(self, x):
        res = x
        x = F.relu(self.linear_1(x))
        x = self.dropout_1(x)
        x = self.linear_2(x)
        x = self.norm_1(x + res)
        return x
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    
    
class FNETLayer(nn.Module):
    def __init__(self, features, expansion, dropout):
        super(FNETLayer, self).__init__()
        self.feed_forward = FeedForwardNet(features, expansion, dropout)
        self.norm_1 = nn.LayerNorm(features)
    
    def forward(self, x):
        res = x
        x = fft.fftn(x, dim = (-2, -1)).real
        x = self.norm_1(x + res)
        x = self.feed_forward(x)
        return x
    
class FNETEncoder(nn.TransformerEncoder):
    def __init__(self, features, expansion=2, dropout=0.5, num_layers=6):
        encoder_layer = FNETLayer(features, expansion, dropout)
        super().__init__(encoder_layer=encoder_layer, num_layers=num_layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
 
class Transpose(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x.transpose(1, 2)


## Decoder 1 - Huggingface GPT2 Decoder

In [83]:

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
    token_type_ids = kwargs.get("token_type_ids", None)
    # only last token for inputs_ids if past is defined in kwargs
    if past:
        input_ids = input_ids[:, -1].unsqueeze(-1)
        if token_type_ids is not None:
            token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

    attention_mask = kwargs.get("attention_mask", None)
    position_ids = kwargs.get("position_ids", None)

    if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        if past:
            position_ids = position_ids[:, -1].unsqueeze(-1)
    else:
        position_ids = None
    return {
        "input_ids": input_ids,
        "past_key_values": past,
        "use_cache": kwargs.get("use_cache"),
        "encoder_hidden_states": kwargs.get("encoder_hidden_states", None),
        "position_ids": position_ids,
        "attention_mask": attention_mask,
        "token_type_ids": token_type_ids,
    }

# define tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2', config = GPT2Config(add_cross_attention = True, is_encoder_decoder = True))
model.prepare_inputs_for_generation = prepare_inputs_for_generation.__get__(model, GPT2LMHeadModel)

# preprocess training labels and tokenize
train_labels = list(waveform_lead_rhythm_diag['diagnosis'])
inputs = tokenizer(train_labels, padding = True, verbose = False, return_tensors="pt")



Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.8.crossattention.c_attn.weight', 'h.7.crossattention.bias', 'h.6.crossattention.masked_bias', 'h.4.crossattention.c_proj.bias', 'h.0.crossattention.bias', 'h.0.crossattention.masked_bias', 'h.3.crossattention.bias', 'h.5.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.c_attn.weight', 'h.3.crossattention.c_proj.weight', 'h.7.ln_cross_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.3.ln_cross_attn.weight', 'h.2.crossattention.c_attn.weight', 'h.1.crossattention.c_attn.weight', 'h.9.ln_cross_attn.weight', 'h.7.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.6.crossattention.q_attn.weight', 'h.10.crossattention.c_attn.weight', 'h.7.crossattention.c_proj.bias', 'h.0.ln_cross_attn.weight', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.bias', 'h.1.crossattention.bias', 'h.8.crossattention.c_proj.weight', 'h.11.cro

In [84]:
# pretrain decoder
torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
torch.autograd.set_detect_anomaly(True)

# set number of epochs
epochs = 100

for i in range(epochs):
    optimizer.zero_grad()
    outputs = model(**inputs, labels = inputs["input_ids"])
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    
    print(loss)
    
torch.save(model.state_dict(), 'model/gpt2.pt')


tensor(8.0579, grad_fn=<NllLossBackward>)
tensor(5.5628, grad_fn=<NllLossBackward>)
tensor(3.0102, grad_fn=<NllLossBackward>)
tensor(2.9243, grad_fn=<NllLossBackward>)
tensor(2.1574, grad_fn=<NllLossBackward>)
tensor(1.6910, grad_fn=<NllLossBackward>)
tensor(1.4633, grad_fn=<NllLossBackward>)
tensor(1.3435, grad_fn=<NllLossBackward>)
tensor(1.1420, grad_fn=<NllLossBackward>)
tensor(0.9907, grad_fn=<NllLossBackward>)
tensor(0.8403, grad_fn=<NllLossBackward>)
tensor(0.7053, grad_fn=<NllLossBackward>)
tensor(0.5677, grad_fn=<NllLossBackward>)
tensor(0.4541, grad_fn=<NllLossBackward>)
tensor(0.3781, grad_fn=<NllLossBackward>)
tensor(0.3731, grad_fn=<NllLossBackward>)
tensor(0.3357, grad_fn=<NllLossBackward>)
tensor(0.2472, grad_fn=<NllLossBackward>)
tensor(0.1931, grad_fn=<NllLossBackward>)


KeyboardInterrupt: 

## EncoderDecoder - FNET Encoder Huggingface Decoder

In [94]:
# create encoder decoder model with GPT2 
class CustEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder):
        super(CustEncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pos_enb = PositionalEncoding(d_model = 768)
        self.embedder = embedder
    
    def forward(self, x):
        ecgs, labels = x
        x = self.embedder(ecgs).permute(2, 0, 1)
        x = self.pos_enb(x).permute(1, 0, 2)
        x = self.encoder(x)
        out = self.decoder(**labels, labels = labels["input_ids"], encoder_hidden_states = x.contiguous())
        return out
    
    def predict(self, x):
        ecgs = x
        x = self.embedder(ecgs).permute(2, 0, 1)
        x = self.pos_enb(x).permute(1, 0, 2)
        x = self.encoder(x)
        return self.decoder(input_ids = torch.tensor(self.decoder.config.bos_token_id), encoder_hidden_states = x.contiguous())
    
    def return_enc(self):
        return self.encoder

# define component models
conv_embedder = nn.Sequential(nn.Conv1d(in_channels = 8, out_channels = 350, kernel_size = 15, padding = 7, stride = 1),
                              nn.ReLU(),
                              nn.BatchNorm1d(350),
                              nn.Conv1d(in_channels = 350, out_channels = 768, kernel_size = 15, padding = 7, stride = 1),
                              nn.ReLU())
model.load_state_dict(torch.load('model/gpt2.pt'))
encoder = FNETEncoder(768, expansion = 2, dropout=0.1, num_layers = 6)
enc_dec_model = CustEncoderDecoder(encoder, model, conv_embedder)

# define all x, feed through to see if all is good
full_x = torch.tensor(list(waveform_lead_rhythm_diag['decoded_waveform'])).float()
#enc_dec_model((full_x, inputs))

enc_dec_model.predict(full_x)


IndexError: tuple index out of range

In [33]:
# train encoder decoder model!
torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(enc_dec_model.parameters(), lr = 1e-5)
torch.autograd.set_detect_anomaly(True)

# set number of epochs
epochs = 100

for i in range(epochs):
    optimizer.zero_grad()
    outputs = enc_dec_model((full_x, inputs))
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    
    
    print(loss)
    
torch.save(enc_dec_model.state_dict(), 'model/gpt2_enc_dec.pt')


KeyboardInterrupt: 

In [31]:
enc_dec_model.load_state_dict(torch.load('model/gpt2_enc_dec.pt'))


RuntimeError: Error(s) in loading state_dict for CustEncoderDecoder:
	Missing key(s) in state_dict: "encoder.layers.0.feed_forward.linear_1.weight", "encoder.layers.0.feed_forward.linear_1.bias", "encoder.layers.0.feed_forward.linear_2.weight", "encoder.layers.0.feed_forward.linear_2.bias", "encoder.layers.0.feed_forward.norm_1.weight", "encoder.layers.0.feed_forward.norm_1.bias", "encoder.layers.0.norm_1.weight", "encoder.layers.0.norm_1.bias", "encoder.layers.1.feed_forward.linear_1.weight", "encoder.layers.1.feed_forward.linear_1.bias", "encoder.layers.1.feed_forward.linear_2.weight", "encoder.layers.1.feed_forward.linear_2.bias", "encoder.layers.1.feed_forward.norm_1.weight", "encoder.layers.1.feed_forward.norm_1.bias", "encoder.layers.1.norm_1.weight", "encoder.layers.1.norm_1.bias", "encoder.layers.2.feed_forward.linear_1.weight", "encoder.layers.2.feed_forward.linear_1.bias", "encoder.layers.2.feed_forward.linear_2.weight", "encoder.layers.2.feed_forward.linear_2.bias", "encoder.layers.2.feed_forward.norm_1.weight", "encoder.layers.2.feed_forward.norm_1.bias", "encoder.layers.2.norm_1.weight", "encoder.layers.2.norm_1.bias", "encoder.layers.3.feed_forward.linear_1.weight", "encoder.layers.3.feed_forward.linear_1.bias", "encoder.layers.3.feed_forward.linear_2.weight", "encoder.layers.3.feed_forward.linear_2.bias", "encoder.layers.3.feed_forward.norm_1.weight", "encoder.layers.3.feed_forward.norm_1.bias", "encoder.layers.3.norm_1.weight", "encoder.layers.3.norm_1.bias", "encoder.layers.4.feed_forward.linear_1.weight", "encoder.layers.4.feed_forward.linear_1.bias", "encoder.layers.4.feed_forward.linear_2.weight", "encoder.layers.4.feed_forward.linear_2.bias", "encoder.layers.4.feed_forward.norm_1.weight", "encoder.layers.4.feed_forward.norm_1.bias", "encoder.layers.4.norm_1.weight", "encoder.layers.4.norm_1.bias", "encoder.layers.5.feed_forward.linear_1.weight", "encoder.layers.5.feed_forward.linear_1.bias", "encoder.layers.5.feed_forward.linear_2.weight", "encoder.layers.5.feed_forward.linear_2.bias", "encoder.layers.5.feed_forward.norm_1.weight", "encoder.layers.5.feed_forward.norm_1.bias", "encoder.layers.5.norm_1.weight", "encoder.layers.5.norm_1.bias", "decoder.transformer.wte.weight", "decoder.transformer.wpe.weight", "decoder.transformer.h.0.ln_1.weight", "decoder.transformer.h.0.ln_1.bias", "decoder.transformer.h.0.attn.bias", "decoder.transformer.h.0.attn.masked_bias", "decoder.transformer.h.0.attn.c_attn.weight", "decoder.transformer.h.0.attn.c_attn.bias", "decoder.transformer.h.0.attn.c_proj.weight", "decoder.transformer.h.0.attn.c_proj.bias", "decoder.transformer.h.0.ln_2.weight", "decoder.transformer.h.0.ln_2.bias", "decoder.transformer.h.0.crossattention.bias", "decoder.transformer.h.0.crossattention.masked_bias", "decoder.transformer.h.0.crossattention.c_attn.weight", "decoder.transformer.h.0.crossattention.c_attn.bias", "decoder.transformer.h.0.crossattention.q_attn.weight", "decoder.transformer.h.0.crossattention.q_attn.bias", "decoder.transformer.h.0.crossattention.c_proj.weight", "decoder.transformer.h.0.crossattention.c_proj.bias", "decoder.transformer.h.0.ln_cross_attn.weight", "decoder.transformer.h.0.ln_cross_attn.bias", "decoder.transformer.h.0.mlp.c_fc.weight", "decoder.transformer.h.0.mlp.c_fc.bias", "decoder.transformer.h.0.mlp.c_proj.weight", "decoder.transformer.h.0.mlp.c_proj.bias", "decoder.transformer.h.1.ln_1.weight", "decoder.transformer.h.1.ln_1.bias", "decoder.transformer.h.1.attn.bias", "decoder.transformer.h.1.attn.masked_bias", "decoder.transformer.h.1.attn.c_attn.weight", "decoder.transformer.h.1.attn.c_attn.bias", "decoder.transformer.h.1.attn.c_proj.weight", "decoder.transformer.h.1.attn.c_proj.bias", "decoder.transformer.h.1.ln_2.weight", "decoder.transformer.h.1.ln_2.bias", "decoder.transformer.h.1.crossattention.bias", "decoder.transformer.h.1.crossattention.masked_bias", "decoder.transformer.h.1.crossattention.c_attn.weight", "decoder.transformer.h.1.crossattention.c_attn.bias", "decoder.transformer.h.1.crossattention.q_attn.weight", "decoder.transformer.h.1.crossattention.q_attn.bias", "decoder.transformer.h.1.crossattention.c_proj.weight", "decoder.transformer.h.1.crossattention.c_proj.bias", "decoder.transformer.h.1.ln_cross_attn.weight", "decoder.transformer.h.1.ln_cross_attn.bias", "decoder.transformer.h.1.mlp.c_fc.weight", "decoder.transformer.h.1.mlp.c_fc.bias", "decoder.transformer.h.1.mlp.c_proj.weight", "decoder.transformer.h.1.mlp.c_proj.bias", "decoder.transformer.h.2.ln_1.weight", "decoder.transformer.h.2.ln_1.bias", "decoder.transformer.h.2.attn.bias", "decoder.transformer.h.2.attn.masked_bias", "decoder.transformer.h.2.attn.c_attn.weight", "decoder.transformer.h.2.attn.c_attn.bias", "decoder.transformer.h.2.attn.c_proj.weight", "decoder.transformer.h.2.attn.c_proj.bias", "decoder.transformer.h.2.ln_2.weight", "decoder.transformer.h.2.ln_2.bias", "decoder.transformer.h.2.crossattention.bias", "decoder.transformer.h.2.crossattention.masked_bias", "decoder.transformer.h.2.crossattention.c_attn.weight", "decoder.transformer.h.2.crossattention.c_attn.bias", "decoder.transformer.h.2.crossattention.q_attn.weight", "decoder.transformer.h.2.crossattention.q_attn.bias", "decoder.transformer.h.2.crossattention.c_proj.weight", "decoder.transformer.h.2.crossattention.c_proj.bias", "decoder.transformer.h.2.ln_cross_attn.weight", "decoder.transformer.h.2.ln_cross_attn.bias", "decoder.transformer.h.2.mlp.c_fc.weight", "decoder.transformer.h.2.mlp.c_fc.bias", "decoder.transformer.h.2.mlp.c_proj.weight", "decoder.transformer.h.2.mlp.c_proj.bias", "decoder.transformer.h.3.ln_1.weight", "decoder.transformer.h.3.ln_1.bias", "decoder.transformer.h.3.attn.bias", "decoder.transformer.h.3.attn.masked_bias", "decoder.transformer.h.3.attn.c_attn.weight", "decoder.transformer.h.3.attn.c_attn.bias", "decoder.transformer.h.3.attn.c_proj.weight", "decoder.transformer.h.3.attn.c_proj.bias", "decoder.transformer.h.3.ln_2.weight", "decoder.transformer.h.3.ln_2.bias", "decoder.transformer.h.3.crossattention.bias", "decoder.transformer.h.3.crossattention.masked_bias", "decoder.transformer.h.3.crossattention.c_attn.weight", "decoder.transformer.h.3.crossattention.c_attn.bias", "decoder.transformer.h.3.crossattention.q_attn.weight", "decoder.transformer.h.3.crossattention.q_attn.bias", "decoder.transformer.h.3.crossattention.c_proj.weight", "decoder.transformer.h.3.crossattention.c_proj.bias", "decoder.transformer.h.3.ln_cross_attn.weight", "decoder.transformer.h.3.ln_cross_attn.bias", "decoder.transformer.h.3.mlp.c_fc.weight", "decoder.transformer.h.3.mlp.c_fc.bias", "decoder.transformer.h.3.mlp.c_proj.weight", "decoder.transformer.h.3.mlp.c_proj.bias", "decoder.transformer.h.4.ln_1.weight", "decoder.transformer.h.4.ln_1.bias", "decoder.transformer.h.4.attn.bias", "decoder.transformer.h.4.attn.masked_bias", "decoder.transformer.h.4.attn.c_attn.weight", "decoder.transformer.h.4.attn.c_attn.bias", "decoder.transformer.h.4.attn.c_proj.weight", "decoder.transformer.h.4.attn.c_proj.bias", "decoder.transformer.h.4.ln_2.weight", "decoder.transformer.h.4.ln_2.bias", "decoder.transformer.h.4.crossattention.bias", "decoder.transformer.h.4.crossattention.masked_bias", "decoder.transformer.h.4.crossattention.c_attn.weight", "decoder.transformer.h.4.crossattention.c_attn.bias", "decoder.transformer.h.4.crossattention.q_attn.weight", "decoder.transformer.h.4.crossattention.q_attn.bias", "decoder.transformer.h.4.crossattention.c_proj.weight", "decoder.transformer.h.4.crossattention.c_proj.bias", "decoder.transformer.h.4.ln_cross_attn.weight", "decoder.transformer.h.4.ln_cross_attn.bias", "decoder.transformer.h.4.mlp.c_fc.weight", "decoder.transformer.h.4.mlp.c_fc.bias", "decoder.transformer.h.4.mlp.c_proj.weight", "decoder.transformer.h.4.mlp.c_proj.bias", "decoder.transformer.h.5.ln_1.weight", "decoder.transformer.h.5.ln_1.bias", "decoder.transformer.h.5.attn.bias", "decoder.transformer.h.5.attn.masked_bias", "decoder.transformer.h.5.attn.c_attn.weight", "decoder.transformer.h.5.attn.c_attn.bias", "decoder.transformer.h.5.attn.c_proj.weight", "decoder.transformer.h.5.attn.c_proj.bias", "decoder.transformer.h.5.ln_2.weight", "decoder.transformer.h.5.ln_2.bias", "decoder.transformer.h.5.crossattention.bias", "decoder.transformer.h.5.crossattention.masked_bias", "decoder.transformer.h.5.crossattention.c_attn.weight", "decoder.transformer.h.5.crossattention.c_attn.bias", "decoder.transformer.h.5.crossattention.q_attn.weight", "decoder.transformer.h.5.crossattention.q_attn.bias", "decoder.transformer.h.5.crossattention.c_proj.weight", "decoder.transformer.h.5.crossattention.c_proj.bias", "decoder.transformer.h.5.ln_cross_attn.weight", "decoder.transformer.h.5.ln_cross_attn.bias", "decoder.transformer.h.5.mlp.c_fc.weight", "decoder.transformer.h.5.mlp.c_fc.bias", "decoder.transformer.h.5.mlp.c_proj.weight", "decoder.transformer.h.5.mlp.c_proj.bias", "decoder.transformer.h.6.ln_1.weight", "decoder.transformer.h.6.ln_1.bias", "decoder.transformer.h.6.attn.bias", "decoder.transformer.h.6.attn.masked_bias", "decoder.transformer.h.6.attn.c_attn.weight", "decoder.transformer.h.6.attn.c_attn.bias", "decoder.transformer.h.6.attn.c_proj.weight", "decoder.transformer.h.6.attn.c_proj.bias", "decoder.transformer.h.6.ln_2.weight", "decoder.transformer.h.6.ln_2.bias", "decoder.transformer.h.6.crossattention.bias", "decoder.transformer.h.6.crossattention.masked_bias", "decoder.transformer.h.6.crossattention.c_attn.weight", "decoder.transformer.h.6.crossattention.c_attn.bias", "decoder.transformer.h.6.crossattention.q_attn.weight", "decoder.transformer.h.6.crossattention.q_attn.bias", "decoder.transformer.h.6.crossattention.c_proj.weight", "decoder.transformer.h.6.crossattention.c_proj.bias", "decoder.transformer.h.6.ln_cross_attn.weight", "decoder.transformer.h.6.ln_cross_attn.bias", "decoder.transformer.h.6.mlp.c_fc.weight", "decoder.transformer.h.6.mlp.c_fc.bias", "decoder.transformer.h.6.mlp.c_proj.weight", "decoder.transformer.h.6.mlp.c_proj.bias", "decoder.transformer.h.7.ln_1.weight", "decoder.transformer.h.7.ln_1.bias", "decoder.transformer.h.7.attn.bias", "decoder.transformer.h.7.attn.masked_bias", "decoder.transformer.h.7.attn.c_attn.weight", "decoder.transformer.h.7.attn.c_attn.bias", "decoder.transformer.h.7.attn.c_proj.weight", "decoder.transformer.h.7.attn.c_proj.bias", "decoder.transformer.h.7.ln_2.weight", "decoder.transformer.h.7.ln_2.bias", "decoder.transformer.h.7.crossattention.bias", "decoder.transformer.h.7.crossattention.masked_bias", "decoder.transformer.h.7.crossattention.c_attn.weight", "decoder.transformer.h.7.crossattention.c_attn.bias", "decoder.transformer.h.7.crossattention.q_attn.weight", "decoder.transformer.h.7.crossattention.q_attn.bias", "decoder.transformer.h.7.crossattention.c_proj.weight", "decoder.transformer.h.7.crossattention.c_proj.bias", "decoder.transformer.h.7.ln_cross_attn.weight", "decoder.transformer.h.7.ln_cross_attn.bias", "decoder.transformer.h.7.mlp.c_fc.weight", "decoder.transformer.h.7.mlp.c_fc.bias", "decoder.transformer.h.7.mlp.c_proj.weight", "decoder.transformer.h.7.mlp.c_proj.bias", "decoder.transformer.h.8.ln_1.weight", "decoder.transformer.h.8.ln_1.bias", "decoder.transformer.h.8.attn.bias", "decoder.transformer.h.8.attn.masked_bias", "decoder.transformer.h.8.attn.c_attn.weight", "decoder.transformer.h.8.attn.c_attn.bias", "decoder.transformer.h.8.attn.c_proj.weight", "decoder.transformer.h.8.attn.c_proj.bias", "decoder.transformer.h.8.ln_2.weight", "decoder.transformer.h.8.ln_2.bias", "decoder.transformer.h.8.crossattention.bias", "decoder.transformer.h.8.crossattention.masked_bias", "decoder.transformer.h.8.crossattention.c_attn.weight", "decoder.transformer.h.8.crossattention.c_attn.bias", "decoder.transformer.h.8.crossattention.q_attn.weight", "decoder.transformer.h.8.crossattention.q_attn.bias", "decoder.transformer.h.8.crossattention.c_proj.weight", "decoder.transformer.h.8.crossattention.c_proj.bias", "decoder.transformer.h.8.ln_cross_attn.weight", "decoder.transformer.h.8.ln_cross_attn.bias", "decoder.transformer.h.8.mlp.c_fc.weight", "decoder.transformer.h.8.mlp.c_fc.bias", "decoder.transformer.h.8.mlp.c_proj.weight", "decoder.transformer.h.8.mlp.c_proj.bias", "decoder.transformer.h.9.ln_1.weight", "decoder.transformer.h.9.ln_1.bias", "decoder.transformer.h.9.attn.bias", "decoder.transformer.h.9.attn.masked_bias", "decoder.transformer.h.9.attn.c_attn.weight", "decoder.transformer.h.9.attn.c_attn.bias", "decoder.transformer.h.9.attn.c_proj.weight", "decoder.transformer.h.9.attn.c_proj.bias", "decoder.transformer.h.9.ln_2.weight", "decoder.transformer.h.9.ln_2.bias", "decoder.transformer.h.9.crossattention.bias", "decoder.transformer.h.9.crossattention.masked_bias", "decoder.transformer.h.9.crossattention.c_attn.weight", "decoder.transformer.h.9.crossattention.c_attn.bias", "decoder.transformer.h.9.crossattention.q_attn.weight", "decoder.transformer.h.9.crossattention.q_attn.bias", "decoder.transformer.h.9.crossattention.c_proj.weight", "decoder.transformer.h.9.crossattention.c_proj.bias", "decoder.transformer.h.9.ln_cross_attn.weight", "decoder.transformer.h.9.ln_cross_attn.bias", "decoder.transformer.h.9.mlp.c_fc.weight", "decoder.transformer.h.9.mlp.c_fc.bias", "decoder.transformer.h.9.mlp.c_proj.weight", "decoder.transformer.h.9.mlp.c_proj.bias", "decoder.transformer.h.10.ln_1.weight", "decoder.transformer.h.10.ln_1.bias", "decoder.transformer.h.10.attn.bias", "decoder.transformer.h.10.attn.masked_bias", "decoder.transformer.h.10.attn.c_attn.weight", "decoder.transformer.h.10.attn.c_attn.bias", "decoder.transformer.h.10.attn.c_proj.weight", "decoder.transformer.h.10.attn.c_proj.bias", "decoder.transformer.h.10.ln_2.weight", "decoder.transformer.h.10.ln_2.bias", "decoder.transformer.h.10.crossattention.bias", "decoder.transformer.h.10.crossattention.masked_bias", "decoder.transformer.h.10.crossattention.c_attn.weight", "decoder.transformer.h.10.crossattention.c_attn.bias", "decoder.transformer.h.10.crossattention.q_attn.weight", "decoder.transformer.h.10.crossattention.q_attn.bias", "decoder.transformer.h.10.crossattention.c_proj.weight", "decoder.transformer.h.10.crossattention.c_proj.bias", "decoder.transformer.h.10.ln_cross_attn.weight", "decoder.transformer.h.10.ln_cross_attn.bias", "decoder.transformer.h.10.mlp.c_fc.weight", "decoder.transformer.h.10.mlp.c_fc.bias", "decoder.transformer.h.10.mlp.c_proj.weight", "decoder.transformer.h.10.mlp.c_proj.bias", "decoder.transformer.h.11.ln_1.weight", "decoder.transformer.h.11.ln_1.bias", "decoder.transformer.h.11.attn.bias", "decoder.transformer.h.11.attn.masked_bias", "decoder.transformer.h.11.attn.c_attn.weight", "decoder.transformer.h.11.attn.c_attn.bias", "decoder.transformer.h.11.attn.c_proj.weight", "decoder.transformer.h.11.attn.c_proj.bias", "decoder.transformer.h.11.ln_2.weight", "decoder.transformer.h.11.ln_2.bias", "decoder.transformer.h.11.crossattention.bias", "decoder.transformer.h.11.crossattention.masked_bias", "decoder.transformer.h.11.crossattention.c_attn.weight", "decoder.transformer.h.11.crossattention.c_attn.bias", "decoder.transformer.h.11.crossattention.q_attn.weight", "decoder.transformer.h.11.crossattention.q_attn.bias", "decoder.transformer.h.11.crossattention.c_proj.weight", "decoder.transformer.h.11.crossattention.c_proj.bias", "decoder.transformer.h.11.ln_cross_attn.weight", "decoder.transformer.h.11.ln_cross_attn.bias", "decoder.transformer.h.11.mlp.c_fc.weight", "decoder.transformer.h.11.mlp.c_fc.bias", "decoder.transformer.h.11.mlp.c_proj.weight", "decoder.transformer.h.11.mlp.c_proj.bias", "decoder.transformer.ln_f.weight", "decoder.transformer.ln_f.bias", "decoder.lm_head.weight", "pos_enb.pe", "embedder.0.weight", "embedder.0.bias", "embedder.2.weight", "embedder.2.bias", "embedder.2.running_mean", "embedder.2.running_var", "embedder.3.weight", "embedder.3.bias". 
	Unexpected key(s) in state_dict: "transformer.wte.weight", "transformer.wpe.weight", "transformer.h.0.ln_1.weight", "transformer.h.0.ln_1.bias", "transformer.h.0.attn.bias", "transformer.h.0.attn.masked_bias", "transformer.h.0.attn.c_attn.weight", "transformer.h.0.attn.c_attn.bias", "transformer.h.0.attn.c_proj.weight", "transformer.h.0.attn.c_proj.bias", "transformer.h.0.ln_2.weight", "transformer.h.0.ln_2.bias", "transformer.h.0.crossattention.bias", "transformer.h.0.crossattention.masked_bias", "transformer.h.0.crossattention.c_attn.weight", "transformer.h.0.crossattention.c_attn.bias", "transformer.h.0.crossattention.q_attn.weight", "transformer.h.0.crossattention.q_attn.bias", "transformer.h.0.crossattention.c_proj.weight", "transformer.h.0.crossattention.c_proj.bias", "transformer.h.0.ln_cross_attn.weight", "transformer.h.0.ln_cross_attn.bias", "transformer.h.0.mlp.c_fc.weight", "transformer.h.0.mlp.c_fc.bias", "transformer.h.0.mlp.c_proj.weight", "transformer.h.0.mlp.c_proj.bias", "transformer.h.1.ln_1.weight", "transformer.h.1.ln_1.bias", "transformer.h.1.attn.bias", "transformer.h.1.attn.masked_bias", "transformer.h.1.attn.c_attn.weight", "transformer.h.1.attn.c_attn.bias", "transformer.h.1.attn.c_proj.weight", "transformer.h.1.attn.c_proj.bias", "transformer.h.1.ln_2.weight", "transformer.h.1.ln_2.bias", "transformer.h.1.crossattention.bias", "transformer.h.1.crossattention.masked_bias", "transformer.h.1.crossattention.c_attn.weight", "transformer.h.1.crossattention.c_attn.bias", "transformer.h.1.crossattention.q_attn.weight", "transformer.h.1.crossattention.q_attn.bias", "transformer.h.1.crossattention.c_proj.weight", "transformer.h.1.crossattention.c_proj.bias", "transformer.h.1.ln_cross_attn.weight", "transformer.h.1.ln_cross_attn.bias", "transformer.h.1.mlp.c_fc.weight", "transformer.h.1.mlp.c_fc.bias", "transformer.h.1.mlp.c_proj.weight", "transformer.h.1.mlp.c_proj.bias", "transformer.h.2.ln_1.weight", "transformer.h.2.ln_1.bias", "transformer.h.2.attn.bias", "transformer.h.2.attn.masked_bias", "transformer.h.2.attn.c_attn.weight", "transformer.h.2.attn.c_attn.bias", "transformer.h.2.attn.c_proj.weight", "transformer.h.2.attn.c_proj.bias", "transformer.h.2.ln_2.weight", "transformer.h.2.ln_2.bias", "transformer.h.2.crossattention.bias", "transformer.h.2.crossattention.masked_bias", "transformer.h.2.crossattention.c_attn.weight", "transformer.h.2.crossattention.c_attn.bias", "transformer.h.2.crossattention.q_attn.weight", "transformer.h.2.crossattention.q_attn.bias", "transformer.h.2.crossattention.c_proj.weight", "transformer.h.2.crossattention.c_proj.bias", "transformer.h.2.ln_cross_attn.weight", "transformer.h.2.ln_cross_attn.bias", "transformer.h.2.mlp.c_fc.weight", "transformer.h.2.mlp.c_fc.bias", "transformer.h.2.mlp.c_proj.weight", "transformer.h.2.mlp.c_proj.bias", "transformer.h.3.ln_1.weight", "transformer.h.3.ln_1.bias", "transformer.h.3.attn.bias", "transformer.h.3.attn.masked_bias", "transformer.h.3.attn.c_attn.weight", "transformer.h.3.attn.c_attn.bias", "transformer.h.3.attn.c_proj.weight", "transformer.h.3.attn.c_proj.bias", "transformer.h.3.ln_2.weight", "transformer.h.3.ln_2.bias", "transformer.h.3.crossattention.bias", "transformer.h.3.crossattention.masked_bias", "transformer.h.3.crossattention.c_attn.weight", "transformer.h.3.crossattention.c_attn.bias", "transformer.h.3.crossattention.q_attn.weight", "transformer.h.3.crossattention.q_attn.bias", "transformer.h.3.crossattention.c_proj.weight", "transformer.h.3.crossattention.c_proj.bias", "transformer.h.3.ln_cross_attn.weight", "transformer.h.3.ln_cross_attn.bias", "transformer.h.3.mlp.c_fc.weight", "transformer.h.3.mlp.c_fc.bias", "transformer.h.3.mlp.c_proj.weight", "transformer.h.3.mlp.c_proj.bias", "transformer.h.4.ln_1.weight", "transformer.h.4.ln_1.bias", "transformer.h.4.attn.bias", "transformer.h.4.attn.masked_bias", "transformer.h.4.attn.c_attn.weight", "transformer.h.4.attn.c_attn.bias", "transformer.h.4.attn.c_proj.weight", "transformer.h.4.attn.c_proj.bias", "transformer.h.4.ln_2.weight", "transformer.h.4.ln_2.bias", "transformer.h.4.crossattention.bias", "transformer.h.4.crossattention.masked_bias", "transformer.h.4.crossattention.c_attn.weight", "transformer.h.4.crossattention.c_attn.bias", "transformer.h.4.crossattention.q_attn.weight", "transformer.h.4.crossattention.q_attn.bias", "transformer.h.4.crossattention.c_proj.weight", "transformer.h.4.crossattention.c_proj.bias", "transformer.h.4.ln_cross_attn.weight", "transformer.h.4.ln_cross_attn.bias", "transformer.h.4.mlp.c_fc.weight", "transformer.h.4.mlp.c_fc.bias", "transformer.h.4.mlp.c_proj.weight", "transformer.h.4.mlp.c_proj.bias", "transformer.h.5.ln_1.weight", "transformer.h.5.ln_1.bias", "transformer.h.5.attn.bias", "transformer.h.5.attn.masked_bias", "transformer.h.5.attn.c_attn.weight", "transformer.h.5.attn.c_attn.bias", "transformer.h.5.attn.c_proj.weight", "transformer.h.5.attn.c_proj.bias", "transformer.h.5.ln_2.weight", "transformer.h.5.ln_2.bias", "transformer.h.5.crossattention.bias", "transformer.h.5.crossattention.masked_bias", "transformer.h.5.crossattention.c_attn.weight", "transformer.h.5.crossattention.c_attn.bias", "transformer.h.5.crossattention.q_attn.weight", "transformer.h.5.crossattention.q_attn.bias", "transformer.h.5.crossattention.c_proj.weight", "transformer.h.5.crossattention.c_proj.bias", "transformer.h.5.ln_cross_attn.weight", "transformer.h.5.ln_cross_attn.bias", "transformer.h.5.mlp.c_fc.weight", "transformer.h.5.mlp.c_fc.bias", "transformer.h.5.mlp.c_proj.weight", "transformer.h.5.mlp.c_proj.bias", "transformer.h.6.ln_1.weight", "transformer.h.6.ln_1.bias", "transformer.h.6.attn.bias", "transformer.h.6.attn.masked_bias", "transformer.h.6.attn.c_attn.weight", "transformer.h.6.attn.c_attn.bias", "transformer.h.6.attn.c_proj.weight", "transformer.h.6.attn.c_proj.bias", "transformer.h.6.ln_2.weight", "transformer.h.6.ln_2.bias", "transformer.h.6.crossattention.bias", "transformer.h.6.crossattention.masked_bias", "transformer.h.6.crossattention.c_attn.weight", "transformer.h.6.crossattention.c_attn.bias", "transformer.h.6.crossattention.q_attn.weight", "transformer.h.6.crossattention.q_attn.bias", "transformer.h.6.crossattention.c_proj.weight", "transformer.h.6.crossattention.c_proj.bias", "transformer.h.6.ln_cross_attn.weight", "transformer.h.6.ln_cross_attn.bias", "transformer.h.6.mlp.c_fc.weight", "transformer.h.6.mlp.c_fc.bias", "transformer.h.6.mlp.c_proj.weight", "transformer.h.6.mlp.c_proj.bias", "transformer.h.7.ln_1.weight", "transformer.h.7.ln_1.bias", "transformer.h.7.attn.bias", "transformer.h.7.attn.masked_bias", "transformer.h.7.attn.c_attn.weight", "transformer.h.7.attn.c_attn.bias", "transformer.h.7.attn.c_proj.weight", "transformer.h.7.attn.c_proj.bias", "transformer.h.7.ln_2.weight", "transformer.h.7.ln_2.bias", "transformer.h.7.crossattention.bias", "transformer.h.7.crossattention.masked_bias", "transformer.h.7.crossattention.c_attn.weight", "transformer.h.7.crossattention.c_attn.bias", "transformer.h.7.crossattention.q_attn.weight", "transformer.h.7.crossattention.q_attn.bias", "transformer.h.7.crossattention.c_proj.weight", "transformer.h.7.crossattention.c_proj.bias", "transformer.h.7.ln_cross_attn.weight", "transformer.h.7.ln_cross_attn.bias", "transformer.h.7.mlp.c_fc.weight", "transformer.h.7.mlp.c_fc.bias", "transformer.h.7.mlp.c_proj.weight", "transformer.h.7.mlp.c_proj.bias", "transformer.h.8.ln_1.weight", "transformer.h.8.ln_1.bias", "transformer.h.8.attn.bias", "transformer.h.8.attn.masked_bias", "transformer.h.8.attn.c_attn.weight", "transformer.h.8.attn.c_attn.bias", "transformer.h.8.attn.c_proj.weight", "transformer.h.8.attn.c_proj.bias", "transformer.h.8.ln_2.weight", "transformer.h.8.ln_2.bias", "transformer.h.8.crossattention.bias", "transformer.h.8.crossattention.masked_bias", "transformer.h.8.crossattention.c_attn.weight", "transformer.h.8.crossattention.c_attn.bias", "transformer.h.8.crossattention.q_attn.weight", "transformer.h.8.crossattention.q_attn.bias", "transformer.h.8.crossattention.c_proj.weight", "transformer.h.8.crossattention.c_proj.bias", "transformer.h.8.ln_cross_attn.weight", "transformer.h.8.ln_cross_attn.bias", "transformer.h.8.mlp.c_fc.weight", "transformer.h.8.mlp.c_fc.bias", "transformer.h.8.mlp.c_proj.weight", "transformer.h.8.mlp.c_proj.bias", "transformer.h.9.ln_1.weight", "transformer.h.9.ln_1.bias", "transformer.h.9.attn.bias", "transformer.h.9.attn.masked_bias", "transformer.h.9.attn.c_attn.weight", "transformer.h.9.attn.c_attn.bias", "transformer.h.9.attn.c_proj.weight", "transformer.h.9.attn.c_proj.bias", "transformer.h.9.ln_2.weight", "transformer.h.9.ln_2.bias", "transformer.h.9.crossattention.bias", "transformer.h.9.crossattention.masked_bias", "transformer.h.9.crossattention.c_attn.weight", "transformer.h.9.crossattention.c_attn.bias", "transformer.h.9.crossattention.q_attn.weight", "transformer.h.9.crossattention.q_attn.bias", "transformer.h.9.crossattention.c_proj.weight", "transformer.h.9.crossattention.c_proj.bias", "transformer.h.9.ln_cross_attn.weight", "transformer.h.9.ln_cross_attn.bias", "transformer.h.9.mlp.c_fc.weight", "transformer.h.9.mlp.c_fc.bias", "transformer.h.9.mlp.c_proj.weight", "transformer.h.9.mlp.c_proj.bias", "transformer.h.10.ln_1.weight", "transformer.h.10.ln_1.bias", "transformer.h.10.attn.bias", "transformer.h.10.attn.masked_bias", "transformer.h.10.attn.c_attn.weight", "transformer.h.10.attn.c_attn.bias", "transformer.h.10.attn.c_proj.weight", "transformer.h.10.attn.c_proj.bias", "transformer.h.10.ln_2.weight", "transformer.h.10.ln_2.bias", "transformer.h.10.crossattention.bias", "transformer.h.10.crossattention.masked_bias", "transformer.h.10.crossattention.c_attn.weight", "transformer.h.10.crossattention.c_attn.bias", "transformer.h.10.crossattention.q_attn.weight", "transformer.h.10.crossattention.q_attn.bias", "transformer.h.10.crossattention.c_proj.weight", "transformer.h.10.crossattention.c_proj.bias", "transformer.h.10.ln_cross_attn.weight", "transformer.h.10.ln_cross_attn.bias", "transformer.h.10.mlp.c_fc.weight", "transformer.h.10.mlp.c_fc.bias", "transformer.h.10.mlp.c_proj.weight", "transformer.h.10.mlp.c_proj.bias", "transformer.h.11.ln_1.weight", "transformer.h.11.ln_1.bias", "transformer.h.11.attn.bias", "transformer.h.11.attn.masked_bias", "transformer.h.11.attn.c_attn.weight", "transformer.h.11.attn.c_attn.bias", "transformer.h.11.attn.c_proj.weight", "transformer.h.11.attn.c_proj.bias", "transformer.h.11.ln_2.weight", "transformer.h.11.ln_2.bias", "transformer.h.11.crossattention.bias", "transformer.h.11.crossattention.masked_bias", "transformer.h.11.crossattention.c_attn.weight", "transformer.h.11.crossattention.c_attn.bias", "transformer.h.11.crossattention.q_attn.weight", "transformer.h.11.crossattention.q_attn.bias", "transformer.h.11.crossattention.c_proj.weight", "transformer.h.11.crossattention.c_proj.bias", "transformer.h.11.ln_cross_attn.weight", "transformer.h.11.ln_cross_attn.bias", "transformer.h.11.mlp.c_fc.weight", "transformer.h.11.mlp.c_fc.bias", "transformer.h.11.mlp.c_proj.weight", "transformer.h.11.mlp.c_proj.bias", "transformer.ln_f.weight", "transformer.ln_f.bias", "lm_head.weight". 

In [85]:
torch.save(model.state_dict(), 'model/gpt2.pt')


In [91]:
GPT2Config()
model.config.bos_token_id

50256

In [95]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (crossattention): GPT2Attention(
          (c_attn): Conv1D()
          (q_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_cross_attn): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace