## Predicting patient risk to develop Pancreatic Cancer with Med-BERT finetuned model

In [1]:
### Required Packages
from termcolor import colored
import math
from sklearn.model_selection import train_test_split
import pandas as pd
import random
import numpy as np
from datetime import datetime
import pickle as pkl
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch import optim
import tqdm
import time
import transformers

from sklearn.metrics import roc_auc_score  
from sklearn.metrics import roc_curve 

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
%matplotlib inline

use_cuda = torch.cuda.is_available()
import transformers
from transformers import BertForSequenceClassification


I0510 22:52:33.998448 140259224065792 file_utils.py:35] PyTorch version 1.2.0 available.


#### Data Preprocessing

In [2]:
### The version in this file is updated to dump pt_id for better visualization and results analysis

### Below are key functions for  Data prepartion ,formating input data into features, and model defintion 
class PaddingInputExample(object):
  """Fake example so the num input examples is a multiple of the batch size.
  Based on original BERT code: We use this class instead of `None` because treating `None` as padding
  batches could cause silent errors.
  """

class InputFeatures(object):
  """A single set of features of data."""

  def __init__(self,
               input_ids,
               input_mask,
               segment_ids,
               label_id,
               pt_id,
               is_real_example=True):
    self.input_ids = input_ids
    self.input_mask = input_mask
    self.segment_ids = segment_ids
    self.label_id = label_id
    self.is_real_example = is_real_example
    self.pt_id = pt_id
    
    
def convert_EHRexamples_to_features(examples,max_seq_length):
    """Convert a set of `InputExample`s to a list of `InputFeatures`."""

    features = []
    for (ex_index, example) in enumerate(examples):
        feature = convert_singleEHR_example(ex_index, example, max_seq_length)
        features.append(feature)
    return features

### This is the EHR version
def convert_singleEHR_example(ex_index, example, max_seq_length):
    if isinstance(example, PaddingInputExample):
        return InputFeatures(
        input_ids=[0] * max_seq_length,
        input_mask=[0] * max_seq_length,
        segment_ids=[0] * max_seq_length,
        label_id=0,
        is_real_example=False)
    
    input_ids=example[2]
    segment_ids=example[3]
    label_id=example[1]
    pt_id=example[0]
    
  # The mask has 1 for real tokens and 0 for padding tokens. Only real
  # tokens are attended to.
    input_mask = [1] * len(input_ids)

   
  # LR 5/13 Left Truncate longer sequence 
    while len(input_ids) > max_seq_length:
        input_ids= input_ids[-max_seq_length:] 
        input_mask= input_mask[-max_seq_length:]
        segment_ids= segment_ids[-max_seq_length:]   
    
  # Zero-pad up to the sequence length.
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

  
    feature =[input_ids,input_mask,segment_ids,label_id,pt_id,True]
    return feature

### DataLoader

In [3]:
class BERTdataEHR(Dataset):
    def __init__(self, Features):
           
        self.data= Features
  
                                     
    def __getitem__(self, idx, seeDescription = False):

        sample = self.data[idx]
   
        return sample

    def __len__(self):
        return len(self.data)     

         
#customized parts for EHRdataloader
def my_collate(batch):
        all_input_ids = []
        all_input_mask = []
        all_segment_ids = []
        all_label_ids = []
        all_pt_ids=[]

        for feature in batch:
            all_input_ids.append(feature[0])
            all_input_mask.append(feature[1])
            all_segment_ids.append(feature[2])
            all_label_ids.append(feature[3])
            all_pt_ids.append(feature[4])
        return [[all_input_ids, all_input_mask,all_segment_ids,all_label_ids],all_pt_ids]
            

class BERTdataEHRloader(DataLoader):
    def __init__(self, dataset, batch_size=128, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=my_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        DataLoader.__init__(self, dataset, batch_size=batch_size, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=my_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None)
        self.collate_fn = collate_fn

 

##### Model Definition

In [4]:
class EHR_BERT_RNN(nn.Module):
    def __init__(self, input_size,embed_dim, hidden_size, n_layers=1,dropout_r=0.1,cell_type='LSTM',bi=False,emb=''):
        super(EHR_BERT_RNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embed_dim = embed_dim
        self.dropout_r = dropout_r
        self.cell_type = cell_type
        if emb=='brt':
            self.brt=True
            self.pretrain=False
        elif len(emb)>3:
            self.brt=False
            self.pretrain=True
            self.pretrained_emb=load_pretrain_w2vemb(emb,self.embed_dim)
        else:
            self.brt=False
            self.pretrain=False
        
        if bi: self.bi=2 
        else: self.bi=1
        
        if use_cuda:
            self.flt_typ=torch.cuda.FloatTensor
            self.lnt_typ=torch.cuda.LongTensor
        else: 
            self.lnt_typ=torch.LongTensor
            self.flt_typ=torch.FloatTensor
        
        if self.brt:
            self.PreBERTmodel=BertForSequenceClassification.from_pretrained("./")
            input_size=self.PreBERTmodel.bert.config.vocab_size
            self.in_size= self.PreBERTmodel.bert.config.hidden_size
        
        elif self.pretrain: ## to add W2Vec embedding as example
            self.embed= nn.Embedding.from_pretrained(self.pretrained_emb)#,freeze=False)  
            self.vembed= nn.Embedding(500, self.embed_dim,padding_idx=0)
            input_size=self.pretrained_emb.shape[0]+1
            self.in_size= self.pretrained_emb.shape[1]
        else:
            input_size=input_size
            self.embed= nn.Embedding(input_size, self.embed_dim,padding_idx=0)#,scale_grad_by_freq=True)
            self.vembed= nn.Embedding(500, self.embed_dim,padding_idx=0)
            self.in_size= embed_dim
            #if self.time: self.in_size= self.in_size+1  ### place holder - no time information in here yet
               
        if self.cell_type == "GRU":
            cell = nn.GRU
        elif self.cell_type == "RNN":
            cell = nn.RNN
        elif self.cell_type == "LSTM":
            cell = nn.LSTM
        else:
            raise NotImplementedError
        
        self.dropout = nn.Dropout(p=self.dropout_r)
        #self.embed_size = embed_size
        self.rnn_c = cell(self.in_size, hidden_size,num_layers=n_layers, dropout= dropout_r , bidirectional=bi , batch_first=True)
        self.out = nn.Linear(self.hidden_size*self.bi,1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, sequence):

        token_t=torch.from_numpy(np.asarray(sequence[0],dtype=int)).type(self.lnt_typ)
        seg_t=torch.from_numpy(np.asarray(sequence[2],dtype=int)).type(self.lnt_typ)
        Label_t=torch.from_numpy(np.asarray(sequence[3],dtype=int)).type(self.flt_typ)
        
        if self.brt:
            Bert_out=self.PreBERTmodel.bert(input_ids=token_t, attention_mask=torch.from_numpy(np.asarray(sequence[1],dtype=int)).type(self.lnt_typ),
                                  token_type_ids=seg_t)
            output, hidden = self.rnn_c(Bert_out[0])#,h_0) 
        else:
            embeddings = self.embed(token_t)+self.vembed(seg_t)
            #embeddings = self.LayerNorm(embeddings)
            #in1= self.dropout(embeddings)
            in1=embeddings
            output, hidden = self.rnn_c(in1)#,h_0) 
        
        if self.cell_type == "LSTM":
            hidden=hidden[0]
        if self.bi==2:
            output = self.sigmoid(self.out(torch.cat((hidden[-2],hidden[-1]),1)))
        else:
            output = self.sigmoid(self.out(hidden[-1]))
        return output.squeeze(),Label_t


 

In [5]:
#### Print Predictions function
def print_preds(model, mbs_list, shuffle = True): 
    model.eval() 
    y_real =[]
    y_hat= []
    pts_sk=[]
    if shuffle: 
        random.shuffle(mbs_list)
    for i,batch in enumerate(mbs_list):
        pt_skl=batch[1]
        output,label_tensor = model(batch[0])
        y_hat.extend(output.cpu().data.view(-1).numpy())  
        y_real.extend(label_tensor.cpu().data.view(-1).numpy())
        pts_sk.extend(pt_skl)
    
    auc = roc_auc_score(y_real, y_hat)
    return auc, pts_sk,y_real, y_hat 

#### Load Data from pickled list

The pickled list is a list of lists where each sublist represent a patient record that looks like 
[pt_id,label, seq_list , segment_list ]
where
    Label: 1: pt developed HF (case) , 0 control
    seq_list: list of all medical codes in all visits
    segment list: the visit number mapping to each code in the sequence list
 

In [32]:
train_f=pkl.load( open('pdata/lr_pc_cid_btexp.combined_BertFT.train', 'rb'), encoding='bytes')
valid_f=pkl.load( open('pdata/lr_pc_cid_btexp.combined_BertFT.valid', 'rb'), encoding='bytes')
test_f=pkl.load( open('pdata/lr_pc_cid_btexp.combined_BertFT.test', 'rb'), encoding='bytes')

In [7]:
print (len(train_f),len(test_f),len(valid_f))

20000 5000 2500


### Print Predictions

In [8]:
MAX_SEQ_LENGTH = 64
BATCH_SIZE = 100
bert_config_file= "sc3_config.json"

results=[]

loaded_model= torch.load('PC_BIGRU_BRT.pth')
loaded_model.load_state_dict(torch.load('PC_BIGRU_BRT.st'))
loaded_model.eval()

train_features = convert_EHRexamples_to_features(train_f, MAX_SEQ_LENGTH) 
test_features = convert_EHRexamples_to_features(test_f, MAX_SEQ_LENGTH)
valid_features = convert_EHRexamples_to_features(valid_f, MAX_SEQ_LENGTH)

train = BERTdataEHR(train_features)
test = BERTdataEHR(test_features)
valid = BERTdataEHR(valid_features)
           
print (' creating the list of training minibatches')
train_mbs = list(BERTdataEHRloader(train, batch_size = BATCH_SIZE))
print (' creating the list of test minibatches')
test_mbs = list(BERTdataEHRloader(test, batch_size = BATCH_SIZE))

print (' creating the list of valid minibatches')
valid_mbs = list(BERTdataEHRloader(valid, batch_size = BATCH_SIZE))


auc_tr, pts_sk_tr,y_real_tr, y_hat_tr=print_preds(loaded_model, train_mbs)

 creating the list of training minibatches
 creating the list of test minibatches
 creating the list of valid minibatches


In [19]:
auc_tr   #0.9018740796391604

0.9018740796391604

In [21]:
pred_df=pd.DataFrame({'Pt_sk':pts_sk_tr,'Label':y_real_tr, 'Prediction':y_hat_tr})
pred_df[(pred_df['Prediction']>0.9) & (pred_df['Label']==1)]

In [None]:
train_f=pkl.load( open('pdata/lr_pc_cid_btexp.combined_BertFT.train', 'rb'), encoding='bytes')
types_d=pkl.load(open('pdata/lr_pc_cid_btexp.types.valid', 'rb'), encoding='bytes')
types_d_rev = dict(zip(types_d.values(),types_d.keys()))
diag_det= pd.read_table('/data/LR_test/panc/data/HF_D_DIAGNOSIS',sep='|')
for x in train_f:
    if x[0] in [ pt1 ,pt2 ,pt3 ,pt4 ,pt5 ,pt6 ,pt7 ,pt8 ,pt9 ,pt10 ]:
        print (x)
        for i in x[2]:
              print(diag_det[['DIAGNOSIS_CODE','DIAGNOSIS_DESCRIPTION']][diag_det['DIAGNOSIS_ID']==types_d_rev[i]])  

From the output of the above cell we found:

Codes that are the key of high score are:
    577.2 CYST AND PSEUDOCYST OF PANCREAS
    
    576.2  OBSTRUCTION OF BILE DUCT
    783.21        LOSS OF WEIGHT
    441.4  ABDOMINAL AORTIC ANEURYSM 
    
    