In [None]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import json
import time
import os
from src.data_preparation.graph_construction import construct_graph
from src.metrics.supervised_metrics import get_rank_metrics, overlap
import pickle
from collections import Counter
from src.baselines.measure_mapping import measure_funcs
from src.data_preparation.prepare_heterogeneous_graph import prepare_heterogeneous_data, split_train_val_test, load_split_data

# FB15K

In [None]:
data_path = "../heterogeneous_data"
graph_data = "FB15K/fb15k_rel.pk"
semantic_data = "FB15K/fb_lang.pk"
split_data = "idx_1000"

graph_data_path = data_path + '/' + graph_data
semantic_data_path = data_path +'/' + semantic_data
split_data_path = data_path + f'/{"FB15K"}/datasets_split/' + split_data

num_split_idx = 1000
train_num = 1


In [None]:
full_graph_edges, feat, edge_types, _ = prepare_heterogeneous_data(graph_data_path, semantic_data_path, 
                                                                   split_data_path, num_split_idx)

In [None]:
dataset_spilt, labels_idx = load_split_data(split_data_path, num_split_idx)

In [None]:
print(Counter(edge_types.reshape(-1))

In [None]:
cross_num = 5
et = 13
outdegree_metrics = []
pine_metrics = []
measure_names = ['out_degree', 'ilc', 'h_index', 'pagerank', 'voterank', 'enrenew', 'betweenness', 'eddc', 'pine']
metric_names = ['ndcg', 'spearman', 'overlap']
superv_res = {measure: {metric: [] for metric in metric_names} for measure in measure_names}
for cross_id in range(cross_num):
    _, val_idx, test_idx, _, val_labels, test_labels, _, _ = split_train_val_test(dataset_spilt, labels_idx, train_num, num_split_idx)
    val_node_labels = dict(zip(val_idx, val_labels))
    test_node_labels = dict(zip(test_idx, test_labels))
    print('Val size', len(val_idx))
    print('Test size', len(test_idx))

    graph_edges = full_graph_edges[:, edge_types==et]
    subnodes = np.unique(np.concatenate([graph_edges[0, :], graph_edges[1, :]], axis=0))
    subfeat = feat[subnodes, :]
    
    origin2new_dict = {subnodes[i]: i for i in range(len(subnodes))}
    new2origin_dict = {value: key for key, value in origin2new_dict.items()}
    f = lambda x: origin2new_dict[x]
    graph_edges = np.vectorize(f)(graph_edges)    

    G, tw, aw = construct_graph(subfeat, graph_edges)

    test_measure_nodes = [node for node in subnodes if node in test_idx]
    gt_labels = torch.tensor([test_node_labels[node] for node in test_measure_nodes])
    for measure in measure_names:
        print(measure)
        measure_dict = measure_funcs[measure](G)
        measure_labels = torch.tensor([measure_dict[origin2new_dict[node]]*np.sum(full_graph_edges[1, :]==node) if origin2new_dict[node] in measure_dict else 0
                                       for node in test_measure_nodes]) 

        ndcg_value, spearman_value = get_rank_metrics(measure_labels, gt_labels, 100, True)
        overlap_value = overlap(gt_labels, measure_labels, 100)
        
        superv_res[measure]['ndcg'].append(ndcg_value)
        superv_res[measure]['spearman'].append(spearman_value)
        superv_res[measure]['overlap'].append(overlap_value)
        

In [None]:
for measure in measure_names: 
    for metric in metric_names:
        print(measure, metric, np.mean(superv_res[measure][metric]), np.std(superv_res[measure][metric]))