In [1]:
import sys
sys.path.insert(0, '../')

from utils import *
from dataset  import *
from models import *
import wandb
from collections import defaultdict
# Check if cuda is available and set device
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# Make sure you choose suitable num_worker, otherwise it will result in errors
num_workers = 8 if cuda else 0

print("Cuda = ", str(cuda), " with num_workers = ", str(num_workers),  " system version = ", sys.version)


  from .autonotebook import tqdm as notebook_tqdm


Cuda =  True  with num_workers =  8  system version =  3.7.13 (default, Oct 18 2022, 18:57:03) 
[GCC 11.2.0]


In [2]:
para_seq_len = 16  #number of paras to be encoded and decoded together (hyperparameter)
batch_size = 4
max_para_length = 128
pretrained_model = "recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier"


In [3]:
test_file = "/home/anjadhav/Chemical-Patent-Reaction-Extraction/data/test_data_iob.csv"
test_data = CRFEmbeddingDataset(test_file, para_seq_len = para_seq_len, pretrained_model = pretrained_model, stride = para_seq_len)
test_args = dict(shuffle=False, batch_size=batch_size, num_workers=8, pin_memory=True, drop_last=False) if cuda else dict(shuffle=False, batch_size=batch_size, drop_last=False)
test_loader = DataLoader(test_data, **test_args)

128


In [5]:
val_file = "/home/anjadhav/Chemical-Patent-Reaction-Extraction/data/val_data_iob.csv"
val_data = CRFEmbeddingDataset(val_file, para_seq_len = para_seq_len, pretrained_model = pretrained_model, stride = para_seq_len)
val_args = dict(shuffle=False, batch_size=batch_size, num_workers=8, pin_memory=True, drop_last=False) if cuda else dict(shuffle=False, batch_size=batch_size, drop_last=False)
val_loader = DataLoader(val_data, **val_args)

128


In [6]:
gen_file = "/home/anjadhav/Chemical-Patent-Reaction-Extraction/data/organic_chem_patents.csv"
gen_data = CRFEmbeddingDataset(gen_file, para_seq_len = para_seq_len, pretrained_model = pretrained_model, stride = para_seq_len)
gen_args = dict(shuffle=False, batch_size=batch_size, num_workers=8, pin_memory=True, drop_last=False) if cuda else dict(shuffle=False, batch_size=batch_size, drop_last=False)
gen_loader = DataLoader(gen_data, **gen_args)

128


In [12]:
class Evaluate:
    def __init__(self, model_load_path, pretrained_model, description = ""):
        
        print(description)

        self.model_load_path = model_load_path
        self.pretrained_model = pretrained_model

        self.model = EncoderDecoderBiLstmCRF(embed_model = BertEmbedding(pretrained_model), num_tags = 3, freeze_bert=False)
        self.model.load_state_dict(torch.load(model_load_path))
        # model = load_pretrained_weights(model, './model_model_params_0.9428545098368426.pth')
        self.model = self.model.to(device)

        self.wandb = wandb.init(name=model_load_path, project="ChemIR3") 
        self.wandb_table = wandb.Table(columns=['run', 'gen_eval_accuracy', 'gen_eval_f-1', 'gen_eval_precision', 'gen_eval_recall', 'gen_fuzzy_f1', 'gen_fuzzy_match_count', 'gen_fuzzy_precision', 'gen_fuzzy_recall', 'gen_misaligned_begin', 'gen_misaligned_begin_end_count', 'gen_misaligned_end', 'gen_missed_span_count', 'gen_strict_f1', 'gen_strict_precision', 'gen_strict_recall',  'test_eval_accuracy', 'test_eval_f-1', 'test_eval_precision', 'test_eval_recall', 'test_fuzzy_f1', 'test_fuzzy_match_count', 'test_fuzzy_precision', 'test_fuzzy_recall', 'test_misaligned_begin', 'test_misaligned_begin_end_count', 'test_misaligned_end', 'test_missed_span_count', 'test_strict_f1', 'test_strict_precision', 'test_strict_recall', 'validate_eval_accuracy', 'validate_eval_f-1', 'validate_eval_precision', 'validate_eval_recall', 'validate_fuzzy_f1', 'validate_fuzzy_match_count', 'validate_fuzzy_precision', 'validate_fuzzy_recall', 'validate_misaligned_begin', 'validate_misaligned_begin_end_count', 'validate_misaligned_end', 'validate_missed_span_count', 'validate_strict_f1', 'validate_strict_precision', 'validate_strict_recall'])     


    def wandb_update(self, val_scores, test_scores, gen_scores):
        val_scores.update(test_scores)
        val_scores.update(gen_scores)
        columns = sorted(list(val_scores.keys()))
        wandb_update_list = [self.model_load_path] #run name
        wandb_update_list.extend([val_scores[k] for k in columns])
        self.wandb_table.add_data(*wandb_update_list)
        self.wandb.log({"gen": self.wandb_table})
    

        
    def evaluate(self, test_loader, test_file, wandb_name = "test_"):
        test_df = pd.read_csv(test_file)
        _, test_predictions, test_scores_validation = validate(self.model, test_loader, device, None, None, wandb = wandb_name)
        _, test_scores_span_perf = get_span_perf(test_df, test_predictions, wandb = wandb_name)
        
        test_scores_validation.update(test_scores_span_perf)
        print(test_scores_validation)
        return test_scores_validation
        

    def extract_eval(self,  test_loader, test_file, save_name):
        test_df = pd.read_csv(test_file)
        test_predictions = extract(self.model, test_loader, device)
        test_df, pred_spans = get_spans(test_df, test_predictions)
        print(pred_spans)
        test_df.to_csv("../outputs/"+save_name)


def load_pretrained_weights(model, pretrained_path):
    pretrained_dict = torch.load(pretrained_path)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k[:12] != "crf_model."}

    model_dict = model.state_dict()
    model_dict.update(pretrained_dict) 
    model.load_state_dict(model_dict)
    return model 

## Model 1


In [13]:
description= """./chem_bert_iob_bilstm_crf_bert_finetune/enc_dec_model_model_params_0.9525667530577686.pth
pretrained_model = "recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier"
batch_size = 4
max_para_length = 128
para_seq_len = 16  #number of paras to be encoded and decoded together (hyperparameter)
'./model_model_params_0.9428545098368426.pth'"""

eval_model1 = Evaluate(model_load_path = '../models/chem_bert_iob_bilstm_crf_bert_finetune/enc_dec_model_model_params_0.9525667530577686.pth', pretrained_model = pretrained_model, description = description)

./chem_bert_iob_bilstm_crf_bert_finetune/enc_dec_model_model_params_0.9525667530577686.pth
pretrained_model = "recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier"
batch_size = 4
max_para_length = 128
para_seq_len = 16  #number of paras to be encoded and decoded together (hyperparameter)
'./model_model_params_0.9428545098368426.pth'


Some weights of the model checkpoint at recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
%%capture
print("==========================Validation Data===================================")
val_scores = eval_model1.evaluate(val_loader, val_file, wandb_name = "validate_")

print("===================================Test Data===================================")
test_scores =  eval_model1.evaluate(test_loader, test_file, wandb_name = "test_")

print("===================================Gen Data===================================")
gen_scores = eval_model1.evaluate(gen_loader, gen_file, wandb_name = "gen_")

eval_model1.wandb_update(val_scores, test_scores, gen_scores)

In [9]:
eval_model1.extract_eval(gen_loader, gen_file, "gen_model1.csv")

                                                      

Time: 2.7918248176574707
Index(['para', 'label', 'document', 'predictions'], dtype='object')
{(726, 727), (555, 555), (557, 557), (710, 710), (715, 715), (717, 717), (723, 723), (556, 556), (716, 716), (711, 711), (724, 724), (712, 713), (558, 559), (479, 480), (923, 925), (720, 721), (241, 242), (709, 709), (708, 708), (926, 928), (473, 474), (921, 922), (562, 562), (718, 718), (483, 484), (475, 476), (719, 719), (239, 240), (730, 730), (477, 478), (729, 729), (728, 728), (706, 707), (467, 468), (471, 472), (243, 244), (247, 248), (481, 482), (560, 561), (731, 731), (733, 733), (714, 714), (469, 470), (237, 238), (245, 246), (732, 732), (725, 725), (722, 722), (223, 223)}


## Model 2

In [15]:
description= """../models/chem_bert_iob_bilstm_crf_no_ft/model_model_params_0.9231656973050348.pth
pretrained_model = "recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier"
batch_size = 4
max_para_length = 128
para_seq_len = 16  #number of paras to be encoded and decoded together (hyperparameter)
"""

eval_model2 = Evaluate(model_load_path = '../models/chem_bert_iob_bilstm_crf_no_ft/model_model_params_0.9231656973050348.pth', pretrained_model = pretrained_model, description = description)

../models/chem_bert_iob_bilstm_crf_no_ft/model_model_params_0.9231656973050348.pth
pretrained_model = "recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier"
batch_size = 4
max_para_length = 128
para_seq_len = 16  #number of paras to be encoded and decoded together (hyperparameter)



Some weights of the model checkpoint at recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
%%capture
print("==========================Validation Data===================================")
val_scores = eval_model2.evaluate(val_loader, val_file, wandb_name = "validate_")

print("===================================Test Data===================================")
test_scores =  eval_model2.evaluate(test_loader, test_file, wandb_name = "test_")

print("===================================Gen Data===================================")
gen_scores = eval_model2.evaluate(gen_loader, gen_file, wandb_name = "gen_")

eval_model2.wandb_update(val_scores, test_scores, gen_scores)


In [11]:
eval_model2.extract_eval(gen_loader, gen_file, "gen_model2.csv")

                                                      

Time: 2.763430118560791
Index(['para', 'label', 'document', 'predictions'], dtype='object')
{(726, 727), (557, 557), (715, 715), (710, 710), (717, 717), (723, 723), (725, 725), (556, 556), (716, 716), (711, 711), (724, 724), (558, 559), (926, 927), (479, 480), (923, 925), (241, 242), (720, 721), (707, 707), (709, 709), (708, 708), (473, 474), (921, 922), (562, 562), (706, 706), (718, 718), (475, 476), (239, 240), (719, 719), (483, 484), (560, 560), (730, 730), (477, 478), (729, 729), (728, 728), (920, 920), (471, 472), (467, 468), (243, 244), (247, 248), (481, 482), (235, 236), (731, 731), (733, 733), (714, 714), (237, 238), (245, 246), (469, 470), (732, 732), (713, 713), (712, 712), (722, 722)}


## Model 3

In [17]:
description= """../models/chem_bert_iob_bilstm_crf_no_ft/model_model_params_0.9231656973050348.pth
pretrained_model = "recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier"
batch_size = 4
max_para_length = 128
para_seq_len = 16  #number of paras to be encoded and decoded together (hyperparameter)
"""

eval_model3 = Evaluate(model_load_path = '../models/chem_bert_iob_bilstm_crf_no_ft/model_model_params_0.9231656973050348.pth', pretrained_model = pretrained_model, description = description)

../models/chem_bert_iob_bilstm_crf_no_ft/model_model_params_0.9231656973050348.pth
pretrained_model = "recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier"
batch_size = 4
max_para_length = 128
para_seq_len = 16  #number of paras to be encoded and decoded together (hyperparameter)



Some weights of the model checkpoint at recobo/chemical-bert-uncased-pharmaceutical-chemical-classifier were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
%%capture
print("==========================Validation Data===================================")
val_scores = eval_model3.evaluate(val_loader, val_file, wandb_name = "validate_")

print("===================================Test Data===================================")
test_scores =  eval_model3.evaluate(test_loader, test_file, wandb_name = "test_")

print("===================================Gen Data===================================")
gen_scores = eval_model3.evaluate(gen_loader, gen_file, wandb_name = "gen_")

eval_model3.wandb_update(val_scores, test_scores, gen_scores)


In [19]:
eval_model3.extract_eval(gen_loader, gen_file, "gen_model3.csv")

                                                      

Time: 2.953927516937256
Index(['para', 'label', 'document', 'predictions'], dtype='object')
{(726, 727), (557, 557), (715, 715), (710, 710), (717, 717), (723, 723), (725, 725), (556, 556), (716, 716), (711, 711), (724, 724), (558, 559), (926, 927), (479, 480), (923, 925), (241, 242), (720, 721), (707, 707), (709, 709), (708, 708), (473, 474), (921, 922), (562, 562), (706, 706), (718, 718), (475, 476), (239, 240), (719, 719), (483, 484), (560, 560), (730, 730), (477, 478), (729, 729), (728, 728), (920, 920), (471, 472), (467, 468), (243, 244), (247, 248), (481, 482), (235, 236), (731, 731), (733, 733), (714, 714), (237, 238), (245, 246), (469, 470), (732, 732), (713, 713), (712, 712), (722, 722)}
