### The main file for training SAGE graph for link predictions

Before running this file, be sure to run the data preparation as below if not already exist

this main file should be similar to training main file, only difference is that 
* we will give them the opportunity to filter the affiliations.
* also we should aime to give top 10 instead (not too many), 
* with each recommended collaborator having their article title listed & link:pub_details['link'] = 'https://www.ncbi.nlm.nih.gov/pubmed/?term=' + pmid

now you will get into dataset preparation(graph construction), train, valid and test split, then train the model and save the results

In [None]:
# imports
import argparse
import itertools
import numpy as np
import pandas as pd 
import scipy.sparse as sp
from sklearn.metrics import roc_auc_score
# import pickle

# torch 
import torch
import torch.nn as nn
import torch.nn.functional as F

# dgl 
import dgl
from dgl.data.utils import save_graphs, load_graphs

# local
from prepare_data import yearly_authors
from prepare_dataset import CollabDataset
from models import GraphSAGE, MLPPredictor, DotPredictor
import utils as ut 

In [None]:
def main(args):
    
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

    # data first
    # yearly_authors(authfile = args.authfile, years = args.yrs, savepath = args.save_path + 'data/') 
    # dataset first time processing 
    # dataset = CollabDataset()
    # graph = dataset[0].to(device)
    # save_graphs(args.save_path + 'data/sageGraph.bin', [graph])
    ## later just load 
    graph = load_graphs(args.save_path + "data/sageGraph.bin")[0][0].to(device)
    print('graph info at a glance:')
    print(graph)
    
    # splits into train, valid and test 
    outputs = ut.split_edges(graph, newNodeMask = True)
    # depending on how long the output is:
    outgs = ut.construct_wEdges(graph, outputs)
    train_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g, \
                 new_val_pos_g, new_val_neg_g, new_test_pos_g,  new_test_neg_g = outgs 
    
    # model and prediction model
    
    model = GraphSAGE(train_g.ndata['feat'].shape[1], h_feats = args.h_feats, pool = [args.sage_pool, args.sage_pool]).to(device)
    predmodel = MLPPredictor(h_feats = args.h_feats).to(device)
    # option
    # predmodel = DotPredictor()
    optimizer = torch.optim.Adam(itertools.chain(model.parameters(), predmodel.parameters()), lr=args.lr)
    h, model, predmodel = ut.train_epochs(model, predmodel, train_g, train_pos_g, train_neg_g, optimizer, device, \
                                          val_pos_g, val_neg_g, \
                                          epochs = args.epochs, every = 5, path = args.save_path, \
                                          new_val_pos_g = new_val_pos_g, new_val_neg_g = new_val_neg_g )

    ut.pred(h, predmodel, test_pos_g, test_neg_g, new_test_pos_g = new_test_pos_g, new_test_neg_g = new_test_neg_g, \
            save = args.save_path)


In [None]:
# keep the original data statistics, and more epochs
if __name__ == "__main__": 

    parser = argparse.ArgumentParser(description='SAGE graph for link prediction')
    parser.add_argument('--yrs', default= [2018, 2019, 2020], type = int,
                        help='years to build the GraphSAGE')
    parser.add_argument('--authfile', default= '../DLrec/newdata/processed_pubs.pickle',
                        help='original crawled pubmed pickle that contains info about publication and authors')
    parser.add_argument('--h_feats', default = 200,
                        help='embedding dimension of GraphSAGE output')
    parser.add_argument('--sage_pool', default= 'gcn',
                        help='aggregation type for GraphSAGE, choose from : mean, pool, gcn')
    parser.add_argument('--lr', default= 0.001,
                        help='learning rate of the optimizer')
    parser.add_argument('--epochs', default= 100,
                        help='training epochs')
    parser.add_argument('--save_path', default= 'sage/',
                        help='main path to store model and prediction results for GraphSAGE')
    args = parser.parse_args([])
    main(args = args)