In [1]:
from transformers import AutoModelWithLMHead,BertForSequenceClassification, AutoTokenizer, AutoModel,AutoModelForMaskedLM,AutoModelForSequenceClassification
import torch
from torch import nn
import json
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split,StratifiedShuffleSplit
from torch.utils.data import DataLoader,TensorDataset
from transformers import Trainer, TrainingArguments
import pickle
from sklearn.metrics import confusion_matrix,classification_report
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score,roc_curve
import pandas as pd
import matplotlib.pyplot as plt


# Fine-tune Masked Language Model

In [4]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
new_tokens = ["interstitial", "fibrosis", "tubular", "atrophy","antibody","T-cell"]
tokenizer.add_tokens(new_tokens)

6

In [5]:
tokenizer.tokenize("interstitial fibrosis and tubular atrophy. T-cell mediated rejection. antibody ")

['interstitial',
 'fibrosis',
 'and',
 'tubular',
 'atrophy',
 '.',
 't-cell',
 'mediated',
 'rejection',
 '.',
 'antibody']

In [3]:
tokenizer.tokenize("interstitial fibrosis and tubular atrophy. T-cell mediated rejection. antibody ")

['inter',
 '##st',
 '##iti',
 '##al',
 'fi',
 '##bro',
 '##sis',
 'and',
 'tub',
 '##ular',
 'at',
 '##rop',
 '##hy',
 '.',
 't',
 '-',
 'cell',
 'mediated',
 'rejection',
 '.',
 'anti',
 '##body']

In [4]:
model = AutoModelForMaskedLM.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model.resize_token_embeddings(len(tokenizer))

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Embedding(29002, 768)

In [5]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(29002, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
           

In [6]:
data = pd.read_csv("data.csv")

In [7]:
inputs = data["Raw Case Text"].tolist()

In [8]:
input_encoding = tokenizer(inputs,padding="max_length", truncation=True, 
                            return_tensors="pt",max_length=512)

In [9]:
input_encoding['labels'] = input_encoding.input_ids.detach().clone()
input_encoding.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

In [10]:
rand = torch.rand(input_encoding.input_ids.shape)
mask_arr = (rand < 0.15) * (input_encoding.input_ids != 101) * \
           (input_encoding.input_ids != 102) * (input_encoding.input_ids != 0)
mask_arr

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        ...,
        [False, False, False,  ...,  True, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False]])

In [11]:
len(mask_arr)

3429

In [12]:
mask_pos = [torch.flatten(mask_arr[i].nonzero()).tolist() for i in range(input_encoding.input_ids.shape[0])]

In [13]:
for i in range(input_encoding.input_ids.shape[0]):
    input_encoding.input_ids[i, mask_pos[i]] = 103

In [14]:
class MaskedDataset(torch.utils.data.Dataset):
    def __init__(self, encoding):
        self.encoding = encoding
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encoding.items()}
    def __len__(self):
        return len(self.encoding.input_ids)

masked_dataset = MaskedDataset(input_encoding)

In [15]:

training_args = TrainingArguments(
    output_dir='./mlm_results_largeData_extended_tokenizer',          
    num_train_epochs=5,              
    per_device_train_batch_size=8,  
    #per_device_eval_batch_size=64,   
    #warmup_steps=50,                
    #weight_decay=0.01,                          
    logging_steps=100,
    #evaluation_strateg="steps",
    #eval_steps=100,
    #load_best_model_at_end=True,
    save_steps = 100,
    save_total_limit = 10,
    seed = 0
)


trainer = Trainer(
    model=model,                         
    args=training_args,                 
    train_dataset=masked_dataset,         
)

trainer.train()


***** Running training *****
  Num examples = 3429
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2145
  return {key: torch.tensor(val[idx]) for key, val in self.encoding.items()}


Step,Training Loss
100,0.1817
200,0.0727
300,0.0542
400,0.0473
500,0.0332
600,0.0298
700,0.0291
800,0.0305
900,0.0239
1000,0.0189


Saving model checkpoint to ./mlm_results_largeData_extended_tokenizer\checkpoint-100
Configuration saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-100\config.json
Model weights saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-100\pytorch_model.bin
  return {key: torch.tensor(val[idx]) for key, val in self.encoding.items()}
Saving model checkpoint to ./mlm_results_largeData_extended_tokenizer\checkpoint-200
Configuration saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-200\config.json
Model weights saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-200\pytorch_model.bin
  return {key: torch.tensor(val[idx]) for key, val in self.encoding.items()}
Saving model checkpoint to ./mlm_results_largeData_extended_tokenizer\checkpoint-300
Configuration saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-300\config.json
Model weights saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-300\pytorch_model.bin
  return {key: 

Deleting older checkpoint [mlm_results_largeData_extended_tokenizer\checkpoint-400] due to args.save_total_limit
  return {key: torch.tensor(val[idx]) for key, val in self.encoding.items()}
Saving model checkpoint to ./mlm_results_largeData_extended_tokenizer\checkpoint-1500
Configuration saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-1500\config.json
Model weights saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-1500\pytorch_model.bin
Deleting older checkpoint [mlm_results_largeData_extended_tokenizer\checkpoint-500] due to args.save_total_limit
  return {key: torch.tensor(val[idx]) for key, val in self.encoding.items()}
Saving model checkpoint to ./mlm_results_largeData_extended_tokenizer\checkpoint-1600
Configuration saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-1600\config.json
Model weights saved in ./mlm_results_largeData_extended_tokenizer\checkpoint-1600\pytorch_model.bin
Deleting older checkpoint [mlm_results_largeData_extended_to

KeyboardInterrupt: 

# Test Trained MLM Model

In [3]:
model_renal = AutoModelForMaskedLM.from_pretrained("./mlm_results_largeData_extended_tokenizer/checkpoint-1100")

In [6]:
original_sent = f"Comment: The biopsy shows severe interstitial inflammation and tubulitis (i3/t3), \
which are diagnostic for acute T-cell-mediated rejection, type IB."

masked_sent = f"Comment: The {tokenizer.mask_token} shows {tokenizer.mask_token} interstitial {tokenizer.mask_token} and tubulitis (i3/t3), \
which are {tokenizer.mask_token} for {tokenizer.mask_token} T-cell-mediated {tokenizer.mask_token}, {tokenizer.mask_token} IB."


In [7]:
tokenized_sent = tokenizer.encode(masked_sent, return_tensors="pt")
mask_token_index = torch.where(tokenized_sent == tokenizer.mask_token_id)[1]

mask_token_index

tensor([ 4,  6,  8, 23, 25, 29, 31])

In [9]:
# clinical model
token_logits = model_renal(tokenized_sent).logits
mask_token_logits = token_logits[0, mask_token_index, :]
top_3_tokens = torch.topk(mask_token_logits, 3, dim=1).indices.tolist()
print("Original Sentence: \n", original_sent, "\n")
for words in zip(*top_3_tokens):
    new_sent = masked_sent
    for i,token in enumerate(words):
        new_sent = new_sent.replace(tokenizer.mask_token,f'{i}*{tokenizer.decode([token])}*',1)
    print(new_sent,"\n")
    
# token_logits = model(tokenized_sent).logits
# mask_token_logits = token_logits[0, mask_token_index, :]
# top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
# for token in top_5_tokens:
#     print(masked_sent.replace(tokenizer.mask_token,tokenizer.decode([token])))

Original Sentence: 
 Comment: The biopsy shows severe interstitial inflammation and tubulitis (i3/t3), which are diagnostic for acute T-cell-mediated rejection, type IB. 

Comment: The 0*patient* shows 1*moderate* interstitial 2*inflammation* and tubulitis (i3/t3), which are 3*typical* for 4*acute* T-cell-mediated 5*rejection*, 6*and* IB. 

Comment: The 0*specimen* shows 1*mild* interstitial 2*fibrosis* and tubulitis (i3/t3), which are 3*diagnostic* for 4*a* T-cell-mediated 5*diabetes*, 6*including* IB. 

Comment: The 0*abdomen* shows 1*severe* interstitial 2*congestion* and tubulitis (i3/t3), which are 3*suspicious* for 4*chronic* T-cell-mediated 5*injury*, 6*but* IB. 



In [20]:
# model_renal
token_logits = model_renal(tokenized_sent).logits
mask_token_logits = token_logits[0, mask_token_index, :]
top_3_tokens = torch.topk(mask_token_logits, 3, dim=1).indices.tolist()
print("Original Sentence: \n", original_sent, "\n")
for words in zip(*top_3_tokens):
    new_sent = masked_sent
    for i,token in enumerate(words):
        new_sent = new_sent.replace(tokenizer.mask_token,f'{i}*{tokenizer.decode([token])}*',1)
    print(new_sent,"\n")


Original Sentence: 
 Comment: The biopsy shows severe interstitial inflammation and tubulitis (i3/t3), which are diagnostic for acute T-cell-mediated rejection, type IB. 

Comment: The 0*patient* shows 1*moderate* interstitial 2*inflammation* and tubulitis (i3/t3), which are 3*typical* for 4*acute* T-cell-mediated 5*rejection*, 6*and* IB. 

Comment: The 0*specimen* shows 1*mild* interstitial 2*fibrosis* and tubulitis (i3/t3), which are 3*diagnostic* for 4*a* T-cell-mediated 5*diabetes*, 6*including* IB. 

Comment: The 0*abdomen* shows 1*severe* interstitial 2*congestion* and tubulitis (i3/t3), which are 3*suspicious* for 4*chronic* T-cell-mediated 5*injury*, 6*but* IB. 

