# Evaluate BERT for OSV 

## Label 1 indicates that the two institutions are the same, and label 0 indicates they are different

In [1]:
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from pprint import pprint
from scipy.stats import entropy
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_curve
from transformers import Trainer, TrainingArguments
from transformers import BertTokenizer, BertTokenizerFast, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from utils.data_helper import read_mag_file
from utils.lazydataset import LazyTextMAG_OSV_Dataset

In [2]:
save_pkl_root = '/home/datamerge/ACL/Data/210422/pkl/'
save_train_root = '/home/datamerge/ACL/Data/210422/train/'
save_test_root = '/home/datamerge/ACL/Data/210422/test/'
save_open_root = '/home/datamerge/ACL/Data/210422/open/'
save_dev_root = '/home/datamerge/ACL/Data/210422/dev/'

afid2nor = pickle.load(open(save_pkl_root+"afid2nor.pkl", "rb"))
nor2afid = pickle.load(open(save_pkl_root+"nor2afid.pkl", "rb"))

In [3]:
overall_mid2label_dict = { i: nor for i, (afid, nor) in enumerate(afid2nor.items())}
overall_label2mid_dict = { v: k for k,v in overall_mid2label_dict.items()}

In [4]:
nor2len_dict = pickle.load(open(save_pkl_root+'210422_nor2len_dict.pkl', 'rb'))

train_mid2label_dict = pickle.load(open(save_pkl_root+'train_mid2label_dict.pkl', 'rb'))
train_label2mid_dict = pickle.load(open(save_pkl_root+'train_label2mid_dict.pkl', 'rb'))

overall_mid2label_dict = pickle.load(open(save_pkl_root+'overall_mid2label_dict.pkl', 'rb'))
overall_label2mid_dict = pickle.load(open(save_pkl_root+'overall_label2mid_dict.pkl', 'rb'))

train_mid2overall_mid = {train_id: overall_label2mid_dict[train_label]  for train_id, train_label in train_mid2label_dict.items()}

In [5]:
dev_osv_filepath = save_dev_root+'dev_osv_simple.txt'
test_osv_filepath = save_test_root+'test_osv_simple.txt'

dev_osv_hard_filepath = save_dev_root+'dev_osv_hard.txt'
test_osv_hard_filepath = save_test_root+'test_osv_hard.txt'

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

dev_osv_dataset = LazyTextMAG_OSV_Dataset(tokenizer, dev_osv_filepath, overall_label2mid_dict)
test_osv_dataset = LazyTextMAG_OSV_Dataset(tokenizer, test_osv_filepath, overall_label2mid_dict)

dev_osv_hard_dataset = LazyTextMAG_OSV_Dataset(tokenizer, dev_osv_hard_filepath, overall_label2mid_dict)
test_osv_hard_dataset = LazyTextMAG_OSV_Dataset(tokenizer, test_osv_hard_filepath, overall_label2mid_dict)

In [7]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

class BertForAffiliationNameNormalization(torch.nn.Module):
    
    def __init__(self, num_of_classes):
        super(BertForAffiliationNameNormalization, self).__init__()
        self.num_of_classes = num_of_classes
        self.bert = BertModel.from_pretrained('bert-base-uncased').to(device)
        self.dropout = nn.Dropout(p=0.1, inplace=False).to(device)
        self.classifier = nn.Linear(self.bert.config.hidden_size, self.num_of_classes, bias=True).to(device)
        
        
    def forward(self, input_ids, attention_mask):
        pooled_out = self.bert(input_ids, attention_mask=attention_mask)
        pooled_out = self.dropout(pooled_out[1])
        logits = self.classifier(pooled_out)
        
        return logits

In [8]:
model = torch.load('./checkpoint_rs3/epoch_16_bert.pkl')

if isinstance(model,torch.nn.DataParallel):
    model = model.module

model = model.to(torch.device("cuda:0"))

In [9]:
def js_divergence(p, q):
    m = (p + q) / 2
    js = entropy(p, m, axis=-1) / 2 + entropy(q, m, axis=-1) / 2
    return js

In [10]:
def report_osv(true, pred):
    return accuracy_score(true, pred)

In [11]:
def evaluate_osv(model, dataset, nor2len_dict, overall_mid2label_dict, threshold):
    model.eval()
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    
    test_dataset = dataset
    loader = torch.utils.data.DataLoader(test_dataset, batch_size=32)
    
    i = 0
    results = []
    labels = []
    for batch in tqdm(loader):
        first_input_ids = torch.cat([i.reshape(1,-1) for i in batch['first_encodings']['input_ids']], dim=0).to(device)
        first_attention_mask = torch.cat([i.reshape(1,-1) for i in batch['first_encodings']['attention_mask']], dim=0).to(device)
        
        second_input_ids = torch.cat([i.reshape(1,-1) for i in batch['second_encodings']['input_ids']], dim=0).to(device)
        second_attention_mask = torch.cat([i.reshape(1,-1) for i in batch['second_encodings']['attention_mask']], dim=0).to(device)        
        
        label = batch['label'].to(device)
        labels.append(label.to(torch.device('cpu')))
        
        first_logits = model(first_input_ids, attention_mask=first_attention_mask)
        first_probs = F.softmax(first_logits, dim=1).to(torch.device('cpu'))
        
        second_logits = model(second_input_ids, attention_mask=second_attention_mask)
        second_probs = F.softmax(second_logits, dim=1).to(torch.device('cpu'))
        
        for i in range(first_probs.shape[0]):
            first_prob, second_prob = first_probs[i, :].detach().numpy(), second_probs[i, :].detach().numpy()
            result = js_divergence(first_prob, second_prob)
            results.append(result)
        
    results = np.array(results)
    judgements = results < threshold
    labels = torch.cat(labels).numpy()
    
    acc = report_osv(labels, judgements)
        
    return acc

In [12]:
for num in range(0, 105, 5):
    threshold = num/100.0
    results = evaluate_osv(model, dev_osv_dataset, nor2len_dict, overall_mid2label_dict, threshold)
    print("threshold: ", threshold, "\t\tresults: ", results)

100%|██████████| 188/188 [00:28<00:00,  6.49it/s]
  1%|          | 1/188 [00:00<00:25,  7.44it/s]

threshold:  0.0 		results:  0.5


100%|██████████| 188/188 [00:26<00:00,  7.03it/s]
  1%|          | 1/188 [00:00<00:29,  6.35it/s]

threshold:  0.05 		results:  0.5443037974683544


100%|██████████| 188/188 [00:26<00:00,  7.18it/s]
  1%|          | 1/188 [00:00<00:24,  7.60it/s]

threshold:  0.1 		results:  0.5581279147235176


100%|██████████| 188/188 [00:26<00:00,  7.14it/s]
  1%|          | 1/188 [00:00<00:28,  6.67it/s]

threshold:  0.15 		results:  0.5691205862758161


100%|██████████| 188/188 [00:26<00:00,  7.09it/s]
  1%|          | 1/188 [00:00<00:24,  7.75it/s]

threshold:  0.2 		results:  0.5801132578281146


100%|██████████| 188/188 [00:26<00:00,  7.14it/s]
  1%|          | 1/188 [00:00<00:24,  7.57it/s]

threshold:  0.25 		results:  0.5919387075283145


100%|██████████| 188/188 [00:25<00:00,  7.31it/s]
  1%|          | 1/188 [00:00<00:23,  7.79it/s]

threshold:  0.3 		results:  0.600599600266489


100%|██████████| 188/188 [00:25<00:00,  7.24it/s]
  1%|          | 1/188 [00:00<00:24,  7.75it/s]

threshold:  0.35 		results:  0.6102598267821452


100%|██████████| 188/188 [00:27<00:00,  6.94it/s]
  1%|          | 1/188 [00:00<00:24,  7.66it/s]

threshold:  0.4 		results:  0.6209193870752832


100%|██████████| 188/188 [00:27<00:00,  6.88it/s]
  1%|          | 1/188 [00:00<00:27,  6.76it/s]

threshold:  0.45 		results:  0.6329113924050633


100%|██████████| 188/188 [00:27<00:00,  6.91it/s]
  0%|          | 0/188 [00:00<?, ?it/s]

threshold:  0.5 		results:  0.6435709526982012


100%|██████████| 188/188 [00:26<00:00,  7.14it/s]
  1%|          | 1/188 [00:00<00:24,  7.65it/s]

threshold:  0.55 		results:  0.6627248500999334


100%|██████████| 188/188 [00:24<00:00,  7.53it/s]
  1%|          | 1/188 [00:00<00:24,  7.71it/s]

threshold:  0.6 		results:  0.685709526982012


100%|██████████| 188/188 [00:25<00:00,  7.29it/s]
  1%|          | 1/188 [00:00<00:24,  7.53it/s]

threshold:  0.65 		results:  0.730346435709527


100%|██████████| 188/188 [00:25<00:00,  7.25it/s]
  1%|          | 1/188 [00:00<00:24,  7.57it/s]

threshold:  0.7 		results:  0.5


100%|██████████| 188/188 [00:26<00:00,  7.23it/s]
  1%|          | 1/188 [00:00<00:26,  7.05it/s]

threshold:  0.75 		results:  0.5


100%|██████████| 188/188 [00:25<00:00,  7.25it/s]
  1%|          | 1/188 [00:00<00:26,  7.16it/s]

threshold:  0.8 		results:  0.5


100%|██████████| 188/188 [00:25<00:00,  7.49it/s]
  1%|          | 1/188 [00:00<00:27,  6.89it/s]

threshold:  0.85 		results:  0.5


100%|██████████| 188/188 [00:28<00:00,  6.58it/s]
  1%|          | 1/188 [00:00<00:24,  7.73it/s]

threshold:  0.9 		results:  0.5


100%|██████████| 188/188 [00:25<00:00,  7.26it/s]
  1%|          | 1/188 [00:00<00:37,  5.03it/s]

threshold:  0.95 		results:  0.5


100%|██████████| 188/188 [00:26<00:00,  7.12it/s]

threshold:  1.0 		results:  0.5





In [13]:
for num in range(60, 71, 1):
    threshold = num/100.0
    results = evaluate_osv(model, dev_osv_dataset, nor2len_dict, overall_mid2label_dict, threshold)
    print("threshold: ", threshold, "\t\tresults: ", results)

100%|██████████| 188/188 [00:25<00:00,  7.24it/s]
  0%|          | 0/188 [00:00<?, ?it/s]

threshold:  0.6 		results:  0.685709526982012


100%|██████████| 188/188 [00:27<00:00,  6.89it/s]
  1%|          | 1/188 [00:00<00:27,  6.82it/s]

threshold:  0.61 		results:  0.6930379746835443


100%|██████████| 188/188 [00:26<00:00,  6.98it/s]
  1%|          | 1/188 [00:00<00:29,  6.34it/s]

threshold:  0.62 		results:  0.6977015323117921


100%|██████████| 188/188 [00:25<00:00,  7.31it/s]
  1%|          | 1/188 [00:00<00:30,  6.23it/s]

threshold:  0.63 		results:  0.7060293137908061


100%|██████████| 188/188 [00:27<00:00,  6.81it/s]
  1%|          | 1/188 [00:00<00:27,  6.72it/s]

threshold:  0.64 		results:  0.7173550966022652


100%|██████████| 188/188 [00:25<00:00,  7.26it/s]
  1%|          | 1/188 [00:00<00:25,  7.24it/s]

threshold:  0.65 		results:  0.730346435709527


100%|██████████| 188/188 [00:25<00:00,  7.24it/s]
  1%|          | 1/188 [00:00<00:24,  7.73it/s]

threshold:  0.66 		results:  0.7383411059293804


100%|██████████| 188/188 [00:26<00:00,  7.13it/s]
  1%|          | 1/188 [00:00<00:24,  7.60it/s]

threshold:  0.67 		results:  0.7426715522984677


100%|██████████| 188/188 [00:26<00:00,  7.18it/s]
  1%|          | 1/188 [00:00<00:25,  7.36it/s]

threshold:  0.68 		results:  0.7380079946702198


100%|██████████| 188/188 [00:25<00:00,  7.27it/s]
  1%|          | 1/188 [00:00<00:28,  6.58it/s]

threshold:  0.69 		results:  0.6795469686875416


100%|██████████| 188/188 [00:26<00:00,  7.19it/s]

threshold:  0.7 		results:  0.5





In [14]:
results = evaluate_osv(model, test_osv_dataset, nor2len_dict, overall_mid2label_dict, threshold=0.67)

100%|██████████| 147/147 [00:21<00:00,  6.84it/s]


In [15]:
print(results)

0.7888888888888889
