In [None]:
from tqdm import tqdm
from collections import Counter

import pandas as pd
import numpy as np
import wandb
import networkx as nx

import torch
import torch.nn.functional as F
import torch.nn as nn

from torch_geometric.utils import to_networkx, from_networkx


from sklearn.metrics import mean_absolute_error
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

General data:

In [None]:
path_2_root = "../"
authors_edges_papers_general = pd.read_csv(path_2_root + "processed_data/SSORC_CS_2010_2021_authors_edges_papers_indices.csv", index_col = 0, \
                                   converters={"papers_indices": lambda x: x.strip("[]").replace("'","").split(", ")})
authors_edges_general = pd.read_csv(path_2_root + "processed_data/SSORC_CS_2010_2021_authors_edge_list.csv", index_col = 0)
papers_features_general = pd.read_csv(path_2_root + "processed_data/SSORC_CS_2010_2021_papers_features_vectorized_compressed_32.csv", index_col = 0)
authors_features_general = pd.read_csv(path_2_root + "processed_data/SSORC_CS_2010_2021_authors_features.csv", index_col = 0)
aev = authors_edges_general.values
edge_to_index = {(aev[i][0], aev[i][1]):i for i in tqdm(range(len(aev)))}

Local data:

In [None]:
dataset_name = 'SSORC_CS_10_21_1437_3164_unfiltered'

authors_edges_papers = pd.read_csv(path_2_root + "datasets/" + dataset_name + "/" + dataset_name + "_" + "authors_edges_papers_indices.csv", index_col = 0,\
                                   converters={"papers_indices": lambda x: x.strip("[]").replace("'","").split(", ")})
authors_graph = nx.read_edgelist(path_2_root + "datasets/" + dataset_name + "/" + dataset_name + "_" + "authors.edgelist", create_using = nx.DiGraph)
citation_graph = nx.read_edgelist(path_2_root + "datasets/" + dataset_name + "/" + dataset_name + "_" + "papers.edgelist", create_using = nx.DiGraph)
papers_targets = pd.read_csv(path_2_root + "datasets/" + dataset_name + "/" + dataset_name + "_papers_targets.csv", index_col = 0)

### Data preparation

In [None]:
splits = "5_0.1"
path = path_2_root + "datasets/" + dataset_name + "/split_" + str(splits) + "/"  

split = 4
train_data_a = torch.load(path + dataset_name + '_train_sample_' + str(split) + '.data')
val_data_a = torch.load(path + dataset_name + '_val_sample_' + str(split) + '.data')
test_data_a = torch.load(path + dataset_name + '_test_sample_' + str(split) + '.data')

In [None]:
papers_nodes = list(citation_graph.nodes)
papers_nodes = [int(papers_nodes[i]) for i in range(len(papers_nodes))]
papers_node_features = papers_features_general.iloc[papers_nodes, :]
for node in tqdm(citation_graph.nodes):
    citation_graph.nodes[node]['x'] = list(papers_node_features.loc[[int(node)]].values[0])
authors_nodes = list(authors_graph.nodes)
authors_nodes = [int(authors_nodes[i]) for i in range(len(authors_nodes))]
authors_node_features = authors_features_general.loc[authors_nodes]
for node in tqdm(authors_graph.nodes):
    authors_graph.nodes[node]['x'] = list(authors_node_features.loc[[int(node)]].values[0])
data_author = from_networkx(authors_graph)
data_citation = from_networkx(citation_graph)

train_data_a.x, val_data_a.x, test_data_a.x = data_author.x.float(), data_author.x.float(), data_author.x.float()
data_citation.x = data_citation.x.float()

original_a_nodes = list(authors_graph.nodes)
pyg_id_2_original_id = {i:int(original_a_nodes[i]) for i in range(len(original_a_nodes))}


original_a_nodes = list(authors_graph.nodes)
pyg_id_2_original_id = {i:int(original_a_nodes[i]) for i in range(len(original_a_nodes))}

sAe_t = train_data_a.edge_index.cpu().numpy().T
sAe_t = [(pyg_id_2_original_id[int(sAe_t[i][0])], pyg_id_2_original_id[int(sAe_t[i][1])]) for i in range(len(sAe_t))]

authors_edges_papers_sub_2t = [authors_edges_papers["papers_indices"][edge_to_index[sAe_t[i]]] for i in tqdm(range(len(sAe_t)))]
authors_edges_papers_sub_flat_2t = [str(item) for subarray in authors_edges_papers_sub_2t for item in subarray]
unique_papers_2t = list(set(authors_edges_papers_sub_flat_2t))

In [None]:
citation_graph_sub = citation_graph.subgraph(unique_papers_2t)
citation_graph_sub_nodes = list(citation_graph_sub.nodes())
global_to_local_id_citation = {citation_graph_sub_nodes[i]:i for i in range(len(citation_graph_sub_nodes))}
authors_graph_sub_nodes = list(authors_graph.nodes())
global_to_local_id_authors = {authors_graph_sub_nodes[i]:i for i in range(len(authors_graph_sub_nodes))}

authors_to_papers = dict()
for i in tqdm(range(len(sAe_t))):
    papers = authors_edges_papers_sub_2t[i]
    author_1, author_2 = sAe_t[i]
    for author in sAe_t[i]:
        if author in authors_to_papers:
            for paper in papers:
                authors_to_papers[global_to_local_id_authors[str(author)]].add(global_to_local_id_citation[paper])
        else:
            authors_to_papers[global_to_local_id_authors[str(author)]] = set()
            for paper in papers:
                authors_to_papers[global_to_local_id_authors[str(author)]].add(global_to_local_id_citation[paper])

In [None]:
for node in tqdm(citation_graph_sub.nodes):
    citation_graph_sub.nodes[node]['x'] = list(papers_features_general.loc[[int(node)]].values[0])
data_citation = from_networkx(citation_graph_sub)
data_citation.x = data_citation.x.float()

authors_nodes = list(authors_graph.nodes)
authors_nodes = [int(authors_nodes[i]) for i in range(len(authors_nodes))]
authors_node_features = authors_features_general.loc[authors_nodes]

### Auxiliary targets computing

In [None]:
data_author = from_networkx(authors_graph)
edges_ordered = [(int(data_author.edge_index.T[i][0]), int(data_author.edge_index.T[i][1])) for i in range(len(data_author.edge_index.T))]
index_to_edge = {i:edges_ordered[i] for i in range(len(edges_ordered))}
authors_edges_papers_sample = authors_edges_papers_sub_2t
citation_nodes = list(citation_graph_sub.nodes)
ownership_dict = {}
inds_dict = {}
for i in tqdm(range(len(authors_edges_papers_sample))):
    arr = authors_edges_papers_sample[i]
    collab_embeddings = []
    for j in range(len(arr)):
        ind = citation_nodes.index(arr[j]) # index_outer_2_index_inner[int(arr[j])]
        collab_embeddings.append(ind)
    ownership_dict[i] = i
    inds_dict[i] = collab_embeddings
    
embs_dict = inds_dict
lens = set([len(embs_dict[i]) for i in range(len(embs_dict))])
batch_dict_x = {}
batch_dict_owner = {}
batch_dict_ind = {}
for i in tqdm(range(len(embs_dict))):
    if (len(embs_dict[i])) in batch_dict_x:
        batch_dict_x[len(embs_dict[i])].append(embs_dict[i])
        batch_dict_owner[len(embs_dict[i])].append(ownership_dict[i])
        batch_dict_ind[len(embs_dict[i])].append(i)
    else:
        batch_dict_x[len(embs_dict[i])], batch_dict_owner[len(embs_dict[i])], batch_dict_ind[len(embs_dict[i])] = [], [], []
        batch_dict_x[len(embs_dict[i])].append(embs_dict[i])
        batch_dict_owner[len(embs_dict[i])].append(ownership_dict[i])
        batch_dict_ind[len(embs_dict[i])].append(i)

for length in batch_dict_owner:
    batch_dict_owner[length] = [index_to_edge[batch_dict_owner[length][i]] for i in range(len(batch_dict_owner[length]))]
    
batch_list_x = list(batch_dict_x.values())
batch_list_owner = list(batch_dict_owner.values())
batch_list_ind = list(batch_dict_ind.values())

papers_targets = papers_targets.values
aux_targets = []
for i in tqdm(range(len(batch_list_x))):
    batch = batch_list_x[i]
    values = []
    for j in range(len(batch)):
        values = [papers_targets[batch[j][k]] for k in range(len(batch[j]))]
        values = np.array(values).T
        targets = [max(values[0]), sum(values[1])/len(values[1]), 
                   sum(values[2])/len(values[2]), sum(values[3])/len(values[3]),
                   len(values[0])]
        aux_targets.append(targets)
        
batch_list_owner_flat = [edge for batch in batch_list_owner for edge in batch]
aux_target_dict = {batch_list_owner_flat[i]:aux_targets[i] for i in tqdm(range(len(aux_targets)))}

train_edges_aux_t, val_edges_aux_t, test_edges_aux_t = train_data_a.edge_label_index.cpu().numpy().T,\
                                                       val_data_a.edge_label_index.cpu().numpy().T,\
                                                       test_data_a.edge_label_index.cpu().numpy().T

train_edges_aux_t, val_edges_aux_t, test_edges_aux_t = [(train_edges_aux_t[i][0], train_edges_aux_t[i][1]) for i in range(len(train_edges_aux_t))],\
                                                       [(val_edges_aux_t[i][0], val_edges_aux_t[i][1]) for i in range(len(val_edges_aux_t))],\
                                                       [(test_edges_aux_t[i][0], test_edges_aux_t[i][1]) for i in range(len(test_edges_aux_t))]

def get_aux_targets(train_edges_aux_t: list) -> list:
    aux_train_target = []
    for k in range(len(train_edges_aux_t)):
        if train_edges_aux_t[k] in aux_target_dict:
            aux_train_target.append(aux_target_dict[train_edges_aux_t[k]])
        else:
            aux_train_target.append([0, 0, 0, 0, 0])
    return aux_train_target

def task_split(aux_train_targets):
    y_q, y_sjr, y_h, y_if, y_n = np.array(aux_train_targets).T
    return torch.Tensor(y_q.T).float().cuda(),\
           torch.Tensor(y_sjr.T).float().cuda(),\
           torch.Tensor(y_h.T).float().cuda(),\
           torch.Tensor(y_if.T).float().cuda(),\
           torch.Tensor(y_n.T).float().cuda()

aux_train_targets, aux_val_targets, aux_test_targets = get_aux_targets(train_edges_aux_t),\
                                                       get_aux_targets(val_edges_aux_t),\
                                                       get_aux_targets(test_edges_aux_t)

train_aux_y, test_aux_y = task_split(aux_train_targets), task_split(aux_test_targets)

train_data_a.aux = train_aux_y
test_data_a.aux = test_aux_y

### Model

In [None]:
%load_ext autoreload
%autoreload 2
from REIGNN import REIGNN

In [None]:
# W&B parameters
project_name = 'GAT_0.3'
prefix = "standard_aux_st_convs"
entity = "sbergraphs"
wandb_output = False

# Global
epochs_per_launch = 15000
lr = 0.001
device = 'cuda:0'

# Local
heads = 1

c_conv_num = 2
c_latent_size = 128

a_conv_num = 3
a_latent_size = 384

operator = "hadamard"
link_size = 128

# Multitask weights
mt_weights = [[0.05, 0.05, 0.05, 0.05]]

def train(model, optimizer, criterion, mt_coeffs = [0, 0, 0, 0], main_coeff = 0.1):
    model.train()
    optimizer.zero_grad()
    sample = model.train_data_a
    z, z_sjr, z_hi, z_ifact, z_numb = model(sample, batch_list_x, batch_list_owner, operator)
    edge_index = sample.edge_label_index
#    link_embeddings = F.relu(final(link_embeddings))
    link_labels = sample.edge_label
    main_loss = main_coeff*F.binary_cross_entropy_with_logits(z, link_labels)
    mask = z > 0
    loss = main_loss\
           + mt_coeffs[0]*criterion(z_sjr[mask], sample.aux[1][mask])\
           + mt_coeffs[1]*criterion(z_hi[mask], sample.aux[2][mask])\
           + mt_coeffs[2]*criterion(z_ifact[mask], sample.aux[3][mask])\
           + mt_coeffs[3]*criterion(z_ifact[mask], sample.aux[4][mask])
    loss.backward()
    optimizer.step()
    return loss
    
@torch.no_grad()
def test(model, optimizer, criterion):
    model.eval()
    perfs = []
    aux = []

    for sample in [model.train_data_a, model.test_data_a]: 
        z, z_sjr, z_hi, z_ifact, z_numb = model(sample, batch_list_x, batch_list_owner, operator)
#        link_embeddings = F.relu(final(link_embeddings))
        link_probs = z.sigmoid()
        link_labels = sample.edge_label
        aux.append([mean_absolute_error(sample.aux[1].cpu(), z_sjr.cpu()),\
                    mean_absolute_error(sample.aux[2].cpu(), z_hi.cpu()),\
                    mean_absolute_error(sample.aux[3].cpu(), z_ifact.cpu()),\
                    mean_absolute_error(sample.aux[4].cpu(), z_numb.cpu())])
        perfs.append([accuracy_score(link_labels.cpu(), link_probs.cpu().round()),\
                      f1_score(link_labels.cpu(), link_probs.cpu().round()),\
                      roc_auc_score(link_labels.cpu(), link_probs.cpu())])
    return perfs, aux

In [None]:
def run(project_name, group, entity, mt_weights, model, optimizer, criterion, i):
    if wandb_output:
        wandb.init(project=project_name, entity=entity, group=group) #, group="Experimental")
        wandb.run.name = group + "_" + str(i)
        wandb.run.save()
    
    max_acc_test, max_f1_test, max_roc_auc_test = 0, 0, 0
    max_acc_val, max_f1_val, max_roc_auc_val = 0, 0, 0

    min_mae_sjr_test, min_mae_h_index_test, min_mae_impact_factor_test, min_mae_number_test = 100500, 100500, 100500, 100500
    min_mae_sjr_val, min_mae_h_index_val, min_mae_impact_factor_val, min_mae_number_val = 100500, 100500, 100500, 100500
    
    for epoch in tqdm(range(epochs_per_launch)):
        loss = []
        train_loss = train(model, optimizer, criterion, mt_weights)

        if epoch % 10 == 0:
            metrics, metrics_aux = test(model, optimizer, criterion)
            print("Train main:", metrics[0])
            print("Test main:", metrics[1])
            print("Train auxiliary:", metrics_aux[0])
            print("Test auxiliary:", metrics_aux[1])
        
        if metrics[1][0] > max_acc_test:
            max_acc_test = metrics[1][0]
        if metrics[1][1] > max_f1_test:
            max_f1_test = metrics[1][1]
        if metrics[1][2] > max_roc_auc_test:
            max_roc_auc_test = metrics[1][2]  
        
        if metrics_aux[1][0] < min_mae_sjr_test:
            min_mae_sjr_test = metrics_aux[1][0]
        if metrics_aux[1][1] < min_mae_h_index_test:
            min_mae_h_index_test = metrics_aux[1][1]
        if metrics_aux[1][2] < min_mae_impact_factor_test:
            min_mae_impact_factor_test = metrics_aux[1][2]  
        if metrics_aux[1][3] < min_mae_number_test:
            min_mae_number_test = metrics_aux[1][3]
            
        if wandb_output:
            wandb.log({"main_train/train_acc":  metrics[0][0], "main_test/test_acc": metrics[1][0],\
                   "main_train/train_f1": metrics[0][1], "main_test/test_f1": metrics[1][1],\
                   "main_train/train_roc_auc": metrics[0][2], "main_test/test_roc_auc": metrics[1][2]})
            
            wandb.log({"main_max/test_max_acc": max_acc_test,\
                   "main_max/test_max_f1": max_f1_test,\
                   "main_max/test_max_roc_auc": max_roc_auc_test})
        
            wandb.log({"aux_train/train_mae_sjr":  metrics_aux[0][0], "aux_train/train_mae_h_index": metrics_aux[0][1], "aux_train/train_mae_impact_factor": metrics_aux[0][2], "aux_train/train_number": metrics_aux[0][3],
                   "aux_test/test_mae_sjr":  metrics_aux[1][0], "aux_test/test_mae_h_index": metrics_aux[1][1], "aux_test/test_mae_impact_factor": metrics_aux[1][2], "aux_test/test_number": metrics_aux[1][3]})
            
            wandb.log({"aux_min/test_min_mae_sjr": min_mae_sjr_test,\
                   "aux_min/test_min_mae_h_index": min_mae_h_index_test,\
                   "aux_min/test_min_mae_impact_factor": min_mae_impact_factor_test,
                   "aux_min/test_min_number_factor": min_mae_number_test})
            
for i in range(10):
    for mt_weight in mt_weights:
        group = "rmgnn(a_rgat(" + prefix + "_" + str(link_size) +  "_" + str(a_latent_size) + "_" + str(a_conv_num) + "_" + str(c_conv_num) + "_" + str(c_latent_size) + ")_v2)_split_n0_" + str(mt_weight[0]) + "_"\
        + str(mt_weight[1]) + "_" + str(mt_weight[2]) + "_" + str(mt_weight[3])\
        + operator + "_" + str(lr) + "_no_wd"
        model = REIGNN(data_citation, heads, device,\
                                    train_data_a, val_data_a, test_data_a,
                                    authors_to_papers,
                                    cit_layers = c_conv_num, latent_size_cit = c_latent_size,
                                    auth_layers = a_conv_num, latent_size_auth = a_latent_size,
                                    link_size = link_size).to(device) 
        data_citation, data_author, train_data_a = data_citation.to(device), data_author.to(device), train_data_a.to(device) 

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.L1Loss()
        run(project_name, group, entity, mt_weight, model, optimizer, criterion, i)