### The main file for training TGN  for link predictions

Before running this file, be sure to run the data preparation as below if not already exist

1. using python utils/prepare_data.py --years [xxxxxxxxx]
2. get the graph data prepared as told:python utils/preprocess_data.py --data collab --bipartite 'FALSE'
3. finally run the link predictions:this!


In [None]:
# librries
import math
import logging
import time
import sys
import argparse
import pickle
from pathlib import Path
import numpy as np
import torch
import gc

# local
import utils.utils as ut
from utils.preprocessing_servicedata_tgn import run
from utils.utils import EarlyStopMonitor, RandEdgeSampler, get_neighbor_finder
from utils.data_servicepreprocessing_tgn import get_data, compute_time_statistics
from evaluation.evaluation import eval_edge_prediction
from model.tgn import TGN
from service_tgn import service

In [None]:
    parser = argparse.ArgumentParser('TGN temporal link predictions')
    # data 
    parser.add_argument( '--data', type=str, help='collab for our own experiments, others include benchmark wikipedia',
                    default='collab')
    parser.add_argument('--bipartite', default = False, help='Whether the graph is bipartite')
    parser.add_argument('--yrs', default = [2019, 2020], type = int, help='years to work on ')
    parser.add_argument('--authfile', default = 'data/20192020/[2019, 2020]_.pickle', \
                        help='crawed pubmed database')
    parser.add_argument('--savepath', type=str, help='which period to experiment on',
                    default='service/')
    parser.add_argument('--different_new_nodes', default = False,
                    help='Whether to use disjoint set of new nodes for train and val')
    parser.add_argument('--bs', type=int, default= 10, help='Batch_size')
    parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints')
    parser.add_argument('--val_ratio', default = 0.2, type = float,
                        help='the valiation data split')
    parser.add_argument('--f_name', default= 'Vahed',
                        help='first name of the user')
    parser.add_argument('--l_name', default= 'Maroufy',
                        help='last name of the user')
    parser.add_argument('--m_name', default= '',
                        help='middle name of the user')
    parser.add_argument('--exclude', default= '', type=str,
                        help='a list of names(string) to exclude from the collaborator recommendations')
    
    # model 
    parser.add_argument('--n_degree', type=int, default=5, help='Number of neighbors to sample') #25/10 for SAGE
    parser.add_argument('--uniform', action='store_true',
                    help='take uniform sampling from temporal neighbors')
    parser.add_argument('--n_layer', type=int, default=1, help='Number of network layers')
    parser.add_argument('--n_head', type=int, default=2, help='Number of heads used in attention layer')
    parser.add_argument('--drop_out', type=float, default=0.1, help='Dropout probability')
    parser.add_argument('--node_dim', type=int, default=200, help='Dimensions of the node embedding')
    parser.add_argument('--time_dim', type=int, default=5, help='Dimensions of the time embedding') 
    parser.add_argument('--backprop_every', type=int, default= 100, help='Every how many batches to '
                                                                  'backprop')
    ## memory 
    parser.add_argument('--use_memory', default= False,
                    help='Whether to augment the model with a node memory')
    parser.add_argument('--memory_updater', type=str, default="gru", \
                        choices=["gru", "rnn"], help='Type of memory updater')
    parser.add_argument('--memory_update_at_end', default = False,
                    help='Whether to update memory at the end or at the start of the batch')
    parser.add_argument('--memory_dim', type=int, default= 172, help='Dimensions of the memory for '
                                                                'each user') #172
    ## message 
    parser.add_argument('--message_function', type=str, default="identity", \
                        choices=["mlp", "identity"], help='Type of message function')
    parser.add_argument('--use_source_embedding_in_message', default = True,
                    help='Whether to use the embedding of the source node as part of the message')
    parser.add_argument('--use_destination_embedding_in_message', default = True,
                    help='Whether to use the embedding of the destination node as part of the message')
    parser.add_argument('--aggregator', type=str, default="last", help='Type of message '
                                                                        'aggregator')
    parser.add_argument('--message_dim', type=int, default=50, help='Dimensions of the messages') #100
    ## embedding 
    parser.add_argument('--embedding_module', type=str, default="graph_attention", \
                        choices=["graph_attention", "graph_sum", "identity", "time"], help='Type of embedding module')
    parser.add_argument('--randomize_features', default = False,
                    help='Whether to randomize node features')
    parser.add_argument('--node_options', default = 'pubs', help='whether use mesh/pubs for the node features')
    parser.add_argument('--dyrep', default = False,
                    help='Whether to run the dyrep model')
    # training 
    parser.add_argument('--n_epoch', type=int, default=100, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
    parser.add_argument('--patience', type=int, default=20, help='Patience for early stopping')
    parser.add_argument('--n_runs', type=int, default=1, help='Number of runs')
    parser.add_argument('--gpu', type=int, default=0, help='Idx for the gpu to use')
    parser.add_argument('--seed', type=int, default=2021, help='One seed that rules them all')
    
    # service
    parser.add_argument('--firstk', type=int, default=30, help='number of collaborators to recommend')

In [None]:
try:
  args = parser.parse_args([])
except:
  parser.print_help()
  sys.exit(0)

torch.manual_seed(args.seed)
np.random.seed(args.seed)


if args.m_name.strip() == '':
    name_suff = args.f_name + '_'  + args.l_name + '/'
else:
    name_suff = args.f_name + '_' +  args.m_name +  '_' + args.l_name + '/'
# we only save user related files here 
res_path = args.savepath  +  name_suff + args.node_options + '/'

BATCH_SIZE = args.bs
NUM_NEIGHBORS = args.n_degree
NUM_NEG = 1
NUM_EPOCH = args.n_epoch
NUM_HEADS = args.n_head
DROP_OUT = args.drop_out
GPU = args.gpu
DATA = args.data
NUM_LAYER = args.n_layer
LEARNING_RATE = args.lr
NODE_DIM = args.node_dim
TIME_DIM = args.time_dim
USE_MEMORY = args.use_memory
MESSAGE_DIM = args.message_dim
MEMORY_DIM = args.memory_dim

Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
MODEL_SAVE_PATH = f'./saved_models/{args.node_options}-{args.data}.pth'
get_checkpoint_path = lambda \
    epoch: f'./saved_checkpoints/{args.node_options}-{args.data}-{epoch}.pth'

### set up logger
logger= ut.make_log(args)

In [None]:
# processing data part
serv = service(f_name = args.f_name, l_name = args.l_name, m_name = args.m_name, path = '../service/', \
               years = args.yrs, pubfile = args.authfile, exclude_users = args.exclude, options = args.node_options,\
               val_ratio = args.val_ratio)
run(args.data, rand_node_feat = args.randomize_features, bipartite=args.bipartite, path = res_path)

In [None]:
### Extract data for training, validation and testing
node_features, edge_features, full_data, train_data, val_data, test_data = get_data(DATA,\
                              different_new_nodes_between_val_and_test=args.different_new_nodes, \
                              randomize_features=args.randomize_features,
                              path= res_path)
# Initialize training neighbor finder to retrieve temporal graph
train_ngh_finder = get_neighbor_finder(train_data, args.uniform)

# Initialize validation and test neighbor finder to retrieve temporal graph
full_ngh_finder = get_neighbor_finder(full_data, args.uniform)

# Initialize negative samplers. Set seeds for validation and testing so negatives are the same
# across different runs
# NB: in the inductive setting, negatives are sampled only amongst other new nodes
train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations)
val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed= args.seed)
test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=args.seed+1)

device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu'
device = torch.device(device_string)

# Compute time statistics
mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \
  compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps)

In [None]:
for i in range(args.n_runs):
  results_path = "results/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/{}.pkl".format(args.node_options)
  Path("results/").mkdir(parents=True, exist_ok=True)

  # Initialize Model: check this
  tgn = TGN(neighbor_finder=train_ngh_finder, node_features=node_features,
            edge_features=edge_features, device=device,
            n_layers=NUM_LAYER,
            n_heads=NUM_HEADS, dropout=DROP_OUT, use_memory=USE_MEMORY,
            message_dimension=MESSAGE_DIM, memory_dimension=MEMORY_DIM,
            memory_update_at_start=not args.memory_update_at_end,
            embedding_module_type=args.embedding_module, #let's see how it goes
            message_function=args.message_function,
            aggregator_type=args.aggregator,
            memory_updater_type=args.memory_updater,
            n_neighbors=NUM_NEIGHBORS,
            mean_time_shift_src=mean_time_shift_src, std_time_shift_src=std_time_shift_src,
            mean_time_shift_dst=mean_time_shift_dst, std_time_shift_dst=std_time_shift_dst,
            use_destination_embedding_in_message=args.use_destination_embedding_in_message,
            use_source_embedding_in_message=args.use_source_embedding_in_message,
            dyrep=args.dyrep)
  criterion = torch.nn.BCELoss()
  optimizer = torch.optim.Adam(tgn.parameters(), lr=LEARNING_RATE)
  tgn = tgn.to(device)

  num_instance = len(train_data.sources)
  num_batch = math.ceil(num_instance / BATCH_SIZE)

  logger.info('num of training instances: {}'.format(num_instance))
  logger.info('num of batches per epoch: {}'.format(num_batch))
  idx_list = np.arange(num_instance)

  val_aps = []
  epoch_times = []
  total_epoch_times = []
  train_losses = []

  early_stopper = EarlyStopMonitor(max_round=args.patience)
  for epoch in range(NUM_EPOCH):
    start_epoch = time.time()
    ### Training

    # Reinitialize memory of the model at the start of each epoch
    if USE_MEMORY:
      tgn.memory.__init_memory__()

    # Train using only training graph
    tgn.set_neighbor_finder(train_ngh_finder)
    m_loss = []

    logger.info('start {} epoch'.format(epoch))
    for k in range(0, num_batch, args.backprop_every):
      loss = 0
      optimizer.zero_grad()

      # Custom loop to allow to perform backpropagation only every a certain number of batches
      for j in range(args.backprop_every):
        batch_idx = k + j

        if batch_idx >= num_batch:
          continue

        start_idx = batch_idx * BATCH_SIZE
        end_idx = min(num_instance, start_idx + BATCH_SIZE)
        sources_batch, destinations_batch = train_data.sources[start_idx:end_idx], \
                                            train_data.destinations[start_idx:end_idx]
        edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx]
        timestamps_batch = train_data.timestamps[start_idx:end_idx]

        size = len(sources_batch)
        _, negatives_batch = train_rand_sampler.sample(size)

        with torch.no_grad():
          pos_label = torch.ones(size, dtype=torch.float, device=device)
          neg_label = torch.zeros(size, dtype=torch.float, device=device)

        tgn = tgn.train()
        # check if the device are consistant
        pos_prob, neg_prob = tgn.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch,
                                                            timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS)

        loss += criterion(pos_prob.squeeze(), pos_label) + criterion(neg_prob.squeeze(), neg_label)

      loss /= args.backprop_every

      loss.backward()
      optimizer.step()
      m_loss.append(loss.item())

      # Detach memory after 'args.backprop_every' number of batches so we don't backpropagate to
      # the start of time
      if USE_MEMORY:
        tgn.memory.detach_memory()

    epoch_time = time.time() - start_epoch
    epoch_times.append(epoch_time)

    ### Validation
    # Validation uses the full graph
    tgn.set_neighbor_finder(full_ngh_finder)

    if USE_MEMORY:
      # Backup memory at the end of training, so later we can restore it and use it for the
      # validation on unseen nodes
      train_memory_backup = tgn.memory.backup_memory()

    val_ap, val_auc = eval_edge_prediction(model=tgn, negative_edge_sampler=val_rand_sampler,
                                           data=val_data,
                                           n_neighbors=NUM_NEIGHBORS, batch_size=BATCH_SIZE)
    if USE_MEMORY:
      val_memory_backup = tgn.memory.backup_memory()
      # Restore memory we had at the end of training to be used when validating on new nodes.
      # Also backup memory after validation so it can be used for testing (since test edges are
      # strictly later in time than validation edges)
      tgn.memory.restore_memory(train_memory_backup)


    if USE_MEMORY:
      # Restore memory we had at the end of validation
      tgn.memory.restore_memory(val_memory_backup)

    val_aps.append(val_ap)
    train_losses.append(np.mean(m_loss))

    # Save temporary results to disk
    pickle.dump({
      "val_aps": val_aps,
      "train_losses": train_losses,
      "epoch_times": epoch_times,
      "total_epoch_times": total_epoch_times
    }, open(results_path, "wb"))

    total_epoch_time = time.time() - start_epoch
    total_epoch_times.append(total_epoch_time)

    logger.info('epoch: {} took {:.2f}s'.format(epoch, total_epoch_time))
    logger.info('Epoch mean loss: {}'.format(np.mean(m_loss)))
    logger.info(
      'val auc: {}'.format(val_auc))
    logger.info(
      'val ap: {}'.format(val_ap))

    # Early stopping
    if early_stopper.early_stop_check(val_ap):
      logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
      logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
      best_model_path = get_checkpoint_path(early_stopper.best_epoch)
      tgn.load_state_dict(torch.load(best_model_path))
      logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
      tgn.eval()
      break
    else:
      torch.save(tgn.state_dict(), get_checkpoint_path(epoch))

  # Training has finished, we have loaded the best model, and we want to backup its current
  # memory (which has seen validation edges) so that it can also be used when testing on unseen
  # nodes
  if USE_MEMORY:
    val_memory_backup = tgn.memory.backup_memory()

  ### Test
  tgn.embedding_module.neighbor_finder = full_ngh_finder
    
  ## define the test batches 
  num_test = len(test_data.sources)
  test_batches = math.ceil(num_test/ BATCH_SIZE)
  all_probs = []
  with torch.no_grad():
      for i in range(test_batches):
          # define start_idx and end_idx
          start_idx = i * BATCH_SIZE
          end_idx = min(num_test, start_idx + BATCH_SIZE)

          sources_batch, destinations_batch = test_data.sources[start_idx:end_idx], \
                                              test_data.destinations[start_idx:end_idx]
          edge_idxs_batch = test_data.edge_idxs[start_idx: end_idx]
          timestamps_batch = test_data.timestamps[start_idx:end_idx]
          size = len(sources_batch)
          _, negatives_batch = test_rand_sampler.sample(size)

          # need to modify this negative_batch, cuz we won't have one
          probs = ut.compute_edge_probabilities(tgn, sources_batch, destinations_batch, negatives_batch,
                                                                timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS)
          all_probs.extend(probs)   
    
  #probs_flat = [item for sublist in all_probs for item in sublist]  
  probs_flat = all_probs
  probs_flat = [i.reshape(all_probs[0].shape) for i in probs_flat]
  # let's write the prediction results
  author_dict = pickle.load(open('data/' + res_path + 'author_refs.pickle', 'rb'))
  ut.recommend(logger = logger, test_dst = test_data.destinations, probs = probs_flat, \
                device = device, author_dict = author_dict, firstk = args.firstk, path = res_path, \
                f_name = args.f_name, l_name = args.l_name, m_name = args.m_name)

  if USE_MEMORY:
    tgn.memory.restore_memory(val_memory_backup)


  # Save results for this run
  pickle.dump({
    "val_aps": val_aps,
    "epoch_times": epoch_times,
    "train_losses": train_losses,
    "total_epoch_times": total_epoch_times
  }, open(results_path, "wb"))

  logger.info('Saving TGN model')
  if USE_MEMORY:
    # Restore memory at the end of validation (save a model which is ready for testing)
    tgn.memory.restore_memory(val_memory_backup)
  torch.save(tgn.state_dict(), MODEL_SAVE_PATH)
  logger.info('TGN model saved')