### The main file for training SAGE graph for link predictions


this main file should be similar to training main file, only difference is that 
* we'll give them the opportunity to filter the affiliations.
* also we should aime to give top 30 instead (not too many), 
* with each recommended collaborator having pubmed page listed ['link'] = 'https://www.ncbi.nlm.nih.gov/pubmed/?term=' + pmid + author_in_concern

now you will get into dataset preparation(graph construction), train, valid and test split, then train the model and save the results

In [None]:
# imports
from os.path import exists 
import argparse
import gc 
import itertools
import numpy as np
import pandas as pd 
import scipy.sparse as sp
from sklearn.metrics import roc_auc_score
import pickle

# torch 
import torch
import torch.nn as nn
import torch.nn.functional as F

# dgl 
import dgl
from dgl.data.utils import save_graphs, load_graphs

# local
from prepare_servicedata import yearly_authors
from prepare_servicedataset import CollabDataset
from service import service
from models import GraphSAGE, MLPPredictor, DotPredictor, Model
import utils as ut 

In [None]:
    parser = argparse.ArgumentParser(description='SAGE graph for collab recommendation service')
    # processing data
    parser.add_argument('--yrs', default= [2019, 2020, 2021], type = int,
                        help='years of authors to build the GraphSAGE')
    parser.add_argument('--authfile', default= 'sage/data/20192020/[2019, 2020]_.pickle',
                        help='original crawled pubmed pickle that contains info about publication and authors')
    parser.add_argument('--k', default = 1, 
                        help='negative sample ratio (vs existing links)')
    parser.add_argument('--val_ratio', default = 0.2, type = float,
                        help='the valiation data split')
    parser.add_argument('--f_name', default= 'Melissa',
                        help='first name of the user')
    parser.add_argument('--l_name', default= 'Valerio-Shewmaker',
                        help='last name of the user')
    parser.add_argument('--m_name', default= 'A',
                        help='middle name of the user')
    parser.add_argument('--exclude', default= '', type=str,
                        help='a list of names(string) to exclude from the collaborator recommendations')
    # model
    parser.add_argument('--in_feats', default = 2000,
                        help='input dimension of GraphSAGE, tfidf vector size 2000')
    parser.add_argument('--node_options', default= 'pubs', type = str,
                        help='the info to build node features on, choose from: pubs, mesh')
    parser.add_argument('--h_feats', default = 50,
                        help='embedding dimension of GraphSAGE output')
    parser.add_argument('--out_feats', default = 2,
                        help='output dimension of model, link prediction of 2')
    parser.add_argument('--sage_pool', default= ['gcn','gcn'], type = str,
                        help='aggregation types for the 2-layer GraphSAGE, choose from : mean, pool, gcn')
    # training 
    parser.add_argument('--lr', default= 0.003, # try another lr 0.001 to 0.005 
                        help='learning rate of the optimizer')
    parser.add_argument('--wd', default= 0.00001,
                        help='weigth decay of the optimizer')    
    parser.add_argument('--epochs', default= 100,
                        help='training epochs')
    parser.add_argument('--GPU', default= 0,
                        help='index for GPU')
    parser.add_argument('--save_path', default= 'service/',
                        help='main path to store model and prediction results for GraphSAGE')
    # recommend 
    parser.add_argument('--firstk', default= 30,
                        help='number of collaborators to show on the recommendation list')
    args = parser.parse_args([])

In [None]:
def main(args):

    #part 2. training
    if args.m_name.strip() == '':
        name_suff = args.f_name + '_'  + args.l_name + '/'
    else:
        name_suff = args.f_name + '_' +  args.m_name +  '_' + args.l_name + '/'

    res_path = args.save_path  +  name_suff +  args.node_options + '/' #create new path

    if exists(res_path + 'sageGraph.bin'):
        graph = load_graphs(res_path + "sageGraph.bin")[0][0]
        if not 'label' in graph.edata:
            graph.edata['label'] = torch.ones(graph.num_edges(), dtype= torch.long)
            save_graphs(res_path + 'sageGraph.bin', [graph])
    else:
        if not (exists(res_path + 'collabs.csv') and exists(res_path + 'authors.csv')):
            serv = service(f_name = args.f_name, l_name = args.l_name, m_name = args.m_name, \
                     path = args.save_path, years = args.yrs, pubfile = args.authfile, \
                     exclude_users = args.exclude, options = args.node_options, val_ratio = args.val_ratio)
        # dataset processing (graph)
        dataset = CollabDataset(raw_dir = res_path)
        graph = dataset[0]
        # add labels
        if not 'label' in graph.edata: 
            graph.edata['label'] = torch.ones(graph.num_edges(), dtype= torch.long)
        save_graphs(res_path  + 'sageGraph.bin', [graph])
    print(graph)
    # print(graph.device)

    # prepare negatives as well 
    graph = ut.construct_negEdges(graph, k = args.k, newNode = False , service = True)                                        
    # train, validation, and test split 
    train_g, val_g, test_g =  ut.inductive_edge_split(graph, newNode = False)
    train_feats, train_y = ut.feat_labels(train_g)
    val_feats, val_y = ut.feat_labels(val_g)
    test_feats, test_y = ut.feat_labels(test_g)

    # logs                     
    logger = ut.create_log(args)             
    # device
    device_string = 'cuda:{}'.format(args.GPU) if torch.cuda.is_available() else 'cpu'
    device = torch.device(device_string)
    gc.collect()
    torch.cuda.empty_cache()  
    with torch.cuda.device(device_string):
        torch.cuda.empty_cache()   

    # model 
    model = Model(in_features = args.in_feats, hidden_features = args.h_feats, out_features =args.out_feats, \
                  pool = args.sage_pool)
    opt = torch.optim.Adam(model.parameters(), lr= args.lr,  weight_decay = args.wd)
    loss_fcn = nn.CrossEntropyLoss()  

    model = ut.train_epochs(logger = logger, epochs=  args.epochs, model = model, 
                    train_g = train_g, train_feats = train_feats, train_y = train_y, \
                     val_g = val_g, val_feats = val_feats, val_y = val_y, \
                     new_val_g =val_g, new_val_feats = val_feats, new_val_y = val_y,\
                     device = device, opt = opt, loss_fcn = loss_fcn, path = res_path, every = 5, newNode = False )


    # recommend
    author_dict = pickle.load(open(res_path + 'author_refs.pickle', 'rb'))
    ut.recommend(logger = logger, model = model, \
                    test_g = test_g, test_feats = test_feats, test_y = test_y,\
                    device =device, author_dict = author_dict, firstk = args.firstk, path = res_path, \
                    f_name = args.f_name, l_name = args.l_name, m_name = args.m_name)


In [None]:
if __name__ == "__main__": 
    main(args)