In [None]:
!pip uninstall -y torch
!pip install torch==1.8.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install --no-index --no-cache-dir torch-scatter  torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.1+cu101.html
!pip install --no-cache-dir torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.1+cu101.html
!pip install --no-cache-dir torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.1+cu101.html
!pip install --no-cache-dir torch-geometric
!pip install wandb

In [None]:
from tqdm import tqdm
from collections import Counter

import pandas as pd
import numpy as np
import wandb
import networkx as nx

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import ModuleList, Embedding
from torch.autograd import Variable

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv,GATv2Conv
from torch_geometric.nn import PNAConv, CGConv, BatchNorm, global_add_pool

from torch_geometric.data import Data
from torch_geometric.utils import negative_sampling
from torch_geometric.utils import degree
from torch_geometric.utils import erdos_renyi_graph, to_networkx, from_networkx
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import HeteroData

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

General data:

In [None]:
path_2_root = "../"

In [None]:
authors_edges_papers_general = pd.read_csv(path_2_root + "processed_data/SSORC_CS_2010_2021_authors_edges_papers_indices.csv", index_col = 0, \
                                   converters={"papers_indices": lambda x: x.strip("[]").replace("'","").split(", ")})
authors_edges_general = pd.read_csv(path_2_root + "processed_data/SSORC_CS_2010_2021_authors_edge_list.csv", index_col = 0)

In [None]:
papers_features_general = pd.read_csv(path_2_root + "processed_data/SSORC_CS_2010_2021_papers_features_vectorized_compressed_32.csv", index_col = 0)
authors_features_general = pd.read_csv(path_2_root + "processed_data/SSORC_CS_2010_2021_authors_features.csv", index_col = 0)

In [None]:
aev = authors_edges_general.values
edge_to_index = {(aev[i][0], aev[i][1]):i for i in tqdm(range(len(aev)))}

Local data:

In [None]:
!ls datasets

In [None]:
dataset_name = 'SSORC_CS_10_21_1437_3164_unfiltered'

In [None]:
authors_edges_papers = pd.read_csv(path_2_root + "datasets/" + dataset_name + "/" + dataset_name + "_" + "authors_edges_papers_indices.csv", index_col = 0,\
                                   converters={"papers_indices": lambda x: x.strip("[]").replace("'","").split(", ")})

In [None]:
authors_graph = nx.read_edgelist(path_2_root + "datasets/" + dataset_name + "/" + dataset_name + "_" + "authors.edgelist", create_using = nx.DiGraph)

In [None]:
citation_graph = nx.read_edgelist(path_2_root + "datasets/" + dataset_name + "/" + dataset_name + "_" + "papers.edgelist", create_using = nx.DiGraph)

In [None]:
papers_targets = pd.read_csv(path_2_root + "datasets/" + dataset_name + "/" + dataset_name + "_papers_targets.csv", index_col = 0)

### MANDATORY CHECK

In [None]:
sAe = list(authors_graph.edges)
sAe = [(int(sAe[i][0]), int(sAe[i][1])) for i in range(len(sAe))]

In [None]:
authors_edges_papers_sub_2 = [authors_edges_papers["papers_indices"][edge_to_index[sAe[i]]] for i in tqdm(range(len(sAe)))]
authors_edges_papers_sub_flat_2 = [int(item) for subarray in authors_edges_papers_sub_2 for item in subarray]
unique_papers_2 = list(set(authors_edges_papers_sub_flat_2))

In [None]:
cgn = list(citation_graph.nodes())
cgn = [int(cgn[i]) for i in range(len(cgn))]

In [None]:
len(set(unique_papers_2).intersection(set(cgn))), len(unique_papers_2), len(cgn)

### Data preparation

In [None]:
splits = "10"
path = path_2_root + "datasets/" + dataset_name + "/split_" + str(splits) + "/"  

split = 0
train_data_a = torch.load(path + dataset_name + '_train_sample_' + str(split) + '.data')
val_data_a = torch.load(path + dataset_name + '_val_sample_' + str(split) + '.data')
test_data_a = torch.load(path + dataset_name + '_test_sample_' + str(split) + '.data')

In [None]:
papers_nodes = list(citation_graph.nodes)
papers_nodes = [int(papers_nodes[i]) for i in range(len(papers_nodes))]
papers_node_features = papers_features_general.iloc[papers_nodes, :]
for node in tqdm(citation_graph.nodes):
    citation_graph.nodes[node]['x'] = list(papers_node_features.loc[[int(node)]].values[0])
authors_nodes = list(authors_graph.nodes)
authors_nodes = [int(authors_nodes[i]) for i in range(len(authors_nodes))]
authors_node_features = authors_features_general.loc[authors_nodes]
for node in tqdm(authors_graph.nodes):
    authors_graph.nodes[node]['x'] = list(authors_node_features.loc[[int(node)]].values[0])
data_author = from_networkx(authors_graph)
data_citation = from_networkx(citation_graph)

deg = torch.zeros(2, dtype=torch.long)
d = degree(train_data_a.edge_index[1], num_nodes=train_data_a.num_nodes, dtype=torch.long)
deg = torch.bincount(d, minlength=deg.numel())

train_data_a.x, val_data_a.x, test_data_a.x = data_author.x.float(), data_author.x.float(), data_author.x.float()
data_citation.x = data_citation.x.float()

original_a_nodes = list(authors_graph.nodes)
pyg_id_2_original_id = {i:int(original_a_nodes[i]) for i in range(len(original_a_nodes))}


original_a_nodes = list(authors_graph.nodes)
pyg_id_2_original_id = {i:int(original_a_nodes[i]) for i in range(len(original_a_nodes))}

sAe_t = train_data_a.edge_index.cpu().numpy().T
sAe_t = [(pyg_id_2_original_id[int(sAe_t[i][0])], pyg_id_2_original_id[int(sAe_t[i][1])]) for i in range(len(sAe_t))]

authors_edges_papers_sub_2t = [authors_edges_papers["papers_indices"][edge_to_index[sAe_t[i]]] for i in tqdm(range(len(sAe_t)))]
authors_edges_papers_sub_flat_2t = [str(item) for subarray in authors_edges_papers_sub_2t for item in subarray]
unique_papers_2t = list(set(authors_edges_papers_sub_flat_2t))

In [None]:
citation_graph_sub = citation_graph.subgraph(unique_papers_2t)
citation_graph_sub_nodes = list(citation_graph_sub.nodes())
global_to_local_id_citation = {citation_graph_sub_nodes[i]:i for i in range(len(citation_graph_sub_nodes))}
authors_graph_sub_nodes = list(authors_graph.nodes())
global_to_local_id_authors = {authors_graph_sub_nodes[i]:i for i in range(len(authors_graph_sub_nodes))}

In [None]:
authors_to_papers = dict()
for i in tqdm(range(len(sAe_t))):
    papers = authors_edges_papers_sub_2t[i]
    author_1, author_2 = sAe_t[i]
    for author in sAe_t[i]:
        if author in authors_to_papers:
            for paper in papers:
                authors_to_papers[global_to_local_id_authors[str(author)]].add(global_to_local_id_citation[paper])
        else:
            authors_to_papers[global_to_local_id_authors[str(author)]] = set()
            for paper in papers:
                authors_to_papers[global_to_local_id_authors[str(author)]].add(global_to_local_id_citation[paper])


### Final preparations of graphs

Adding feature description of citation graph:

In [None]:
for node in tqdm(citation_graph_sub.nodes):
    citation_graph_sub.nodes[node]['x'] = list(papers_features_general.loc[[int(node)]].values[0])
data_citation = from_networkx(citation_graph_sub)
data_citation.x = data_citation.x.float()

Adding feature description of co-authorship graph:

In [None]:
authors_nodes = list(authors_graph.nodes)
authors_nodes = [int(authors_nodes[i]) for i in range(len(authors_nodes))]
authors_node_features = authors_features_general.loc[authors_nodes]

### Auxiliary targets computing

In [None]:
data_author = data_author 
edges_ordered = [(int(data_author.edge_index.T[i][0]), int(data_author.edge_index.T[i][1])) for i in range(len(data_author.edge_index.T))]
index_to_edge = {i:edges_ordered[i] for i in range(len(edges_ordered))}
authors_edges_papers_sample = authors_edges_papers_sub_2
citation_nodes = list(citation_graph.nodes)
ownership_dict = {}
inds_dict = {}
for i in tqdm(range(len(authors_edges_papers_sample))):
    arr = authors_edges_papers_sample[i]
    collab_embeddings = []
    for j in range(len(arr)):
        ind = citation_nodes.index(arr[j]) # index_outer_2_index_inner[int(arr[j])]
        collab_embeddings.append(ind)
    ownership_dict[i] = i
    inds_dict[i] = collab_embeddings
    
embs_dict = inds_dict
lens = set([len(embs_dict[i]) for i in range(len(embs_dict))])
batch_dict_x = {}
batch_dict_owner = {}
batch_dict_ind = {}
for i in tqdm(range(len(embs_dict))):
    if (len(embs_dict[i])) in batch_dict_x:
        batch_dict_x[len(embs_dict[i])].append(embs_dict[i])
        batch_dict_owner[len(embs_dict[i])].append(ownership_dict[i])
        batch_dict_ind[len(embs_dict[i])].append(i)
    else:
        batch_dict_x[len(embs_dict[i])], batch_dict_owner[len(embs_dict[i])], batch_dict_ind[len(embs_dict[i])] = [], [], []
        batch_dict_x[len(embs_dict[i])].append(embs_dict[i])
        batch_dict_owner[len(embs_dict[i])].append(ownership_dict[i])
        batch_dict_ind[len(embs_dict[i])].append(i)

for length in batch_dict_owner:
    batch_dict_owner[length] = [index_to_edge[batch_dict_owner[length][i]] for i in range(len(batch_dict_owner[length]))]
    
batch_list_x = list(batch_dict_x.values())
batch_list_owner = list(batch_dict_owner.values())
batch_list_ind = list(batch_dict_ind.values())

In [None]:
papers_targets = papers_targets.values

In [None]:
aux_targets = []
for i in tqdm(range(len(batch_list_x))):
    batch = batch_list_x[i]
    values = []
    for j in range(len(batch)):
        values = [papers_targets[batch[j][k]] for k in range(len(batch[j]))]
        values = np.array(values).T
        targets = [max(values[0]), sum(values[1])/len(values[1]), 
                   sum(values[2])/len(values[2]), sum(values[3])/len(values[3]),
                   len(values[0])]
        aux_targets.append(targets)
        
batch_list_owner_flat = [edge for batch in batch_list_owner for edge in batch]
aux_target_dict = {batch_list_owner_flat[i]:aux_targets[i] for i in tqdm(range(len(aux_targets)))}

train_edges_aux_t, val_edges_aux_t, test_edges_aux_t = train_data_a.edge_label_index.cpu().numpy().T,\
                                                       val_data_a.edge_label_index.cpu().numpy().T,\
                                                       test_data_a.edge_label_index.cpu().numpy().T

train_edges_aux_t, val_edges_aux_t, test_edges_aux_t = [(train_edges_aux_t[i][0], train_edges_aux_t[i][1]) for i in range(len(train_edges_aux_t))],\
                                                       [(val_edges_aux_t[i][0], val_edges_aux_t[i][1]) for i in range(len(val_edges_aux_t))],\
                                                       [(test_edges_aux_t[i][0], test_edges_aux_t[i][1]) for i in range(len(test_edges_aux_t))]

def get_aux_targets(train_edges_aux_t: list) -> list:
    aux_train_target = []
    for k in range(len(train_edges_aux_t)):
        if train_edges_aux_t[k] in aux_target_dict:
            aux_train_target.append(aux_target_dict[train_edges_aux_t[k]])
        else:
            aux_train_target.append([0, 0, 0, 0, 0])
    return aux_train_target

def task_split(aux_train_targets):
    y_q, y_sjr, y_h, y_if, y_n = np.array(aux_train_targets).T
    return torch.Tensor(y_q.T).float().cuda(),\
           torch.Tensor(y_sjr.T).float().cuda(),\
           torch.Tensor(y_h.T).float().cuda(),\
           torch.Tensor(y_if.T).float().cuda(),\
           torch.Tensor(y_n.T).float().cuda()

aux_train_targets, aux_val_targets, aux_test_targets = get_aux_targets(train_edges_aux_t),\
                                                       get_aux_targets(val_edges_aux_t),\
                                                       get_aux_targets(test_edges_aux_t)

train_aux_y, test_aux_y = task_split(aux_train_targets), task_split(aux_test_targets)

train_data_a.aux = train_aux_y
test_data_a.aux = test_aux_y

### Model

In [None]:
class ResLinearBlock(nn.Module):
    def __init__(self, size, link_size):
        super(ResLinearBlock, self).__init__()
        self.linear_1 = nn.Linear(size, size)
        self.linear_2 = nn.Linear(size, size)
        self.linear_3 = nn.Linear(size, link_size)
        self.batch_norm_1 = nn.BatchNorm1d(size)
        self.batch_norm_2 = nn.BatchNorm1d(size)

    def forward(self, x):
        result = self.linear_1(x)
        result = F.relu(self.batch_norm_1(result))
        result = self.linear_2(result)
        result = F.relu(self.batch_norm_2(result))
        result = self.linear_3(x + result)
        return result
    
class gs_sum_concatenation_gs(nn.Module):
    def __init__(self, data_c, 
                 parameters, 
                 train_data_a, val_data_a, test_data_a):
        super(gs_sum_concatenation_gs, self).__init__()
        
        self.data_c  = data_c
        self.train_data_a, self.val_data_a, self.test_data_a = train_data_a, val_data_a, test_data_a 
        self.params = parameters
        
        # convolutions on citation graph
        self.conv_c_1 = GATv2Conv(data_c.x.shape[1], self.params["conv_size"][0])
        self.conv_c_2 = GATv2Conv(self.params["conv_size"][0], self.params["conv_size"][1])
        self.conv_c_3 = GATv2Conv(self.params["conv_size"][1], self.params["conv_size"][2])
        
        # aggregation
        self.pre_conv = nn.Linear(self.params["conv_size"][2]+19, 75)
            
        self.input_size = self.params["conv_size"][2]+19
        self.num_layers = 1
        self.hidden_size = 75
        
        self.lstm = nn.LSTM(input_size=75, hidden_size=75,
                    num_layers=1, batch_first=True)
        
        # convolutions on co-authorship graph
        
        self.convs_a = ModuleList()
        self.batch_norms = ModuleList()
        
        for _ in range(3):
            conv_a = GATv2Conv(75, 75)
            self.convs_a.append(conv_a)
            self.batch_norms.append(BatchNorm(75))
        
        # post link prediction layers
        self.post_lp_layers = ModuleList()
        for _ in range(4):
            hidden_post_lp = nn.Linear(75, 1)
            self.post_lp_layers.append(hidden_post_lp)
            
        # multitask
        self.hidden_q1 = ResLinearBlock(75, 128)
        self.hidden_q2 = nn.Linear(128, 75)
        
        self.hidden_if1 = ResLinearBlock(75, 128)
        self.hidden_if2 = nn.Linear(128, 75)
        
        self.hidden_hi1 = ResLinearBlock(75, 128)
        self.hidden_hi2 = nn.Linear(128, 75)

        self.hidden_sjr1 = ResLinearBlock(75, 128)
        self.hidden_sjr2 = nn.Linear(128, 75)

    def forward(self, sample, batch_list_x, batch_list_owner, operator = "l2"):
        def cp(z, edge_index):
            return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

        def l1(z, edge_index):
            return (torch.abs(z[edge_index[0]] - z[edge_index[1]]))

        def l2(z, edge_index):
            return (torch.pow(z[edge_index[0]] - z[edge_index[1]], 2))

        def hadamard(z, edge_index):
            return (z[edge_index[0]] * z[edge_index[1]])
       
        def summ(z, edge_index):
            return (z[edge_index[0]] + z[edge_index[1]])
        
        x_a = self.train_data_a.x
        if True:
            x_c = F.relu(self.conv_c_1(self.data_c.x, self.data_c.edge_index, None))
            x_c = F.relu(self.conv_c_2(x_c, self.data_c.edge_index, None))
            x_c = F.relu(self.conv_c_3(x_c, self.data_c.edge_index, None))


            counter = 0
            x = torch.zeros(x_a.shape[0], x_a.shape[1] + x_c.shape[1]).to(device)
            for i in range(len(x_a)):
                if i in authors_to_papers:
                    collab_emb = sum(x_c[list(authors_to_papers[i])])
                else:
                    counter += 1
                    collab_emb = torch.zeros(self.params["conv_size"][2]).to(device)
                x[i] = torch.cat((x_a[i].unsqueeze(0), collab_emb.unsqueeze(0)), 1)

        convolutions = []
        x = self.pre_conv(x)
        convolutions.append(x)
        for conv, batch_norm in zip(self.convs_a, self.batch_norms):
            x = F.relu(conv(x, self.train_data_a.edge_index))
            convolutions.append(x)
        
        horizontal = []
        for j in range(len(convolutions[0])):
            horizontal.append([convolutions[i][j] for i in range(len(convolutions))])
            horizontal[j] = torch.stack(horizontal[j])
        emb_seqs_t = torch.stack(horizontal)

        h_0 = Variable(torch.zeros(
        self.num_layers, emb_seqs_t.size(0), self.hidden_size)).to(device)

        c_0 = Variable(torch.zeros(
        self.num_layers, emb_seqs_t.size(0), self.hidden_size)).to(device)

        ula, (h_out, _) = self.lstm(emb_seqs_t, (h_0, c_0))        
        
        x = h_out.view(-1, self.hidden_size)
        
        q = self.hidden_q1(x)
        q = self.hidden_q2(q)
        
        ifact = self.hidden_if1(x)
        ifact = self.hidden_if2(ifact)
        
        hi = self.hidden_hi1(x)
        hi = self.hidden_hi2(hi)
        
        sjr = self.hidden_sjr1(x)
        sjr = self.hidden_sjr2(sjr)
        
        operator_dict = {"cp": cp, "l1": l1, "l2": l2, "hadamard": hadamard, "summ": summ}
        embedding_operator = operator_dict[operator]
        
        edge_index = sample.edge_label_index
        
        link_embeddings, sjr_embeddings, hi_embeddings, if_embeddings = embedding_operator(x, edge_index),\
                                                                        embedding_operator(sjr, edge_index),\
                                                                        embedding_operator(hi, edge_index),\
                                                                        embedding_operator(ifact, edge_index)
        
        embeddings = [link_embeddings, sjr_embeddings, hi_embeddings, if_embeddings]
        if embedding_operator != cp:
            for i in range(len(embeddings)):
                embeddings[i] = self.post_lp_layers[i](embeddings[i]).squeeze(-1)
            
        return embeddings

In [None]:
operator = "hadamard"

def train(model, optimizer, criterion, mt_coeffs = [0, 0, 0]):
    model.train()
    optimizer.zero_grad()
    sample = model.train_data_a
    z, z_sjr, z_hi, z_ifact = model(sample, batch_list_x, batch_list_owner, operator)
    edge_index = sample.edge_label_index
#    link_embeddings = F.relu(final(link_embeddings))
    link_labels = sample.edge_label
    loss = F.binary_cross_entropy_with_logits(z, link_labels)\
                                               + mt_coeffs[0]*criterion(z_sjr, sample.aux[1])\
                                               + mt_coeffs[1]*criterion(z_hi, sample.aux[2])\
                                               + mt_coeffs[2]*criterion(z_ifact, sample.aux[3])
    loss.backward()
    optimizer.step()
    return loss
    
@torch.no_grad()
def test(model, optimizer, criterion):
    model.eval()
    perfs = []
    aux = []

    for sample in [model.train_data_a, model.test_data_a]: 
        z, z_sjr, z_hi, z_ifact = model(sample, batch_list_x, batch_list_owner, operator)
#        link_embeddings = F.relu(final(link_embeddings))
        link_probs = z.sigmoid()
        link_labels = sample.edge_label
        aux.append([mean_absolute_error(sample.aux[1].cpu(), z_sjr.cpu()),\
                    mean_absolute_error(sample.aux[2].cpu(), z_hi.cpu()),\
                    mean_absolute_error(sample.aux[3].cpu(), z_ifact.cpu())])
        perfs.append([accuracy_score(link_labels.cpu(), link_probs.cpu().round()),\
                      f1_score(link_labels.cpu(), link_probs.cpu().round()),\
                      roc_auc_score(link_labels.cpu(), link_probs.cpu())])
    return perfs, aux

In [None]:
def run(project_name, group, entity, mt_weights, model, optimizer, criterion, i):
    
    wandb.init(project=project_name, entity=entity, group=group) #, group="Experimental")
    wandb.run.name = group + "_" + str(i)
    wandb.run.save()
    
    max_acc_test, max_f1_test, max_roc_auc_test = 0, 0, 0
    max_acc_val, max_f1_val, max_roc_auc_val = 0, 0, 0

    min_mae_sjr_test, min_mae_h_index_test, min_mae_impact_factor_test, min_mae_number_test = 100500, 100500, 100500, 100500
    min_mae_sjr_val, min_mae_h_index_val, min_mae_impact_factor_val, min_mae_number_val = 100500, 100500, 100500, 100500
    
    for epoch in tqdm(range(epochs_per_launch)):
        loss = []
        train_loss = train(model, optimizer, criterion, mt_weights)

        if epoch % 10 == 0:
            metrics, metrics_aux = test(model, optimizer, criterion)
            print("Loss:", float(train_loss), "\nTrain:", metrics[0], "\nTest:", metrics[1])

            print("Aux:", metrics_aux)
            
        wandb.log({"main_train/train_acc":  metrics[0][0], "main_test/test_acc": metrics[1][0],\
                   "main_train/train_f1": metrics[0][1], "main_test/test_f1": metrics[1][1],\
                   "main_train/train_roc_auc": metrics[0][2], "main_test/test_roc_auc": metrics[1][2]})

        if metrics[1][0] > max_acc_test:
            max_acc_test = metrics[1][0]
        if metrics[1][1] > max_f1_test:
            max_f1_test = metrics[1][1]
        if metrics[1][2] > max_roc_auc_test:
            max_roc_auc_test = metrics[1][2]  
        wandb.log({"main_max/test_max_acc": max_acc_test,\
                   "main_max/test_max_f1": max_f1_test,\
                   "main_max/test_max_roc_auc": max_roc_auc_test})
        """
        if metrics[1][0] > max_acc_val:
            max_acc_val = metrics[2][0]
        if metrics[1][1] > max_f1_val:
            max_f1_val = metrics[2][1]
        if metrics[1][2] > max_roc_auc_val:
            max_roc_auc_val = metrics[2][2]  
        wandb.log({"val_max_acc": max_acc_val,\
                   "val_max_f1": max_f1_val,\
                   "val_max_roc_auc": max_roc_auc_val})
        """

        wandb.log({"aux_train/train_mae_sjr":  metrics_aux[0][0], "aux_train/train_mae_h_index": metrics_aux[0][1], "aux_train/train_mae_impact_factor": metrics_aux[0][2],
                   "aux_test/test_mae_sjr":  metrics_aux[1][0], "aux_test/test_mae_h_index": metrics_aux[1][1], "aux_test/test_mae_impact_factor": metrics_aux[1][2],})

        if metrics_aux[1][0] < min_mae_sjr_test:
            min_mae_sjr_test = metrics_aux[1][0]
        if metrics_aux[1][1] < min_mae_h_index_test:
            min_mae_h_index_test = metrics_aux[1][1]
        if metrics_aux[1][2] < min_mae_impact_factor_test:
            min_mae_impact_factor_test = metrics_aux[1][2]  
        wandb.log({"aux_min/test_min_mae_sjr": min_mae_sjr_test,\
                   "aux_min/test_min_mae_h_index": min_mae_h_index_test,\
                   "aux_min/test_min_mae_impact_factor": min_mae_impact_factor_test})
        """
        if metrics_aux[1][0] <  min_mae_sjr_val:
            min_mae_sjr_val = metrics_aux[1][0]
        if metrics_aux[1][1] < min_mae_h_index_val:
            min_mae_h_index_val = metrics_aux[1][1]
        if metrics_aux[1][2] < min_mae_impact_factor_val:
            min_mae_impact_factor_val = metrics_aux[1][2]  
        wandb.log({"test_min_mae_sjr": min_mae_sjr_val,\
                   "test_min_mae_h_index": min_mae_h_index_val,\
                   "test_min_mae_impact_factor": min_mae_impact_factor_val})
        """

In [None]:
project_name = 'GAT_0.1'
epochs_per_launch = 15000
lr = 0.0005
mt_weights = [[0.3, 0.3, 0.3]]
for i in range(10):
    for mt_weight in mt_weights:
        group = "rmgnn(a_original_rgat_res)_split_n0_" + str(mt_weight[0]) + "_"\
        + str(mt_weight[1]) + "_" + str(mt_weight[2]) + "_"\
        + operator + "_" + str(lr) + "_no_wd"
        entity = "sbergraphs"
        device = 'cuda:1' # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        parameters = {"conv_size": [128, 128, 128]}
        model = gs_sum_concatenation_gs(data_citation, parameters,\
                                    train_data_a, val_data_a, test_data_a).to(device) 
        data_citation, data_author, train_data_a = data_citation.to(device), data_author.to(device), train_data_a.to(device) 

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.L1Loss()
        run(project_name, group, entity, mt_weight, model, optimizer, criterion, i)