In [None]:
!nvidia-smi

Tue May 18 16:11:43 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   66C    P8    12W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install -q wandb
!pip install -q transformers

[K     |████████████████████████████████| 1.8MB 23.0MB/s 
[K     |████████████████████████████████| 174kB 54.6MB/s 
[K     |████████████████████████████████| 133kB 54.3MB/s 
[K     |████████████████████████████████| 102kB 14.1MB/s 
[K     |████████████████████████████████| 71kB 10.7MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 2.3MB 24.0MB/s 
[K     |████████████████████████████████| 901kB 36.0MB/s 
[K     |████████████████████████████████| 3.3MB 47.8MB/s 
[?25h

In [None]:
import wandb
import torch
import sklearn
import numpy as np
import pandas as pd
from torch import nn
from sklearn import metrics
from tqdm.notebook import tqdm
from transformers import AdamW
from tokenizers import Tokenizer
from torch.nn import functional as F
from transformers import BertConfig, BertModel
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from transformers import get_linear_schedule_with_warmup
from sklearn.utils.class_weight import compute_class_weight

In [None]:
class BERTBPETokenizer:
    def __init__(self, tokenizer, padding=True, truncation=True, max_length=512):
        self.tokenizer = tokenizer
        self.tokenizer.enable_padding(
            pad_id=tokenizer.token_to_id('[PAD]')
            , pad_type_id=0, pad_token='[PAD]')
        self.tokenizer.enable_truncation(max_length)
    
    @classmethod
    def from_pretrained(cls, path, padding=True, truncation=True, max_length=512):
        tokenizer = Tokenizer.from_file(path)
        return cls(tokenizer, padding=padding, truncation=truncation, max_length=max_length)
        
    def __call__(self, data):
        if isinstance(data, str):
            out = self.tokenizer.encode(data)
            input_ids = out.ids
            tokens = out.tokens
            attention_mask = out.attention_mask
            segment_id = out.type_ids
            
        elif isinstance(data, tuple):
            out = self.tokenizer.encode(data[0], data[1])
            input_ids = out.ids
            tokens = out.tokens
            attention_mask = out.attention_mask
            segment_id = out.type_ids
            
        elif isinstance(data, list):
            out_list =  self.tokenizer.encode_batch(data)
            input_ids, tokens, attention_mask,  segment_id = [], [], [], []
            for out in out_list:
                input_ids.append(out.ids)
                tokens.append(out.tokens)
                attention_mask.append(out.attention_mask)
                segment_id.append(out.type_ids)
        return torch.tensor(input_ids), torch.tensor(segment_id), torch.tensor(attention_mask), tokens


class PatDataset(Dataset):
    def __init__(self, dx_px, rx, y, tokenizer):
        self.dx_px = dx_px
        self.rx = rx
        self.y = y
        self.tokenizer = tokenizer
            
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.dx_px[idx], self.rx[idx], self.y[idx]

    def dynamic_batching(self, batch):
        batch = np.array(batch)
        sent_pairs = list(batch[:, 0])
        rx_data = torch.from_numpy(np.stack(batch[:, 1])).type(torch.float32)
        x = self.tokenizer(sent_pairs)
        y = torch.from_numpy(batch[:, 2].astype(int))
        return x[0], x[1], x[2], rx_data, y


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert_base = BertModel.from_pretrained(
            '/content/drive/MyDrive/ColabData/saved_models/PatientBERT/mimic-3-bert-base'
            )    
        self.dx_px_seq = nn.Sequential(
            nn.LayerNorm(768, eps=1e-12, elementwise_affine=True)
            )  
        self.rx_seq = nn.Sequential(
            nn.Linear(806, 384),
            nn.LayerNorm(384, eps=1e-12, elementwise_affine=True)
            )
        self.concat_seq = nn.Sequential(
            nn.Linear(1152, 384),
            nn.Dropout(p=0.1),
            nn.Linear(384, 2),
            )

    def forward(self, tkn_ids, sent_ids, attn_mask, rx_embed):
        bert_features = self.bert_base(input_ids=tkn_ids
                           , attention_mask=attn_mask
                           , token_type_ids=sent_ids)[0].mean(axis=1)
        dx_px_features = self.dx_px_seq(bert_features)
        rx_features = self.rx_seq(rx_embed)
        concat = torch.cat((dx_px_features, rx_features), dim=1)
        return self.concat_seq(concat)


class Trainer:
    def __init__(self, train_data, val_data):
        self.train_data = train_data
        self.val_data = val_data
        if torch.cuda.is_available():
            self.dev = torch.device('cuda')
        else:
            self.dev = torch.device('cpu')
        print("Using: ", self.dev)

    @staticmethod
    def _roc_auc(y_true, y_raw_logits, class_idx=1):
        y_score = F.softmax(y_raw_logits, dim=1).cpu().detach().numpy()[:, class_idx]
        return metrics.roc_auc_score(y_true.cpu().detach().numpy(), y_score)

    @staticmethod
    def _accuracy(true, pred_proba, class_idx=None):
        preds = torch.argmax(pred_proba, dim=1)
        if class_idx is None:
            score = (true == preds).float().mean()
        else:
            score = (true==preds)[true==class_idx].float().mean()
        return score

    @staticmethod
    def _get_optimizer_with_decay(model, lr):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01, 'lr':lr},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, 'lr':lr}
             ]
        return AdamW(optimizer_grouped_parameters, lr=lr)

    def _get_loss_func(self, class_wts=None):
        if class_wts is not None:
            loss_fn = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_wts).type(torch.float).to(self.dev))
        else:
            loss_fn = torch.nn.CrossEntropyLoss()
        return loss_fn
    
    @staticmethod
    def get_class_wts(y):
        return compute_class_weight('balanced', np.unique(y), y)

    def train_model(self, model, n_epochs=1, lr=1e-5, class_wts=None, path_to_save_model='./model.tar'
                , logging_step=5, patience=3, max_val_accuracy=0, use_scheduler=True, wandb_project=None):
        
        if wandb_project is not None:
            run = wandb.init(project=wandb_project, reinit=True)

        loss_fn = self._get_loss_func(class_wts=class_wts)
        optimizer = Trainer._get_optimizer_with_decay(model, lr)
        if use_scheduler:
                warmup_steps = int(len(self.train_data)*n_epochs*0.34)
                total_steps = len(self.train_data)*n_epochs
                print(f"Using linear lr scheduler with {warmup_steps} warmup steps and {total_steps} total steps")
                linear_scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

        model.to(self.dev)
        trn_loss = []
        val_loss = []
        val_accu = []
        max_val_accu = max_val_accuracy
        count = 0
        for epoch in range(n_epochs):
            print(f"Epoch: {epoch}")
            trn_loss_per_epoch = []
            model.train()
            for i, batch in enumerate(tqdm(self.train_data)):
                batch = [x.to(self.dev) for x in batch]
                outputs = model(*batch[:-1])
                loss = loss_fn(outputs, batch[-1])
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                if use_scheduler:
                    linear_scheduler.step()
                    lr_0, _ = linear_scheduler.get_last_lr()
                    wandb.log({"lr": lr_0})
                del batch
                torch.cuda.empty_cache()
                trn_loss_per_epoch.append(float(loss))
                if i % logging_step == 0:
                    if wandb_project is not None:
                        wandb.log({"train_loss": np.mean(trn_loss_per_epoch)})
                    else:
                        print("train_loss: ", np.mean(trn_loss_per_epoch))
            trn_loss.append(np.mean(trn_loss_per_epoch))
            
            model.eval()
            valid_loss_per_epoch = []
            yhat_lst = []
            y_lst = []
            with torch.no_grad():
                for batch in tqdm(self.val_data):
                    batch = [x.to(self.dev) for x in batch]
                    outputs = model(*batch[:-1])
                    valid_loss_per_epoch.append(float(loss_fn(outputs, batch[-1])))
                    yhat_lst.append(outputs)
                    y_lst.append(batch[-1])   
                tmp = Trainer._roc_auc(torch.cat(y_lst, dim=0), torch.cat(yhat_lst, dim=0))
            val_loss.append(np.mean(valid_loss_per_epoch))
            val_accu.append(tmp)
            
            print(f"Training Loss for epoch {epoch}: ", trn_loss[-1])
            print("Validation Loss: ", val_loss[-1], "| Validation roc_auc: ", val_accu[-1])
            if wandb_project is not None:
                wandb.log({"trn_loss_epoch": trn_loss[-1]})
                wandb.log({"val_loss_epoch": val_loss[-1]})
                wandb.log({"val_roc_auc_epoch": val_accu[-1]})

            if val_accu[-1] > max_val_accu:
                count = 0
                max_val_accu = val_accu[-1]
                if isinstance(model, torch.nn.DataParallel):
                    torch.save({
                                'epoch': epoch,
                                'model_state_dict': model.module.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'training_loss': trn_loss[-1],
                                'val_loss': val_loss[-1],
                                'val_accuracy': val_accu[-1]
                                }, path_to_save_model)
                    print(f"Model & optimizer state dictionaries saved")
                else:
                    torch.save({
                                'epoch': epoch,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'training_loss': trn_loss[-1],
                                'val_loss': val_loss[-1],
                                'val_accuracy': val_accu[-1]
                                }, path_to_save_model)
                    print(f"Model & optimizer state dictionaries saved")
            else:
                count = count + 1

            if count >= patience:
                print("Stopping early, restoring best weights..")
                model_parameters = torch.load(path_to_save_model, map_location = self.dev)
                model.load_state_dict(model_parameters['model_state_dict'])
                print("Best weights loaded!")
                break 
            print("Max roc_auc till now: ", max_val_accu)
        return model.eval(), optimizer

In [None]:
tokenizer = BERTBPETokenizer.from_pretrained(
    "/content/drive/MyDrive/ColabData/saved_models/PatientBERT/bert-bpe-icd.json", max_length=300)

In [None]:
cols = ['pat_id', 'px_journey', 'dx_journey']
dfy = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/target.pkl")
dfx = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/pat_journey_data.pkl")

cols_idx_map = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/rx_col_idx_map.pkl")
idx_col_map = {v:k for k,v in cols_idx_map.items()}
col_lst = [idx_col_map[i] for i in range(len(cols_idx_map))]

a = [0]*len(cols_idx_map)
no_rx_pats = []
for r,p in zip(dfx.rx_count, dfx.pat_id):
    if r==a:
        no_rx_pats.append(p)

df = dfx[(dfx.px_journey!='') & (dfx.dx_journey!='')]
df = df[~df['pat_id'].isin(no_rx_pats)][cols]
targets = dfy[dfy.pat_id.isin(df.pat_id)]
assert df.pat_id.tolist() == targets.pat_id.tolist()
del dfx, dfy

X1, X_test, y1, y_test = train_test_split(df, targets, test_size=0.20, random_state=42, stratify=targets['switch_flag'].tolist())
X_train, X_val, y_train, y_val = train_test_split(X1, y1, test_size=0.15, random_state=42, stratify=y1['switch_flag'].tolist())

In [None]:
rx_trn = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/imp_rx_trn.pkl")
rx_val = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/imp_rx_val.pkl")
rx_test = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/imp_rx_test.pkl")

In [None]:
dx_px_lst = list(zip(X_train['dx_journey'].tolist(), X_train['px_journey'].tolist()))
trn_ds = PatDataset(dx_px_lst, rx_trn.values, y_train['switch_flag'].tolist(), tokenizer)
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=32, collate_fn=trn_ds.dynamic_batching, num_workers=0, pin_memory=False)

In [None]:
dx_px_lst = list(zip(X_val['dx_journey'].tolist(), X_val['px_journey'].tolist()))
val_ds = PatDataset(dx_px_lst, rx_val.values, y_val['switch_flag'].tolist(), tokenizer)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32, collate_fn=trn_ds.dynamic_batching, num_workers=0, pin_memory=False)

In [None]:
dx_px_lst = list(zip(X_test['dx_journey'].tolist(), X_test['px_journey'].tolist()))
test_ds = PatDataset(dx_px_lst, rx_test.values, y_test['switch_flag'].tolist(), tokenizer)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=32, collate_fn=trn_ds.dynamic_batching, num_workers=0, pin_memory=False)

In [None]:
model = MyModel()

Some weights of the model checkpoint at /content/drive/MyDrive/ColabData/saved_models/PatientBERT/mimic-3-bert-base were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
trainer = Trainer(trn_dl, val_dl)
class_wts = trainer.get_class_wts(y_train['switch_flag'].tolist())
class_wts

Using:  cuda


array([ 0.51734004, 14.9175    ])

In [None]:
model, optimizer = trainer.train_model(model, n_epochs=20, lr=2e-5, class_wts=class_wts,
                                       path_to_save_model='./model.tar',
                                       logging_step=1, patience=3, use_scheduler=False,
                                       wandb_project="sepsis-readmission-full-bert-without-static")

[34m[1mwandb[0m: Currently logged in as: [33mmeet14[0m (use `wandb login --relogin` to force relogin)


Epoch: 0


HBox(children=(FloatProgress(value=0.0, max=746.0), HTML(value='')))






HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))


Training Loss for epoch 0:  0.6717811963014085
Validation Loss:  0.6142274229363962 | Validation roc_auc:  0.6777570464677506
Model & optimizer state dictionaries saved
Max roc_auc till now:  0.6777570464677506
Epoch: 1


HBox(children=(FloatProgress(value=0.0, max=746.0), HTML(value='')))






HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))


Training Loss for epoch 1:  0.6116044130825485
Validation Loss:  0.6148574596101587 | Validation roc_auc:  0.689274247357629
Model & optimizer state dictionaries saved
Max roc_auc till now:  0.689274247357629
Epoch: 2


HBox(children=(FloatProgress(value=0.0, max=746.0), HTML(value='')))






HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))


Training Loss for epoch 2:  0.5675785834364213
Validation Loss:  0.6342719982970845 | Validation roc_auc:  0.6859433007381391
Max roc_auc till now:  0.689274247357629
Epoch: 3


HBox(children=(FloatProgress(value=0.0, max=746.0), HTML(value='')))






HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))


Training Loss for epoch 3:  0.5150040463911464
Validation Loss:  0.7061879284905664 | Validation roc_auc:  0.6631702179923381
Max roc_auc till now:  0.689274247357629
Epoch: 4


HBox(children=(FloatProgress(value=0.0, max=746.0), HTML(value='')))






HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))


Training Loss for epoch 4:  0.4701138379966126
Validation Loss:  0.7471615910304316 | Validation roc_auc:  0.6589786606876871
Stopping early, restoring best weights..
Best weights loaded!


In [None]:
torch.save(
    {'model_state_dict': model.state_dict()},
    "/content/drive/MyDrive/ColabData/saved_models/PatientBERT/full_bert/sepsis-readmission-full.pt"
    )

In [None]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = MyModel().to(dev)
model_parameters = torch.load(
    "/content/drive/MyDrive/ColabData/saved_models/PatientBERT/full_bert/sepsis-readmission-full.pt",
    map_location = dev
    )
model.load_state_dict(model_parameters['model_state_dict'])
model = model.eval()

In [None]:
dl = trn_dl
yhat_lst = []
y_lst = []
with torch.no_grad():
    for batch in tqdm(dl):
        batch = [x.to(dev) for x in batch]
        yhat = model(*batch[:-1])
        yhat_lst.append(yhat)
        y_lst += list(batch[-1].cpu().numpy())

print(Trainer._roc_auc(torch.tensor(y_lst), torch.cat(yhat_lst, dim=0)))

HBox(children=(FloatProgress(value=0.0, max=746.0), HTML(value='')))




0.7922010198109936


In [None]:
dl = val_dl
yhat_lst = []
y_lst = []
with torch.no_grad():
    for batch in tqdm(dl):
        batch = [x.to(dev) for x in batch]
        yhat = model(*batch[:-1])
        yhat_lst.append(yhat)
        y_lst += list(batch[-1].cpu().numpy())

print(Trainer._roc_auc(torch.tensor(y_lst), torch.cat(yhat_lst, dim=0)))

HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))




0.689274247357629


In [None]:
dl = test_dl
yhat_lst = []
y_lst = []
with torch.no_grad():
    for batch in tqdm(dl):
        batch = [x.to(dev) for x in batch]
        yhat = model(*batch[:-1])
        yhat_lst.append(yhat)
        y_lst += list(batch[-1].cpu().numpy())

print(Trainer._roc_auc(torch.tensor(y_lst), torch.cat(yhat_lst, dim=0)))

HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))




0.6933360510512865
