# Evaluate BERT for CSC

In [1]:
import torch
import pickle
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from transformers import Trainer, TrainingArguments
from transformers import BertTokenizer, BertTokenizerFast, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from utils.data_helper import read_mag_file
from utils.lazydataset import LazyTextMAG_Dataset

In [2]:
save_pkl_root = '/home/datamerge/ACL/Data/210422/pkl/'
save_train_root = '/home/datamerge/ACL/Data/210422/train/'
save_test_root = '/home/datamerge/ACL/Data/210422/test/'
save_open_root = '/home/datamerge/ACL/Data/210422/open/'
save_dev_root = '/home/datamerge/ACL/Data/210422/dev/'

afid2nor = pickle.load(open(save_pkl_root+"afid2nor.pkl", "rb"))
nor2afid = pickle.load(open(save_pkl_root+"nor2afid.pkl", "rb"))

In [3]:
nor2len_dict = pickle.load(open(save_pkl_root+'210422_nor2len_dict.pkl', 'rb'))

train_mid2label_dict = pickle.load(open(save_pkl_root+'train_mid2label_dict.pkl', 'rb'))
train_label2mid_dict = pickle.load(open(save_pkl_root+'train_label2mid_dict.pkl', 'rb'))

In [4]:
train_filepath = save_train_root+'train_part.txt'
dev_filepath = save_dev_root+'dev.txt'
test_filepath = save_test_root+'test.txt'

In [5]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

dev_dataset = LazyTextMAG_Dataset(tokenizer, dev_filepath, train_label2mid_dict)
test_dataset = LazyTextMAG_Dataset(tokenizer, test_filepath, train_label2mid_dict)

# dev_dataset = LazyTextMAG_Dataset(tokenizer, dev_filepath, train_label2mid_dict, block_size=64)
# test_dataset = LazyTextMAG_Dataset(tokenizer, test_filepath, train_label2mid_dict, block_size=64)

In [6]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

class BertForAffiliationNameNormalization(torch.nn.Module):
    
    def __init__(self, num_of_classes):
        super(BertForAffiliationNameNormalization, self).__init__()
        self.num_of_classes = num_of_classes
        self.bert = BertModel.from_pretrained('bert-base-uncased').to(device)
        self.dropout = nn.Dropout(p=0.1, inplace=False).to(device)
        self.classifier = nn.Linear(self.bert.config.hidden_size, self.num_of_classes, bias=True).to(device)
        
        
    def forward(self, input_ids, attention_mask):
        pooled_out = self.bert(input_ids, attention_mask=attention_mask)
        pooled_out = self.dropout(pooled_out[1])
        logits = self.classifier(pooled_out)
        
        return logits

In [7]:
model = torch.load('./checkpoint0422/After_epoch_79_bert.pkl')

In [19]:
def report(true, pred):
    all_labels = list(set(true))
    a = accuracy_score(true, pred) * 100
    p = precision_score(true, pred, average="macro", labels=all_labels, zero_division=0) * 100
    r = recall_score(true, pred, average="macro", labels=all_labels) * 100
    f = f1_score(true, pred, average="macro", labels=all_labels, zero_division=0) * 100
    return a, p, r, f


def calc_split(low_margin, high_margin, true, pred, test_set_size):
    low = test_set_size < low_margin
    high = test_set_size > high_margin
    mid = np.logical_and(test_set_size >= low_margin, test_set_size <= high_margin)

    r1 = report(true[high], pred[high])
    r2 = report(true[mid], pred[mid])
    r3 = report(true[low], pred[low])
    
    return r1, r2, r3

In [20]:
def evaluate_test(model, dataset, nor2len_dict, train_mid2label_dict):
    model.eval()
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
#     model = nn.DataParallel(model, device_ids=[0,1,2,3])
#     model.to(torch.device('cuda:1'))
    
    test_dataset = dataset
    loader = torch.utils.data.DataLoader(test_dataset, batch_size=128)
    true = []
    pred = []
    test_set_size = []
    
    for batch in tqdm(loader):
        input_ids = torch.cat([i.reshape(1,-1) for i in batch['input_ids']], dim=0).to(device)
        attention_mask = torch.cat([i.reshape(1,-1) for i in batch['attention_mask']], dim=0).to(device)
        label = batch['label'].to(device)
        logits = model(input_ids, attention_mask=attention_mask)
        preds = logits.argmax(-1)
        
        tmp_test_size = [nor2len_dict[train_mid2label_dict[label_id.item()]]  for label_id in label]
        test_set_size = test_set_size + tmp_test_size
        
        true.append(label.to(torch.device('cpu')))
        pred.append(preds[:len(label)].to(torch.device('cpu')))

    pred = torch.cat(pred).numpy()
    true = torch.cat(true).numpy()    
    
    test_set_size = np.array(test_set_size)
    acc, precision, recall, f1= report(true, pred)
    overall = {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }
    high, middle, few = calc_split(5, 20, true, pred, test_set_size)
    part = {
        'high':{
            'accuracy': high[0],
            'precision': high[1],
            'recall': high[2],
            'f1': high[3],
        },
        'middle':{
            'accuracy': middle[0],
            'precision': middle[1],
            'recall': middle[2],
            'f1': middle[3],
        }, 
        'few':{
            'accuracy': few[0],
            'precision': few[1],
            'recall': few[2],
            'f1': few[3],
        },     
    }
    return (overall, part)

In [21]:
dev_overall, dev_part = evaluate(model, dev_dataset, nor2len_dict, train_mid2label_dict)

100%|██████████| 426/426 [01:13<00:00,  5.83it/s]


Accuracy: 85.308%
Macro Avg Precision: 66.874%
Macro Avg Recall: 69.666%


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Macro Avg F1_score: 67.547%


In [13]:
print(dev_overall)

{'accuracy': 85.30832675104246, 'f1': 68.01655447424996, 'precision': 67.33860369578946, 'recall': 70.15044163355083}


In [14]:
from pprint import pprint

pprint(dev_part)

{'few': {'accuracy': 42.969334330590875,
         'f1': 42.87317620650954,
         'precision': 42.83576505798728,
         'recall': 42.96670407781519},
 'high': {'accuracy': 89.91820306413918,
          'f1': 82.08666408399763,
          'precision': 82.57111870924956,
          'recall': 83.31117209166113},
 'middle': {'accuracy': 67.33297316765712,
            'f1': 67.00544742318522,
            'precision': 66.96229648671807,
            'recall': 67.31852123882973}}


In [16]:
test_overall, test_part = evaluate(model, test_dataset, nor2len_dict, train_mid2label_dict)

100%|██████████| 455/455 [01:20<00:00,  5.66it/s]


In [17]:
print(test_overall)

{'accuracy': 83.30295422498882, 'f1': 62.793747052625946, 'precision': 61.73278962939942, 'recall': 65.47052312326386}


In [18]:
from pprint import pprint

pprint(test_part)

{'few': {'accuracy': 40.99082568807339,
         'f1': 40.75298438934802,
         'precision': 40.615243342516074,
         'recall': 41.02846648301194},
 'high': {'accuracy': 90.01760227345027,
          'f1': 82.09937202259643,
          'precision': 82.33997433157705,
          'recall': 83.43299147360005},
 'middle': {'accuracy': 67.80760223383174,
            'f1': 67.55539233688333,
            'precision': 67.5174439955931,
            'recall': 67.83266005631044}}
