In [None]:
# basics 
from os.path import exists
import math
import logging
import time
import sys
import argparse
import pickle
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch 
from torch_sparse import SparseTensor
from sklearn.metrics import average_precision_score, roc_auc_score

# pytorch geometric 
from torch_geometric.utils import structured_negative_sampling
from torch_geometric.data import TemporalData

# local
from prepare_data_bsl import yearly_authors
from prepare_dataset_bsl import CollabDataset
import utils_bsl as ut 
from myLightGCN import LightGCN

In [None]:
    # define arguments
    parser = argparse.ArgumentParser('baseline link predictions--lightGCN')
    # data 
    parser.add_argument( '--data', type=str, help='collab for our own experiments',
                    default='collab')
    parser.add_argument('--yrs', default = [2010,2011], type = int, help='years to work on')
    parser.add_argument('--authfile', default = '../../part0_GrantRec/newdata/processed_pubs.pickle', \
                        help='crawed pubmed database')
    parser.add_argument('--inpath', default = '../sage/mesh/20102011/', \
                        help="since we are using the same dataset as the SAGE, we can reuse the processed dataset")
    parser.add_argument('--node_options', default = 'mesh', \
                        help="node feature options, choose from mesh/pubs")
    parser.add_argument('--savepath', type=str, help='path to save the data',
                    default='20102011_mesh/')    
    parser.add_argument('--bs', type=int, default= 1024, help='Batch_size')

    # model 
    parser.add_argument('--embedding_dim', type=int, default=200, help='embedding dimensions') 
    parser.add_argument('--num_layers', type=int, default=2, help='number of lightgcn layers')
    parser.add_argument('--lr', type=float, default=0.005, help='Learning rate')
    parser.add_argument('--gpu', type=int, default=0, help='GPU index to use if built trees on GPU')
    # training 
    parser.add_argument('--n_epoch', type=int, default= 100, help='Number of epochs')
    parser.add_argument('--seed', type=int, default=2021, help='One seed that rules them all')
    args = parser.parse_args([])
    

In [None]:
def main(args) 
    if not (exists(args.inpath + 'collabs_masks.csv') and exists(args.inpath + 'node_feats.npy')):
        yearly_authors(authfile = args.authfile, years = args.yrs, savepath = args.inpath, options = args.node_options) 
        # dataset processing (graph)
        dataset = CollabDataset(raw_dir = args.inpath)
    
    # original data
    device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
    print(f"Using device {device}.")
    df = pd.read_csv(args.inpath + 'collabs_masks.csv')
    node_feats = np.load(args.inpath + 'node_feats.npy')
    node_feats = torch.tensor(node_feats, dtype = torch.float).to(device)
    
    # prepare data 
    data = TemporalData(
    src= torch.tensor(df.new_author.to_list()),
    dst=torch.tensor(df.new_coauthor.to_list()),
    t= torch.tensor(df.timestamp.to_list()))
    data.train_mask = torch.tensor(df.train_mask.to_list(), dtype= torch.bool)
    data.val_mask = torch.tensor(df.val_mask.to_list(), dtype= torch.bool)
    data.test_mask = torch.tensor(df.test_mask.to_list(), dtype= torch.bool)
    data.edge_index  = torch.stack([data.src, data.dst])
    train, val, test = data[data.train_mask], data[data.val_mask], data[data.test_mask]
    
    
    logger = ut.create_log(args)
    model = LightGCN(num_node = node_feats.shape[0], num_feat = node_feats.shape[1], \
                 embedding_dim = args.embedding_dim, num_layers = args.num_layers)
    model = model.to(device)
    
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    
    # training and validation 
    train_losses = []
    val_losses = []

    train_iter = math.ceil(train.num_events/args.bs) 
    val_iter = math.ceil(val.num_events/args.bs) 

    for epoch in range(args.n_epoch):
        total_loss = 0. 
        model.train()

        for i in range(train_iter):

            optimizer.zero_grad()

            # Sample negative destination nodes.
            src, pos_dsc, neg_dsc = ut.sample_mini_batch(batch_size = args.bs, i = i, edge_index = train.edge_index)
            pos = torch.cat([src.reshape(1,-1), pos_dsc.reshape(1,-1)])
            neg = torch.cat([src.reshape(1,-1), neg_dsc.reshape(1,-1)])
            batch_edge_index = torch.cat ([pos, neg], dim = 1 ).to(device) # 2x (batchdize *2)
            #break 

            preds = model.predict_link( X= node_feats, edge_index = batch_edge_index)  
            #print(preds.shape, preds)

            batch_edge_label= torch.cat([torch.ones(pos.shape[1], preds.shape[1], dtype = torch.float),
                                         torch.zeros(neg.shape[1], preds.shape[1], dtype = torch.float)]).to(device)
            #print(batch_edge_label.shape, batch_edge_label)
            loss = criterion(preds, batch_edge_label)
            #print(loss)
            #break

            loss.backward()
            optimizer.step()
            total_loss += float(loss) * src.shape[0]

        train_losses.append(total_loss)

        val_loss  =0.
        model.eval()
        for j in range(val_iter):
            with torch.no_grad():
                # Sample negative destination nodes.
                src, pos_dsc, neg_dsc = ut.sample_mini_batch(batch_size = args.bs, i = j, edge_index = val.edge_index)
                pos = torch.cat([src.reshape(1,-1), pos_dsc.reshape(1,-1)])
                neg = torch.cat([src.reshape(1,-1), neg_dsc.reshape(1,-1)])

                batch_edge_index = torch.cat ([pos, neg], dim = 1 ).to(device) # 2x (batchdize *2)
                preds = model.predict_link( X= node_feats, edge_index = batch_edge_index) 

                batch_edge_label= torch.cat([torch.ones(pos.shape[1], preds.shape[1], dtype = torch.float),
                                             torch.zeros(neg.shape[1], preds.shape[1], dtype = torch.float)]).to(device)
                loss = criterion(preds, batch_edge_label)

                val_loss += float(loss)*src.shape[0]

        val_losses.append(val_loss)
    
    plt.plot(range(1,args.n_epoch +1), train_losses, 'g', label='Training loss')
    #plt.plot(range(1, args.n_epoch +1), val_losses, 'b', label='validation loss')
    plt.title('Training loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show();
    plt.plot(range(1, args.n_epoch +1), val_losses, 'b', label='validation loss')
    plt.title('validation loss')
    plt.show();
    
    
    #test
    predictions = []
    labels = []
    test_iter = math.ceil(test.num_events/args.bs) 

    model.eval()

    for j in range(test_iter):
        with torch.no_grad():
            # Sample negative destination nodes.
            src, pos_dsc, neg_dsc = ut.sample_mini_batch(batch_size = args.bs, i = j, edge_index = test.edge_index)
            pos = torch.cat([src.reshape(1,-1), pos_dsc.reshape(1,-1)])
            neg = torch.cat([src.reshape(1,-1), neg_dsc.reshape(1,-1)])

            batch_edge_index = torch.cat ([pos, neg], dim = 1 ).to(device) # 2x (batchdize *2)
            preds = model.predict_link( X= node_feats, edge_index = batch_edge_index) 

            batch_edge_label= torch.cat([torch.ones(pos.shape[1], preds.shape[1], dtype = torch.float),
                                         torch.zeros(neg.shape[1], preds.shape[1], dtype = torch.float)]).to(device)
            loss = criterion(preds, batch_edge_label)

            val_loss += float(loss)*src.shape[0]

            predictions.append(preds)
            labels.append(batch_edge_label)

    flat_preds = torch.cat(predictions)  
    flat_labels = torch.cat(labels)
    ap = average_precision_score(flat_labels.cpu().numpy(), flat_preds.sigmoid().detach().cpu().numpy())
    auc = roc_auc_score(flat_labels.cpu().numpy(), flat_preds.sigmoid().detach().cpu().numpy())
    print('test auc = {}, test ap = {}'.format(auc, ap))

In [None]:
if __name__ == "__main__": 
    main(args)