In [16]:
!pip install datasets transformers nltk pytorch-crf torch seaborn sklearn matplotlib



In [17]:
from torchcrf import CRF
import os
import os.path as osp
import nltk
import random
# nltk.download('stopwords')
# nltk.download('punkt')
# from nltk.corpus import stopwords
# english_stopwords = stopwords.words("english")
import numpy as np
import re
import seaborn as sns
sns.set_theme(style="whitegrid")
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
import pandas as pd
import string
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import datasets
from datasets import load_dataset
import pickle
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertTokenizer, BertModel
import multiprocessing
import time
from torch.utils.data import DataLoader, Dataset 
import sys
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

In [18]:
test_file = "test_data_iob.csv"
val_file = "val_data_iob.csv"
train_file = "train_data_iob.csv"

In [19]:
pretrained_model = "recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier"
batch_size = 4
max_para_length = 128
para_seq_len = 16  #number of paras to be encoded and decoded together (hyperparameter)
# Check if cuda is available and set device
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# Make sure you choose suitable num_worker, otherwise it will result in errors
num_workers = 8 if cuda else 0

print("Cuda = ", str(cuda), " with num_workers = ", str(num_workers),  " system version = ", sys.version)

Cuda =  True  with num_workers =  8  system version =  3.7.13 (default, Oct 18 2022, 18:57:03) 
[GCC 11.2.0]


In [20]:
class CRFEmbeddingDataset(Dataset):
    def __init__(self, csv_file, para_seq_len, pretrained_model, stride = 1):
      df = pd.read_csv(csv_file)

      self.para_seq_len = para_seq_len
      self.tokenizer = BertTokenizer.from_pretrained(pretrained_model, do_lower=True)    

      # Tokenize the paragraphs
      self.df = df["para"].apply(self.preprocess)
      self.y = df['label']
      # self.test = is_test
      self.stride = stride
  
     
    def preprocess(self, examples):
      return self.tokenizer(examples, truncation=True, 
                     padding="max_length", max_length=max_para_length,
                     return_token_type_ids=False)['input_ids']

    def __len__(self):
      # if(self.test):
      #   # print(math.ceil(len(self.y)/self.para_seq_len))
      #   return math.ceil(len(self.y)/self.para_seq_len)
      l = math.ceil((len(self.y) - self.para_seq_len + 1) / self.stride)
      # print(len(self.y))  
      # print(l)
      return l
    
    def __getitem__(self,index):
      return torch.LongTensor(list(self.df[index*self.stride: (index*self.stride + self.para_seq_len)])), torch.LongTensor(self.y[index*self.stride: (index*self.stride + self.para_seq_len)].tolist())
      

In [24]:
train_data = CRFEmbeddingDataset(train_file, para_seq_len = para_seq_len, pretrained_model = pretrained_model, stride = 2)
val_data = CRFEmbeddingDataset(val_file, para_seq_len = para_seq_len, pretrained_model = pretrained_model, stride = para_seq_len)
test_data = CRFEmbeddingDataset(test_file, para_seq_len = para_seq_len, pretrained_model = pretrained_model, stride = para_seq_len)

train_args = dict(shuffle=True, batch_size=batch_size, num_workers=8, pin_memory=True, drop_last=False) if cuda else dict(shuffle=True, batch_size=batch_size, drop_last=False)
train_loader = DataLoader(train_data, **train_args)

val_args = dict(shuffle=False, batch_size=batch_size, num_workers=8, pin_memory=True, drop_last=False) if cuda else dict(shuffle=False, batch_size=batch_size, drop_last=False)
val_loader = DataLoader(val_data, **val_args)


test_args = dict(shuffle=False, batch_size=batch_size, num_workers=8, pin_memory=True, drop_last=False) if cuda else dict(shuffle=False, batch_size=batch_size, drop_last=False)
test_loader = DataLoader(test_data, **test_args)

In [25]:
print(train_data.__len__())
print(val_data.__len__())
print(test_data.__len__())
print(len(train_loader))
print(len(test_loader))
print(len(val_loader))

23002
471
804
5751
201
118


## Span level analysis

In [38]:
def get_span_perf(test_df, predictions):
    # print(len(test_df))
    test_df = test_df[:][:len(predictions)]
    test_df['predictions'] = predictions
    # test_df.to_csv("bert_embed_iob_bilstm_crf_pred.csv")
    # print(len(test_df)) 
    test_df = test_df.reset_index(drop=False)
    test_df.columns = ['index', 'para', 'label', 'document', 'predictions']

    orig = set()
    i = 0
    while i < len(test_df):
        if(test_df['label'][i] == 2):
            st = test_df['index'][i]
            i +=1
            while(i < len(test_df) and test_df['label'][i] == 1):
                i+=1
            orig.add((st, i-1))
        else:
            i+=1

    pred = set()
    i = 0
    while i < len(test_df):
        if(test_df['predictions'][i] == 2):
            st = test_df['index'][i]
            i +=1
            while(i < len(test_df) and test_df['predictions'][i] == 1):
                i+=1
            pred.add((st, i-1))
        else:
            i+=1
            
    strict_match_spans = orig.intersection(pred)
    fuzzy_cnt = 0
    for o in orig:
        if ((o in pred) or ((o[0]+1,o[1]) in pred) or ((o[0]+1,o[1]-1) in pred) or ((o[0]+1,o[1]+1) in pred) 
            or ((o[0]-1,o[1]) in pred) or ((o[0]-1,o[1]+1) in pred) or ((o[0]-1,o[1]-1) in pred) or ((o[0],o[1]+1) in pred)
            or ((o[0],o[1]-1) in pred)):
            fuzzy_cnt+=1
    

    miss_start_end = 0
    miss_start = 0
    miss_end = 0

    for o in orig:
        if(o in pred):
            continue 
        elif(((o[0]-1,o[1]+1) in pred) or ((o[0]-1,o[1]-1) in pred) or ((o[0]+1,o[1]-1) in pred) or ((o[0]+1,o[1]+1) in pred)):
            miss_start_end += 1
        elif(((o[0]+1,o[1]) in pred) or ((o[0]-1,o[1]) in pred)):
            miss_start += 1
        elif(((o[0],o[1]+1) in pred) or ((o[0],o[1]-1) in pred)):
            miss_end+=1

    print("Total original spans: ", len(orig))
    print("Total predicted spans: ", len(pred))
    print("Total number of original spans correctly predicted acc to strict match: ", len(strict_match_spans))
    print("Percent of original spans correctly predicted acc to strict match: ", len(strict_match_spans)/len(orig)*100)

    print("Total number of original spans correctly predicted acc to fuzzy match: ", fuzzy_cnt)
    print("Percent of original spans correctly predicted acc to fuzzy match: ", fuzzy_cnt/len(orig)*100)

    fuzzy_matched_only = miss_start_end+miss_start+miss_end
    assert(fuzzy_matched_only == fuzzy_cnt - len(strict_match_spans))
    print("Count of fuzzy matched spans: ", miss_start_end+miss_start+miss_end)
    print("Count of spans with misaligned begin and end: {} ({:.2f}%) ".format(miss_start_end, miss_start_end/fuzzy_matched_only*100))
    print("Count of spans with misaligned begin: {} ({:.2f}%) ".format(miss_start, miss_start/fuzzy_matched_only*100))
    print("Count of spans with misaligned end: {} ({:.2f}%) ".format(miss_end, miss_end/fuzzy_matched_only*100))

    return test_df

## Fixed Bert word Embeddings, BiLSTM encoder, Triplet Decoder

In [27]:
class BertEmbedding(nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.model = BertModel.from_pretrained(pretrained_model, output_hidden_states = True)

    def forward(self, x):
        # print("Input to BertEmbedding: ", x.shape)
        outputs = self.model(x)
        hidden_states = outputs[2]
        embedding = torch.cat((hidden_states[-1],hidden_states[-2],hidden_states[-3],hidden_states[-4]), dim = 2)
        # print("Output from BertEmbedding: ", embedding.shape)
        return embedding

class ParaEncoderForContext(nn.Module):
    def __init__(self, bilayers = 1, input_dim = 3072, hidden_size = 512):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_size
        self.lstm = nn.LSTM(
                input_size=input_dim, hidden_size=hidden_size,
                num_layers=1, batch_first=True, bidirectional=True)

        for name, param in self.lstm.named_parameters():
            if 'bias' in name:
                nn.init.constant(param, 0.0)
            elif 'weight' in name:
                nn.init.kaiming_normal_(param)
     

    def forward(self, x): # (B*T(T=1+2*context), tokens, input_dim)
        # print("Input to Encoder: ",x.shape)
        outputs, _ = self.lstm(x) # (B*T, tokens, 2*hidden_dim)
        # print("After LSTM: ", outputs.shape)
        first = outputs[:, 0, self.hidden_dim:]
        second = outputs[:, -1, :self.hidden_dim]
        para_embed = torch.cat((second,first), dim = 1) #(B*T, 2*hidden_dim)

        # print("Output from Encoder", para_embed.shape)
        return para_embed #(B*T, 2*hidden_dim)



class ParaDecoderBiLstmCRF(nn.Module):
    def __init__(self, input_dim, hidden_size, bilayers = 1):
        super().__init__()
        # self.input_dim = input_dim
        # self.hidden_dim = hidden_size
        self.lstm = nn.LSTM(
                input_size=input_dim, hidden_size=hidden_size,
                num_layers=1, batch_first=True, bidirectional=True)

        self.linear = nn.Linear(2*hidden_size, 3, bias= True)
        
        for name, param in self.lstm.named_parameters():
            if 'bias' in name:
                nn.init.constant(param, 0.0)
            elif 'weight' in name:
                nn.init.kaiming_normal_(param)

    
    def forward(self, x):  #(B, T, 2*encoder.hidden_dim)
    # print("Input to decoder: ", x.shape) 
        outputs, _ = self.lstm(x)   #out = (B, T, 2*decoder.hidden_dim)

        s0, s1, s2 = outputs.shape
        op = outputs.reshape(s0*s1, s2) # (B*T, 2*decoder.hidden_dim)

        op2 = self.linear(op)

        op3 = op2.view(s0, s1, -1)

        return op3 #(B,T,3) #emissions



class EncoderDecoderBiLstmCRF(nn.Module):
    def __init__(self, embed_model, num_tags, encoder_bilayers = 1, encoder_input_dim = 3072, encoder_hidden_size = 512, decoder_bilayers = 1, decoder_hidden_size = 512, freeze_bert = True):
        super().__init__()
        self.para_encoder = ParaEncoderForContext(bilayers = encoder_bilayers, input_dim = encoder_input_dim, hidden_size = encoder_hidden_size)
        self.para_decoder = ParaDecoderBiLstmCRF(input_dim = encoder_hidden_size*2, hidden_size = decoder_hidden_size, bilayers = decoder_bilayers)
        self.crf_model = CRF(num_tags = num_tags, batch_first = True)
        self.embed_model = embed_model
        
        if(freeze_bert):
            for param in self.embed_model.parameters():
                param.requires_grad = False

    def decode(self, emission):
        return self.crf_model.decode(emission)

    def forward(self, x, y): # (B, 2*context+1, tokens_per_para)
    # print("Input to model: ", x.shape)
        s0, s1, s2 = x.shape
        xv = x.view(s0*s1, s2)
        embeds = self.embed_model(xv)
        para_vec = self.para_encoder(embeds)
        pvv = para_vec.view(s0, s1, -1) #(B, T, 2*hidden_dim)
        # print("Input to decoder: ", pvv.shape)
        emission = self.para_decoder(pvv) #(B,T,3) #emissions
        log_likelihood = self.crf_model(emission, y, reduction='mean') 
        return -log_likelihood, emission

## Train and Validate Functions

In [41]:
def train(para_model, data_loader):
    para_model.train()
    # crf_model.train()

    avg_loss = []
    start = time.time()
    all_predictions = []
    all_targets = []
    
    for i, (x, y) in enumerate(tqdm(data_loader, desc="Epoch", leave=False)):
        optimizer.zero_grad()
        y  = y.to(device) 
        x = x.to(device)

        loss, emission = para_model(x, y)
        del x

        avg_loss.extend([loss.item()]*len(y))
        
        decoded_list = para_model.decode(emission)
        for l in decoded_list:
            all_predictions.extend(l)


        all_targets.extend(torch.flatten(y.detach().cpu()).tolist())
        
        loss.backward()
        optimizer.step()
        scheduler.step()

        del y
        del emission
        torch.cuda.empty_cache()
  
        
    
    end = time.time()
    avg_loss = np.mean(avg_loss)
    print('learning_rate: {}'.format(scheduler.get_last_lr()))
    print('Training loss: {:.2f}, Time: {}'.format(avg_loss, end-start))

    all_predictions = np.array(all_predictions)
    all_targets = np.array(all_targets)
    scores = precision_recall_fscore_support(all_targets, all_predictions, 
                                            average="weighted", zero_division=0.)

    test_scores={
      "eval_accuracy": (all_predictions == all_targets).sum() / len(all_predictions),
      "eval_precision": scores[0],
      "eval_recall": scores[1],
      "eval_f-1": scores[2]
    }
    print(test_scores)

In [42]:
def validate(para_model, data_loader):
    para_model.eval()
    # crf_model.eval()

    avg_loss = []
    all_predictions = []
    all_targets = []
    start = time.time()

    for i, (x, y) in enumerate(tqdm(data_loader, desc="Epoch", leave=False)):
        y = y.to(device)
        x = x.to(device)

        with torch.no_grad():
            loss, emission = para_model(x, y) 
            del x
            
            avg_loss.extend([loss.item()]*len(y))

            decoded_list = para_model.decode(emission)
            
            for l in decoded_list:
                all_predictions.extend(l)

            all_targets.extend(torch.flatten(y.detach().cpu()).tolist())
            del emission
            del y
            torch.cuda.empty_cache()
      
    end = time.time()
    avg_loss = np.mean(avg_loss)
    print('learning_rate: {}'.format(scheduler.get_last_lr()))
    print('Validation loss: {:.2f}, Time: {}'.format(avg_loss, end-start))

    all_predictions = np.array(all_predictions)
    all_targets = np.array(all_targets)
    scores = precision_recall_fscore_support(all_targets, all_predictions, 
                                            average="weighted", zero_division=0.)

    test_scores={
      "eval_accuracy": (all_predictions == all_targets).sum() / len(all_predictions),
      "eval_precision": scores[0],
      "eval_recall": scores[1],
      "eval_f-1": scores[2]
    }
    print(test_scores)
    return test_scores["eval_f-1"], all_predictions


In [43]:
def save(model, acc, best=""):
    if not os.path.exists('./chem_bert_iob_bilstm_crf_bert_finetune/'):
        os.mkdir('./chem_bert_iob_bilstm_crf_bert_finetune/')

    torch.save(model.state_dict(), './chem_bert_iob_bilstm_crf_bert_finetune/'+'/{}model_params_{}.pth'.format(best, acc))

def load_pretrained_weights(model, pretrained_path):
    pretrained_dict = torch.load(pretrained_path)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k[:12] != "embed_model."}
    # print(pretrained_dict.keys())
    model_dict = model.state_dict()
    model_dict.update(pretrained_dict) 
    model.load_state_dict(model_dict)
    return model    

## Main

In [44]:
model = EncoderDecoderBiLstmCRF(embed_model = BertEmbedding(pretrained_model), num_tags = 3, freeze_bert=False)
# model.load_state_dict(torch.load('./bert_iob_bilstm_crf/model_model_params_0.9428545098368426.pth'))
model = load_pretrained_weights(model, './model_model_params_0.9428545098368426.pth')

# if torch.cuda.device_count() > 1:
#   print("Let's use", torch.cuda.device_count(), "GPUs!")
#   model = nn.DataParallel(model)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable_total_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print("Total params: ", total_params)
print("Trainable params: ", trainable_total_params)
print("Non Trainable params: ", non_trainable_total_params)


Some weights of the model checkpoint at recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Total params:  130909458
Trainable params:  130909458
Non Trainable params:  0


In [45]:
epochs = 10 #changed from 10
lamda = 1e-3  #L2 regularization (prev : 1e-4)
learning_rate = 5e-5 #changed from 1e-2   ## Greatly reduces LR for bert finetuning

# criterion = nn.CrossEntropyLoss()
# criterion = criterion.to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=lamda)
# optimizer.load_state_dict(torch.load('./bert_base_triplet/optimizer_model_params_0.9409211846833226.pth'))    

# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[i for i in range(4,20,4)], gamma=0.75)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(len(train_loader) * epochs))

In [46]:
torch.cuda.empty_cache()
best_val_f1 = 0
val_df = pd.read_csv(val_file)

for epoch in range(epochs):
    print('Epoch #{}'.format(epoch+1))

    train(model, train_loader)
    val_f1, val_preds = validate(model, val_loader)
    try:
        get_span_perf(val_df, val_preds)
        save(model, val_f1, best = "model_")
        save(optimizer, val_f1, best = "optimizer_")
    except Exception as error:
        print(error)


Epoch #1


                                                            

learning_rate: [4.877641290737887e-05]
Training loss: 1.32, Time: 4095.447846889496
{'eval_accuracy': 0.9722850186940266, 'eval_precision': 0.9720540062578936, 'eval_recall': 0.9722850186940266, 'eval_f-1': 0.9720996777603389}


                                                        

learning_rate: [4.877641290737887e-05]
Validation loss: 3.65, Time: 15.204040050506592
{'eval_accuracy': 0.9532908704883227, 'eval_precision': 0.9525908355479169, 'eval_recall': 0.9532908704883227, 'eval_f-1': 0.9525667530577686}
Total original spans:  911
Total predicted spans:  842
Total number of original spans correctly predicted acc to strict match:  699
Percent of original spans correctly predicted acc to strict match:  76.72886937431394
Total number of original spans correctly predicted acc to fuzzy match:  772
Percent of original spans correctly predicted acc to fuzzy match:  84.74204171240395
Count of fuzzy matched spans:  73
Count of spans with misaligned begin and end: 16 (21.92%) 
Count of spans with misaligned begin: 28 (38.36%) 
Count of spans with misaligned end: 29 (39.73%) 
Epoch #2


                                                          

learning_rate: [4.522542485937361e-05]
Training loss: 0.36, Time: 3230.8570413589478
{'eval_accuracy': 0.9924517433266673, 'eval_precision': 0.992454440132001, 'eval_recall': 0.9924517433266673, 'eval_f-1': 0.9924527337580042}


                                                        

learning_rate: [4.522542485937361e-05]
Validation loss: 2.80, Time: 14.896248817443848
{'eval_accuracy': 0.9538216560509554, 'eval_precision': 0.9534230759626194, 'eval_recall': 0.9538216560509554, 'eval_f-1': 0.953476708995007}
Total original spans:  911
Total predicted spans:  861
Total number of original spans correctly predicted acc to strict match:  706
Percent of original spans correctly predicted acc to strict match:  77.49725576289791
Total number of original spans correctly predicted acc to fuzzy match:  764
Percent of original spans correctly predicted acc to fuzzy match:  83.86388583973655
Count of fuzzy matched spans:  58
Count of spans with misaligned begin and end: 28 (48.28%) 
Count of spans with misaligned begin: 14 (24.14%) 
Count of spans with misaligned end: 16 (27.59%) 
Epoch #3


                                                            

learning_rate: [3.9694631307311735e-05]
Training loss: 0.24, Time: 4928.052469968796
{'eval_accuracy': 0.9950113033649248, 'eval_precision': 0.9950135157658397, 'eval_recall': 0.9950113033649248, 'eval_f-1': 0.9950122317299185}


                                                        

learning_rate: [3.9694631307311735e-05]
Validation loss: 4.23, Time: 49.10485291481018
{'eval_accuracy': 0.9211783439490446, 'eval_precision': 0.9200920328688463, 'eval_recall': 0.9211783439490446, 'eval_f-1': 0.9204843779941658}
Total original spans:  911
Total predicted spans:  876
Total number of original spans correctly predicted acc to strict match:  614
Percent of original spans correctly predicted acc to strict match:  67.39846322722283
Total number of original spans correctly predicted acc to fuzzy match:  688
Percent of original spans correctly predicted acc to fuzzy match:  75.52140504939628
Count of fuzzy matched spans:  74
Count of spans with misaligned begin and end: 20 (27.03%) 
Count of spans with misaligned begin: 20 (27.03%) 
Count of spans with misaligned end: 34 (45.95%) 
Epoch #4


Epoch:  20%|█▉        | 1145/5751 [26:03<1:39:13,  1.29s/it]

In [None]:
# Test on Test Set

In [None]:
_, predictions = validate(model, test_loader)

In [None]:
# Store predictions

In [None]:
test_df = pd.read_csv(test_file, predictions)
print(len(test_df))
test_df = get_span_perf(test_df)
print(len(test_df))
test_df.to_csv("chem_bert_embed_iob_bilstm_crf_pred.csv")

In [None]:
# Store error cases

In [None]:
error = test_df[test_df['label'] != test_df['predictions']]
print((len(test_df)- len(error)) / len(test_df))
print(len(error))
error.to_csv("errors_bert_embed_iob_bilstm_crf.csv")