# Training BlueBERT for contradiction classification

This notebook uses transfer learning on BlueBert to train a binary classification model to determine if a pair of drug-treatment sentences contain any contradiction.

Code based on [this](https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta) and [this](https://github.com/CoronaWhy/drug-lit-contradictory-claims/blob/master/src/contradictory_claims/data/make_dataset.py).

### To Run:

Make sure you have the working directory setup as in [this](https://github.com/CoronaWhy/drug-lit-contradictory-claims/blob/master/src/contradictory_claims/data/make_dataset.py) notebook.

Install PyTorch BlueBert:

Clone [BlueBert repo](https://github.com/ncbi-nlp/bluebert) into current proj_path

Mac Run:
```
export NCBI_DIR=<directory_path_to_NCBI_BERT>transformers-cli convert --model_type bert \
  --tf_checkpoint $NCBI_DIR/bert_model.ckpt \
  --config $NCBI_DIR/bert_config.json \
  --pytorch_dump_output $NCBI_DIR/pytorch_model.bin
```

Windows Run:


```
set NCBI_DIR=<directory_path_to_NCBI_BERT>python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path=%NCBI_Directory%/bert_model.ckpt --bert_config_file=%NCBI_Directory%/bert_config.json --pytorch_dump_path=%NCBI_Directory%/pytorch_model.bin
```

Rename bert-config.json -> config.json

Alternatively, follow the instructions [here](https://medium.com/@manasmohanty/ncbi-bluebert-ncbi-bert-using-tensorflow-weights-with-huggingface-transformers-15a7ec27fc3d)

In [None]:
# Only necessary if running in colab
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# %%capture
!pip install transformers
import os
import shutil
import json
import math
import time
import datetime
import numpy as np
import pandas as pd
from pathlib import Path
import xml.etree.ElementTree as et 
from itertools import permutations


from keras.utils import np_utils
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint
import transformers
from transformers import AutoModel
from transformers import AdamW, TFAutoModel, AutoTokenizer, AutoModelWithLMHead, BertTokenizer, BertModel, TFBertModel, get_linear_schedule_with_warmup
from tqdm.notebook import tqdm
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
import tensorflow.keras.backend as K
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
from tensorflow.keras.callbacks import EarlyStopping
from torchsummary import summary
import pickle

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")    
    print('There are %d GPU(s) available.' % torch.cuda.device_count())    
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

### Model definition and preprocessing utilities

In [None]:
def regular_encode(texts, tokenizer, max_len=512):
    """Tokenize a sentence as an np.array."""
    enc_di = tokenizer.batch_encode_plus(
        texts, 
        return_token_type_ids=False,
        padding='max_length',
        max_length=max_len,
        truncation=True
    )
    return np.array(enc_di['input_ids'])

def batchify(data, batch_size=50, device=torch.device("cpu")):
    tensor = torch.from_numpy(data).float()
    x = tensor.size(0)
    y = tensor.size(1)
    batches = x // batch_size
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    tensor = tensor.narrow(0, 0, batches * batch_size)
    # Evenly divide the data across the batch_size batches.
    tensor = tensor.view(-1, batch_size, y).contiguous()
    return tensor.to(device)


In [None]:
class ContraDataset(Dataset):
  """Dataset loader."""
  
  def __init__(self, claims, labels, tokenizer, max_len=512):
    self.tokenizer = tokenizer
    self.claims = torch.tensor(self.regular_encode(claims, max_len=max_len))
    self.att_mask = torch.zeros(self.claims.size())
    self.att_mask = torch.where(self.claims <= self.att_mask,  self.att_mask, torch.ones(self.claims.size()))
    self.labels = labels

  def __getitem__(self, index):
    assert index < len(self.labels)
    return (self.claims[index], self.att_mask[index], torch.tensor(self.labels[index]))

  def __len__(self):
    return self.labels.shape[0]

  def regular_encode(self, texts, max_len=512):
    """Tokenize a batch of sentence as an np.array."""
    enc_di = self.tokenizer.batch_encode_plus(
        texts, 
        return_token_type_ids=False,
        padding='max_length',
        max_length=max_len,
        truncation=True
    )
    return np.array(enc_di['input_ids'])

In [None]:
class TorchContraNet(nn.Module):
  """Our transfer learning trained model."""

  def __init__(self, transformer):
    """Define and initialize the layers of the model."""
    super(TorchContraNet, self).__init__()
    self.transformer = transformer
    self.linear = nn.Linear(768, 3)
    self.out = nn.Softmax(dim=0)
    
  def forward(self, claim, mask, label=None):
    """Run the model on inputs."""
    hidden_states, enc_attn_mask = self.transformer(claim, 
                                                    token_type_ids=None,
                                                    attention_mask=mask)
    unnormalized_labels = self.linear(hidden_states[:, 0, :])
    y = self.out(unnormalized_labels)
    return y
  
  def train(self):
    # Prepare transformer for training
    self.transformer.train()

  def eval(self):
    # Freeze transformer weights
    self.transformer.eval()

#TODO implement TF version using BERT as a Service

### Load and Preprocess Data

In [None]:
# Contains project hyperparameters and paths

# Root path
proj_path = Path('/content/gdrive/My Drive/Colab Notebooks/BlueBERT/')

# Pretrained models
models_path = proj_path / 'models'
bluebert_path = models_path / 'BlueBERT'

# Top level dataset paths
inputs_path = proj_path / 'Input'
multinli_path = inputs_path / 'multinli'
mednli_path = inputs_path / 'mednli'
mancon_path = inputs_path / 'manconcorpus-sent-pairs'
drug_path = inputs_path / 'drugnames'
virus_path = inputs_path / 'virus-words'

# MultiNLI
multinli_train_path = multinli_path / 'multinli_1.0_train.txt'
multinli_test_path = multinli_path / 'multinli_1.0_dev_matched.txt'
MULTINLI_DATA_PER_CATEGORY = 1000

# Mancon
mancon_sent_pairs = mancon_path / 'manconcorpus_sent_pairs_200516.tsv'
MANCON_DATA_PER_CATEGORY = 1000

# MedNLI
#TODO add in the rest of the MedNLI ingestion and preprocessing
USING_MEDNLI = False # Does the user have access/did the training for mednli?
if USING_MEDNLI:
  mednli_train_path = mednli_path / 'mli_train_v1.jsonl'
  mednli_dev_path = mednli_path / 'mli_dev_v1.jsonl'
  mednli_test_path = mednli_path / 'mli_test_v1.jsonl'
  MEDNLI_DATA_PER_CATEGORY = 1000

# Additional tokenizer dataset path
drug_names_path = drug_path / 'DrugNames.txt'
virus_names_path = virus_path / 'virus_words.txt'

In [None]:
# Load pretrained BlueBert
PRETRAINED_PATH = bluebert_path

tokenizer = BertTokenizer.from_pretrained(str(PRETRAINED_PATH))
transformer = BertModel.from_pretrained(str(PRETRAINED_PATH), num_labels=3, output_attentions=False, output_hidden_states=False)

In [None]:
# Read in MultiNLI
multinli_train_data = pd.read_csv(multinli_train_path, sep='\t', error_bad_lines=False)
multinli_test_data = pd.read_csv(multinli_test_path, sep='\t', error_bad_lines=False)

multinli_train_data['gold_label'] = [2 if l=='contradiction' else 1 if l=='entailment' else 0 for l in multinli_train_data['gold_label']]
multinli_test_data['gold_label'] = [2 if l=='contradiction' else 1 if l=='entailment' else 0 for l in multinli_test_data['gold_label']]

# Remove rows with NaN in either sentence1 or sentence2 since NaN != NaN
multinli_train_data = multinli_train_data[multinli_train_data['sentence1'] == multinli_train_data['sentence1']][multinli_train_data['sentence2'] == multinli_train_data['sentence2']]
multinli_test_data = multinli_test_data[multinli_test_data['sentence1'] == multinli_test_data['sentence1']][multinli_test_data['sentence2'] == multinli_test_data['sentence2']]

#TODO add range checks
balanced_multinli_train_data = multinli_train_data[multinli_train_data['gold_label']==2].head(MULTINLI_DATA_PER_CATEGORY).reset_index(drop=True) 
balanced_multinli_train_data = balanced_multinli_train_data.append(multinli_train_data[multinli_train_data['gold_label']==1].head(MULTINLI_DATA_PER_CATEGORY)).reset_index(drop=True)
balanced_multinli_train_data = balanced_multinli_train_data.append(multinli_train_data[multinli_train_data['gold_label']==0].head(MULTINLI_DATA_PER_CATEGORY)).reset_index(drop=True)

# Make data into the form that BERT expects
multinli_x_train = '[CLS]' + balanced_multinli_train_data.sentence1 + '[SEP]' + balanced_multinli_train_data.sentence2
multinli_x_test = '[CLS]' + multinli_test_data.sentence1 + '[SEP]' + multinli_test_data.sentence2
multinli_y_train = np_utils.to_categorical(balanced_multinli_train_data['gold_label'], dtype='int')
multinli_y_test = np_utils.to_categorical(multinli_test_data['gold_label'], dtype='int')

# Package data into a DataLoader
multinli_x_train_dataset = ContraDataset(multinli_x_train.to_list(), multinli_y_train, tokenizer, max_len=64)
multinli_x_train_sampler = RandomSampler(multinli_x_train_dataset)
multinli_x_train_dataloader = DataLoader(multinli_x_train_dataset, sampler=multinli_x_train_sampler, batch_size=4)

In [None]:
# Ingest Mancon data
raw_mancon_data = pd.read_csv(mancon_sent_pairs, sep ='\t')
raw_mancon_data['label'] = [2 if l=='contradiction' else 1 if l=='entailment' else 0 for l in raw_mancon_data['label']]
raw_mancon_data['label'] = raw_mancon_data['label'].astype('float')

#TODO add range checks
balanced_mancon_data = raw_mancon_data[raw_mancon_data['label']==2].head(MANCON_DATA_PER_CATEGORY).reset_index(drop=True)
balanced_mancon_data = balanced_mancon_data.append(raw_mancon_data[raw_mancon_data['label']==1].head(MANCON_DATA_PER_CATEGORY)).reset_index(drop=True)
balanced_mancon_data = balanced_mancon_data.append(raw_mancon_data[raw_mancon_data['label']==0].head(MANCON_DATA_PER_CATEGORY)).reset_index(drop=True)

mancon_x_train, mancon_x_test, mancon_y_train, mancon_y_test = train_test_split('[CLS]' + balanced_mancon_data['text_a'] + '[SEP]' + balanced_mancon_data['text_b'], balanced_mancon_data['label'], test_size=0.2)
mancon_y_train = np_utils.to_categorical(mancon_y_train)
mancon_y_test = np_utils.to_categorical(mancon_y_test)

In [None]:
# Read in drug to help augment tokenizer
drug_names = pd.read_csv(drug_names_path,header=None)
drug_names = list(drug_names[0])

# Only want the drugs mentioned in our datasets
#TODO add multinli and mednli mentions to filter drug names too
mancon_text = ' '.join(list(set(balanced_mancon_data.text_a)))
drug_names = [drug for drug in drug_names if drug in mancon_text]


# Read in virus names to help augment tokenizer
virus_names = pd.read_csv(virus_names_path, header=None)
virus_names = list(virus_names[0])

# Add drug and virus names to existing tokenizer
tokenizer.add_tokens(drug_names + virus_names)

### Train Utilities

Functions adapted from [here](https://medium.com/@aniruddha.choudhury94/part-2-bert-fine-tuning-tutorial-with-pytorch-for-text-classification-on-the-corpus-of-linguistic-18057ce330e1)

In [None]:
# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels):
  # Get index of largest softmax prediction
  pred_flat = np.argmax(preds, axis=1).flatten()
  labels_flat = np.argmax(labels, axis=1).flatten()
  return np.sum(pred_flat == labels_flat) / len(labels_flat)

def format_time(elapsed):
  '''
  Takes a time in seconds and returns a string hh:mm:ss
  '''
  # Round to the nearest second.
  elapsed_rounded = int(round((elapsed)))

  # Format as hh:mm:ss
  return str(datetime.timedelta(seconds=elapsed_rounded))

In [None]:
def train(model,
          dataloader,
          device,
          criterion=torch.nn.MSELoss(reduction='sum'),
          optimizer = AdamW(model.parameters(), lr = 2e-5, eps = 1e-8),
          epochs=3,
          seed=42):
  
  # Set the seed value all over the place to make this reproducible.
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

  loss_values = []
  
  # Initialize scheduler
  total_steps = len(dataloader) * epochs
  scheduler = get_linear_schedule_with_warmup(optimizer, 
                                              num_warmup_steps = 0,
                                              num_training_steps = total_steps)

  # Training loop
  for epoch in range(epochs):
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))
    print('Training...')

    # Measure how long the training epoch takes.
    t0 = time.time()    
    
    # Reset the total loss for this epoch.
    total_loss = 0

    # Step through dataloader output
    for step, batch in enumerate(dataloader):
      # Progress update every 40 batches.
      if step % 40 == 0 and not step == 0:
        elapsed = format_time(time.time() - t0)
        print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(dataloader), elapsed))

      claim = batch[0].to(device)
      mask = batch[1].to(device)
      label = batch[2].to(device).float()

      model.zero_grad()
      
      y = model(claim, mask)
      loss = criterion(y, label)
      total_loss += loss.item()
      loss.backward()

      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
      optimizer.step()
      scheduler.step()
      
    avg_train_loss = total_loss / len(dataloader)            
    
    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)    
    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epcoh took: {:}".format(format_time(time.time() - t0)))

  return loss_values

In [None]:
def validate(model,
             dataloader,
             device):
  print("")
  print("Running Validation...")    
  t0 = time.time()

  # Put the model in evaluation mode--the dropout layers behave differently during evaluation.
  model.eval()
  
  # Tracking variables 
  eval_loss, eval_accuracy = 0, 0
  nb_eval_steps, nb_eval_examples = 0, 0
  
  # Evaluate data for one epoch
  for batch in dataloader:
      
    claim = batch[0].to(device)
    mask = batch[1].to(device)
    label = batch[2].to(device).float()

    with torch.no_grad():                    
      pred_labels = model(claim, mask)
    
    pred_labels = pred_labels.detach().cpu().numpy()
    label = label.to('cpu').numpy()
  
    # Calculate the accuracy
    tmp_eval_accuracy = flat_accuracy(pred_labels, label)
    
    # Accumulate the total accuracy.
    eval_accuracy += tmp_eval_accuracy
    nb_eval_steps += 1

  print("  Accuracy: {0:.2f}".format(eval_accuracy/nb_eval_steps))
  print("  Validation took: {:}".format(format_time(time.time() - t0)))

In [None]:
# Create model
model = TorchContraNet(transformer)
model.train()
model.to(device)

# Train model
losses = train(model, multinli_x_train_dataloader, device)

In [None]:
validate(model, multinli_x_train_dataloader, device)