In [11]:
import torch
import json
import numpy as np
from evaluation import *
from NN_Models import *
import random
from sklearn.metrics import accuracy_score, recall_score, f1_score
from sklearn.model_selection import train_test_split

In [13]:
# year & venue
year_venue_test = torch.load("outputs/year_venue_test.pt")
year_venue_test = year_venue_test.detach().numpy()

# abstract & title
abstract_title_test = torch.load("outputs/abstract_title_test.pt")
abstract_title_test = abstract_title_test.detach().numpy()

# coauthor
author_test = torch.load("outputs/author_test.pt")
author_test = author_test.detach().numpy()

y_test = torch.load('data/y_test.pt')

## Test

#### Grid Search 
Use the grid search to find the best weights of each outputs, and get the highest f1 score

In [14]:
def grid_search(weight1, weight2, threshold):
    for i in weight1:
        for j in weight2:
            for k in threshold:
                yield(i, j, k)

In [15]:
max_f1 = 0
max_param = None

weight1 = np.linspace(0, 1, 20)
weight2 = np.linspace(0, 1, 20)

thresholds = [0.02, 0.05, 0.1, 0.2, 0.3, 0.5, 0.6]
total = len(weight1) * len(weight2) * len(thresholds)

for w1, w2, thred in tqdm(grid_search(weight1, weight2, thresholds), total=total):
    
    if w1 + w2 > 1:
        continue

    w3 = 1 - w1 - w2

    y_pred = predict(
            author=author_test,
            COAUTHOR_WEIGHT=w1,
            year_venue=year_venue_test,
            # year_venue=year_venue_nb_test,
            YEAR_VENUE_WEIGHT=w2,
            abstracts_title=abstract_title_test,
            # abstracts_title=abstract_title_doc2vec_test,
            SENTENCE_WEIGHT=w3,
            THRESHOLD=thred
        )
    
    f1 = f1_score(y_test, y_pred, average='samples', zero_division=1)
    
    if f1 > max_f1:
        max_f1 = f1
        max_param = (w1, w2, w3, thred)
        
print("Max f1 score       : ", round(max_f1, 4))
COAUTHOR_WEIGHT, YEAR_VENUE_WEIGHT, SENTENCE_WEIGHT, THRESHOLD = max_param
print("COAUTHOR_WEIGHT    : ", round(COAUTHOR_WEIGHT, 10))
print("YEAR_VENUE_WEIGHT  : ", round(YEAR_VENUE_WEIGHT, 10))
print("SENTENCE_WEIGHT    : ", round(SENTENCE_WEIGHT, 10))
print("THRESHOLD          : ", THRESHOLD)

100%|██████████| 2800/2800 [13:14<00:00,  3.52it/s]

Max f1 score       :  0.6605
COAUTHOR_WEIGHT    :  0.5789473684
YEAR_VENUE_WEIGHT  :  0.3157894737
SENTENCE_WEIGHT    :  0.1052631579
THRESHOLD          :  0.3



