# Evaluate on OSV task

In [1]:
import pickle
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
from scipy.stats import entropy
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_curve

import cs
from tools.utils import to_device, get_test_set_loader
from tasks.bert_classifier.utils import get_model, get_test_set, get_osv_set
from tasks.bert_classifier.train import train
from tasks.bert_classifier.predict_new import bert_classifier_predict, bert_classifier_validation, bert_classifier_predict_probs
from tasks.bert_classifier.predict_new import bert_classifier_test_detail, bert_classifier_test_detail_result

In [2]:
id_to_cls = pickle.load(open('/home/jxqi/ACL/experiment/ann/dataset/id_to_cls.pkl', 'rb'))
cls_to_id = { v:k for k,v in id_to_cls.items() }

## Load Model

In [3]:
model = get_model(last_training_time=1622731763, last_step='02971')

In [4]:
save_test_root = '/home/datamerge/ACL/Data/210422/test/'
save_dev_root = '/home/datamerge/ACL/Data/210422/dev/'
save_pkl_root = '/home/datamerge/ACL/Data/210422/pkl/'

dev_osv_filepath = save_dev_root+'dev_osv.txt'
test_osv_filepath = save_test_root+'test_osv.txt'
nor2len_dict = pickle.load(open(save_pkl_root+'210422_nor2len_dict.pkl', 'rb'))

In [5]:
dev_osv_dataset_first, dev_osv_dataset_second = get_osv_set(dev_osv_filepath)
test_osv_dataset_first, test_osv_dataset_second = get_osv_set(test_osv_filepath)

## Test

In [6]:
def js_divergence(p, q):
    m = (p + q) / 2
    js = entropy(p, m, axis=-1) / 2 + entropy(q, m, axis=-1) / 2
    return js

def report_osv(true, pred):
    return accuracy_score(true, pred)

In [9]:
def evaluate_osv(model, osv_dataset_first, osv_dataset_second, threshold):
    '''
    Input:
        model: model need to evaluate
        osv_dataset_first: the first colume of osv dataset
        osv_dataset_second:  the second colume of osv dataset
        threshold: the threshold to predict whether a pair of samples from one class
    Output:
        
    '''
    prob_first, _ = bert_classifier_predict_probs(model, dev_osv_dataset_first)
    prob_second, labels = bert_classifier_predict_probs(model, dev_osv_dataset_second)
    results = []
    first_probs, second_probs = prob_first, prob_second
    
    for i in range(first_probs.shape[0]):
        first_prob, second_prob = first_probs[i, :], second_probs[i, :]
        result = js_divergence(first_prob, second_prob)
        results.append(result)
        
    results = np.array(results)
    judgements = results < threshold
    
    acc = report_osv(labels, judgements)
        
    return acc

To get the perfomence of test set. We first evalute it on dev set to get a optimum threshold. And then use this threshold to predict test set perfomence.

In [10]:
acc = evaluate_osv(model, dev_osv_dataset_first, dev_osv_dataset_second, threshold=0.5)

100%|██████████| 24/24 [00:28<00:00,  1.17s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:28<00:00,  1.19s/it]


(6004, 25129) (6004,)


In [11]:
acc

0.7370086608927382

In [12]:
for num in range(0, 105, 5):
    threshold = num/100.0
    results = evaluate_osv(model, dev_osv_dataset_first, dev_osv_dataset_second, threshold)
    print("threshold: ", threshold, "\t\tresults: ", results)

100%|██████████| 24/24 [00:28<00:00,  1.19s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:28<00:00,  1.20s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.0 		results:  0.5


100%|██████████| 24/24 [00:28<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.05 		results:  0.5413057961359093


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.1 		results:  0.5659560293137909


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:28<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.15 		results:  0.5877748167888075


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.2 		results:  0.610093271152565


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.25 		results:  0.6292471685542972


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:28<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.3 		results:  0.6517321785476349


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.35 		results:  0.6698867421718854


100%|██████████| 24/24 [00:28<00:00,  1.20s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.4 		results:  0.6897068620919387


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.45 		results:  0.7133577614923384


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.5 		results:  0.7370086608927382


100%|██████████| 24/24 [00:28<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.22s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.55 		results:  0.7641572285143238


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.22s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.6 		results:  0.7776482345103265


100%|██████████| 24/24 [00:29<00:00,  1.22s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.22s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.65 		results:  0.7141905396402398


100%|██████████| 24/24 [00:28<00:00,  1.20s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:28<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.7 		results:  0.5


100%|██████████| 24/24 [00:28<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.22s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.75 		results:  0.5


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.8 		results:  0.5


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.85 		results:  0.5


100%|██████████| 24/24 [00:29<00:00,  1.22s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.9 		results:  0.5


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.22s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.95 		results:  0.5


100%|██████████| 24/24 [00:28<00:00,  1.21s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:29<00:00,  1.21s/it]


(6004, 25129) (6004,)
threshold:  1.0 		results:  0.5


In [None]:
for num in range(45, 65+1, 1):
    threshold = num/100.0
    results = evaluate_osv(model, dev_osv_dataset_first, dev_osv_dataset_second, threshold)
    print("threshold: ", threshold, "\t\tresults: ", results)

100%|██████████| 24/24 [00:28<00:00,  1.18s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:28<00:00,  1.20s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.45 		results:  0.7133577614923384


100%|██████████| 24/24 [00:28<00:00,  1.20s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:28<00:00,  1.20s/it]


(6004, 25129) (6004,)


  0%|          | 0/24 [00:00<?, ?it/s]

threshold:  0.46 		results:  0.7175216522318454


 92%|█████████▏| 22/24 [00:27<00:02,  1.25s/it]

In [14]:
results = evaluate_osv(model, test_osv_dataset_first, test_osv_dataset_second, threshold=0.59)
print("threshold: ", 0.59, "\t\tresults: ", results)

100%|██████████| 24/24 [00:28<00:00,  1.18s/it]
  0%|          | 0/24 [00:00<?, ?it/s]

(6004, 25129) (6004,)


100%|██████████| 24/24 [00:28<00:00,  1.19s/it]


(6004, 25129) (6004,)
threshold:  0.59 		results:  0.7786475682878081
