In [1]:
import torch 
import json
from tqdm import tqdm

In [2]:
with open("train.json", "r") as f:
    train_data = f.readlines()

In [3]:
import ast

train_data = list(map(ast.literal_eval, train_data))

In [4]:
train_dic = dict()
unique_tokens = set()

for row in tqdm(train_data):
    train_dic[row[1]] = {
        "data": row[5:],
        "labels": row[2:5]
    }

    unique_tokens = unique_tokens.union(set(row[5:]))

100%|██████████| 345302/345302 [02:11<00:00, 2626.66it/s]


In [7]:
len(unique_tokens), len(train_dic)

(30398, 345302)

In [9]:
import numpy as np
import hashlib
import os

class Tokenizer():
    def __init__(self,
                 data_dir: str,
                 special_tokens: dict,
                 min_freq: int=5):
        """ Initialize word-level tokenizer.
            
        Parameters
        ----------
            special_tokens (list of str)
                tokens used in the model for unknown words, padding, masking,
                starting/closing sentences, etc.
            min_freq (int)
                minimum frequency at which a word should occur in the corpus to have
                its own token_id (else: token_id of '[UNK]')
                
        """
        self.data_dir = data_dir
        self.encoder = dict(special_tokens)
        self.special_tokens = dict(special_tokens)
        self.min_freq = min_freq
        self.unique_id = self.create_unique_id()
        self.path = os.path.join(self.data_dir, 'tokenizer', self.unique_id)
        
    def create_unique_id(self):
        unique_str = str(vars(self))
        return hashlib.sha256(unique_str.encode()).hexdigest()

    def fit(self, words: list):        
        # Compute and sort vocabulary
        word_vocab, word_counts = np.unique(words, return_counts=True)
        if self.min_freq > 0:  # remove rare words
            word_vocab = word_vocab[word_counts >= self.min_freq]
            word_counts = word_counts[word_counts >= self.min_freq]
        inds = word_counts.argsort()[::-1]
        word_vocab = word_vocab[inds]
        
        # Generate word level encoder
        self.encoder.update({i: (idx + len(self.special_tokens))
                             for idx, i in enumerate(word_vocab)})
        
        # Store word count for every word (useful for skipgram dataset)
        self.word_counts = {self.encoder[word]: count for word, count in \
                            zip(word_vocab, sorted(word_counts)[::-1])}
        self.word_counts.update({k: 1 for k in self.special_tokens.values()})
        
        # Build decoder
        self.decoder = {v: k for k, v in self.encoder.items()}
        
        # Store and print tokenizer vocabulary information
        self.vocab_sizes = {'total': len(self.encoder),
                            'special': len(self.special_tokens),
                            'word': len(word_vocab)}
        
    def encode(self, word: str):
        try:
            return self.encoder[word]
        except:
            return self.encoder['[UNK]']
    
    def decode(self, token_id: int):
        return self.decoder[token_id]
    
    def get_vocab(self):
        return self.encoder.keys()

In [52]:
tok = Tokenizer("tok", {"[PAD]":0, "[UNK]":1, "[MSK]":2, "[CLS]":3})

raw_tokens = [ehr for e in tqdm(train_dic.values()) for ehr in list(e.values())[0]]
tok.fit(raw_tokens)
len(tok.encoder)

100%|██████████| 345302/345302 [00:00<00:00, 363632.37it/s]


15606

In [53]:
# list(train_dic.values())[0]['data']

In [690]:
from torch.utils.data import Dataset
from torch.distributions.categorical import Categorical

class EHRDataset(Dataset):
    def __init__(self):
        self.adm_index = list(train_dic.keys())
        self.labels = [torch.tensor([1]) if e['labels'][0] == "LBL_ALIVE" else torch.tensor([0])
                       for e in train_dic.values()]
        self.tokens = [["[CLS]"] + e['data'] for e in train_dic.values()]
        # self.tokens = [e['data'] for e in train_dic.values()]
        self.tokens = [torch.tensor(list(map(tok.encode, record))) for record in self.tokens]

        self.mlm_ratio = 0.15

    def __len__(self):
        return len(self.adm_index)
    
    def prob_mask_like(self, inputs, mask_ratio=0.15):
        return torch.zeros_like(inputs).float().uniform_(0, 1) < mask_ratio

    def __getitem__(self, index):
        tokens = self.tokens[index]
        # labels = self.labels[index]

        bert_mask = torch.zeros_like(tokens)
        #zero if we let the token unchanged
        #one if we mask it
        #two if we replace it with a random token

        total_mask = self.prob_mask_like(bert_mask)
        
        if total_mask[total_mask].numel() > 0:
            sample = Categorical(torch.tensor([0.1, 0.8, 0.1])).sample(total_mask[total_mask].shape)

            bert_mask[total_mask] = sample

            tokens_with_mask = torch.where(bert_mask==1, tok.encoder['[MSK]'] * torch.ones_like(tokens), tokens)
            tokens_with_mask = torch.where(bert_mask==2, torch.randint_like(tokens_with_mask, low=4, high=len(tok.encoder)), tokens_with_mask)
        else:
            tokens_with_mask = tokens

        sample = {
            "tokens": tokens_with_mask,
            # "target": labels
            "target_tokens": tokens#[bert_mask==1]
        }
        return sample

In [691]:
ds = EHRDataset()

In [692]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

max_len = 150

def collate_fn(batch):
    tokens = [item['tokens'][:max_len] for item in batch]#.copy()
    tokens = pad_sequence(tokens, batch_first=True, padding_value=0)

    # labels = [item['target'] for item in batch]
    # labels = torch.stack(labels)

    target_tokens = [item['target_tokens'][:max_len] for item in batch]
    target_tokens = pad_sequence(target_tokens, batch_first=True, padding_value=0)

    return {
        "tokens": tokens, 
        # "target": labels
        "target_tokens": target_tokens
        }

In [693]:
import torch.nn as nn
import math

class StayEmbedding(nn.Module):

    def __init__(self, d_model, max_len=max_len, tokenizer_codes=None, device='cpu'):
        super().__init__()

        # d_model = d_model//2

        #compute time encoding
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe = torch.squeeze(pe)
        self.time_encoding = pe.to(device)

        #compute code embedding
        self.embedding = nn.Embedding(1+len(tokenizer_codes), d_model, padding_idx=0, device=device)

    def forward(self, codes):
        batch_size, n_timesteps = codes.shape
        
        code_embedding = self.embedding(codes)

        time_embedding = self.time_encoding.repeat([batch_size, 1, 1])
        time_embedding = time_embedding[:, :n_timesteps, :]

        x = code_embedding + time_embedding
        return x

In [694]:
class BERT_MLM(nn.Module):

    def __init__(self, d_embedding, d_model, dropout=0.1, n_layers=2, nhead=4, tokenizer_codes=None, device='cpu'):
        super().__init__()

        #embedding
        self.embedding = StayEmbedding(d_embedding, tokenizer_codes=tokenizer_codes, device=device)

        #projection
        self.proj = nn.Linear(d_embedding, d_model).to(device)

        #transformer embedding
        layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2*d_model, dropout=dropout, batch_first=True, device=device)
        self.bert = nn.TransformerEncoder(layer, n_layers).to(device)

        #classification
        self.cls = nn.Linear(d_model, len(tok.encoder), bias=True).to(device)

    def forward(self, tokens):
        x = self.embedding(tokens)

        x = nn.ReLU()(self.proj(x))

        x = self.bert(x)
        x = nn.ReLU()(x)

        x = self.cls(x)
        return x

In [695]:
d_model = 64
device = 'cpu'
device = 'mps'
n_layers = 2

In [696]:
model = BERT_MLM(d_model, 2*d_model, dropout=0.2, tokenizer_codes=tok.encoder, device=device)

In [697]:
test = next(iter(DataLoader(dataset=ds, batch_size=64, drop_last=False, collate_fn=collate_fn)))

tokens, target = test['tokens'], test['target_tokens']
tokens.shape, target.shape

(torch.Size([64, 150]), torch.Size([64, 150]))

In [700]:
model(tokens)

RuntimeError: Placeholder storage has not been allocated on MPS device!

In [656]:
predicted_tokens = model(tokens)
predicted_tokens.shape

torch.Size([64, 150, 15606])

In [682]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3
    )

epochs = 100

In [684]:
len(ds)/64

5395.34375

In [685]:
for epoch in range(epochs):
    epoch_loss = []
    epoch_acc = []

    model.train()
    for batch in tqdm(DataLoader(dataset=ds, batch_size=64, drop_last=False, collate_fn=collate_fn)):
        optimizer.zero_grad()
        
        tokens = batch["tokens"]#.to(device)
        targets = batch["target_tokens"]#.to(device)
        
        predicted_tokens = model(tokens)

        was_mlm_token = tokens == 2
        mlm_predicted_tokens = predicted_tokens[was_mlm_token]
        target_tokens = target[was_mlm_token]

        loss = criterion(mlm_predicted_tokens, target_tokens)
        loss.backward()
        optimizer.step()

        epoch_loss.append(loss.detach().item())
        epoch_acc.append((target_tokens == mlm_predicted_tokens.argmax(dim=1)).detach().numpy().mean().item())

    epoch_loss = np.array(epoch_loss).mean()
    epoch_acc = np.array(epoch_acc).mean()

    print(f"Epoch {1+epoch}:", "train: loss {:.2f}; accuracy {:.2f}".format(epoch_loss, epoch_acc))

  6%|▌         | 328/5396 [03:32<54:43,  1.54it/s]  


KeyboardInterrupt: 