## Environmental Setup

Install necessary modules and clone the Github repo

In [None]:
!pip install wget

import os
import shutil
import wget
from urllib.parse import urlparse

%cd /content/

repo_path = "https://github.com/KyawYeThu-11/DyGLib.git"
repo_name = os.path.splitext(os.path.basename(urlparse(repo_path).path))[0]

if not os.path.exists(repo_name):
  !git clone $repo_path
  !pip install -r DyGLib/requirements.txt

Collecting wget
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9656 sha256=f4661507a08243e3e53c6644ee95ac2a396713f9f4c5b3f88fca2beb97278f36
  Stored in directory: /root/.cache/pip/wheels/8b/f1/7f/5c94f0a7a505ca1c81cd1d9208ae2064675d97582078e6c769
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
/content
Cloning into 'DyGLib'...
remote: Enumerating objects: 435, done.[K
remote: Counting objects: 100% (378/378), done.[K
remote: Compressing objects: 100% (260/260), done.[K
remote: Total 435 (delta 137), reused 332 (delta 111), pack-reused 57[K
Receiving objects: 100% (435/435), 226.83 MiB | 15.72 MiB/s, done.
Resolving deltas: 100% (149/149), done.
Updating files: 100% (231/231), done.
Collecting torch-geometric (from -r DyGLib/requirem

In [None]:
%cd /content/DyGLib

/content/DyGLib


In [None]:
import logging
import time
import sys
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
import warnings
import shutil
import json
import torch
import torch.nn as nn

from models.DyGFormer import DyGFormer
from models.modules import MergeLayer
from utils.utils import set_random_seed, convert_to_gpu, get_parameter_sizes, create_optimizer
from utils.utils import get_neighbor_sampler, NegativeEdgeSampler
from utils.metrics import get_link_prediction_metrics
from utils.DataLoader import get_idx_data_loader, get_link_prediction_data
from utils.EarlyStopping import EarlyStopping
from utils.load_configs import get_link_prediction_args

# Data Preprocessing

The dataset, `train_0`, represents brain activations of a particular subject.

In [None]:
dataset = 'train_0'
os.makedirs(f'./DG_data/{dataset}', exist_ok=True)
shutil.copy(f'./DG_data/connectome/5-percentile/Train_Data_csv/{dataset}.csv', f'./DG_data/{dataset}')

'./DG_data/train_0/train_0.csv'

Or if you want to test with existing datasets.

In [None]:
# datasets = ['wikipedia', 'reddit', 'mooc', 'lastfm', 'myket', 'enron', 'SocialEvo', 'uci', 'Flights', 'CanParl', 'USLegis', 'UNtrade', 'UNvote', 'Contacts']
# dataset = 'CanParl'
# if not os.path.exists(dataset):
#   download_link = f"https://zenodo.org/records/7213796/files/{dataset}.zip"
#   wget.download(download_link, f"{dataset}.zip")
#   !unzip *.zip
#   !mv $dataset ./DG_data/

In [None]:
%cd preprocess_data
!python preprocess_data.py  --dataset_name $dataset
%cd ..

/content/DyGLib/preprocess_data
preprocess dataset train_0...
number of nodes  400
number of node features  172
number of edges  400000
number of edge features  1
train_0 is processed successfully.
/content/DyGLib


# Training

### Constants and Helper Functions

In [None]:
class Args:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.batch_size = 1000
        self.model_name = 'DyGFormer'
        self.gpu = 0
        self.num_neighbors = 64
        self.sample_neighbor_strategy = 'recent'
        self.time_scaling_factor = 1e-6
        self.num_walk_heads = 8
        self.num_heads = 2
        self.num_layers = 2
        self.load_checkpoint = False
        self.walk_length = 1
        self.time_gap = 2000
        self.time_feat_dim = 100
        self.position_feat_dim = 172
        self.edge_bank_memory_mode = 'unlimited_memory'
        self.time_window_mode = 'fixed_proportion'
        self.patch_size = 2
        self.channel_embedding_dim = 50
        self.max_input_sequence_length = 64
        self.learning_rate = 0.0005
        self.dropout = 0.2
        self.num_epochs = 10
        self.optimizer = 'Adam'
        self.weight_decay = 0.0
        self.patience = 20
        self.val_ratio = 0.1
        self.test_ratio = 0.1
        self.num_runs = 3
        self.test_interval_epochs = 5
        self.negative_sample_strategy = 'random'

    def __str__(self):
        properties = [f"{key}={value}" for key, value in self.__dict__.items()]
        return f"Args({', '.join(properties)})"

    def __repr__(self):
        return self.__str__()

# Create an instance of Args with the loaded configuration
args = Args(dataset_name = dataset)
args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu'

In [None]:
def set_up_logger(args):
        # set up logger
        logging.basicConfig(level=logging.INFO)
        logger = logging.getLogger()
        logger.setLevel(logging.DEBUG)
        os.makedirs(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_model_name}/", exist_ok=True)
        # create file handler that logs debug and higher level messages
        fh = logging.FileHandler(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_model_name}/{str(time.time())}.log")
        fh.setLevel(logging.DEBUG)
        # create console handler with a higher log level
        ch = logging.StreamHandler()
        ch.setLevel(logging.WARNING)
        # create formatter and add it to the handlers
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        ch.setFormatter(formatter)
        # add the handlers to logger
        logger.addHandler(fh)
        logger.addHandler(ch)

        return logger, fh, ch

In [None]:
def process_batch(batch_idx, data_indices, neg_edge_sampler, data, mode):
    data_indices = data_indices.numpy()
    batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids = \
                data.src_node_ids[data_indices], data.dst_node_ids[data_indices], \
                data.node_interact_times[data_indices], data.edge_ids[data_indices]


    if mode == 'train':
      _, batch_neg_dst_node_ids = neg_edge_sampler.sample(size=len(batch_src_node_ids))
      batch_neg_src_node_ids = batch_src_node_ids
    elif mode == 'val':
        if neg_edge_sampler.negative_sample_strategy != 'random':
            batch_neg_src_node_ids, batch_neg_dst_node_ids = neg_edge_sampler.sample(size=len(batch_src_node_ids),
                                                                                                  batch_src_node_ids=batch_src_node_ids,
                                                                                                  batch_dst_node_ids=batch_dst_node_ids,
                                                                                                  current_batch_start_time=batch_node_interact_times[0],
                                                                                                  current_batch_end_time=batch_node_interact_times[-1])
        else:
            _, batch_neg_dst_node_ids = neg_edge_sampler.sample(size=len(batch_src_node_ids))
            batch_neg_src_node_ids = batch_src_node_ids

    # get temporal embedding of source and destination nodes
    # two Tensors, with shape (batch_size, node_feat_dim)
    batch_src_node_embeddings, batch_dst_node_embeddings = \
          model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
                                                                          dst_node_ids=batch_dst_node_ids,
                                                                          node_interact_times=batch_node_interact_times)

    # get temporal embedding of negative source and negative destination nodes
    # two Tensors, with shape (batch_size, node_feat_dim)
    batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
          model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
                                                                          dst_node_ids=batch_neg_dst_node_ids,
                                                                          node_interact_times=batch_node_interact_times)





    # get positive and negative probabilities, shape (batch_size, )
    positive_probabilities = model[1](input_1=batch_src_node_embeddings, input_2=batch_dst_node_embeddings).squeeze(dim=-1).sigmoid()
    negative_probabilities = model[1](input_1=batch_neg_src_node_embeddings, input_2=batch_neg_dst_node_embeddings).squeeze(dim=-1).sigmoid()

    predicts = torch.cat([positive_probabilities, negative_probabilities], dim=0)
    labels = torch.cat([torch.ones_like(positive_probabilities), torch.zeros_like(negative_probabilities)], dim=0)

    loss = loss_func(input=predicts, target=labels)

    return predicts, loss, labels

In [None]:
def evaluate_model_link_prediction(model, neighbor_sampler, evaluate_idx_data_loader, evaluate_data):
        model.eval()

        model[0].set_neighbor_sampler(neighbor_sampler)

        with torch.no_grad():
          # store evaluate losses and metrics
          evaluate_losses, evaluate_metrics = [], []
          evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
          for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
            predicts, loss, labels = process_batch(batch_idx, evaluate_data_indices, val_neg_edge_sampler, evaluate_data, 'val')

            evaluate_losses.append(loss.item())
            evaluate_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))

            evaluate_idx_data_loader_tqdm.set_description(f'evaluate for the {batch_idx + 1}-th batch, evaluate loss: {loss.item()}')

          return evaluate_losses, evaluate_metrics

### Main

In [None]:
 # get data for training, validation and testing
node_raw_features, edge_raw_features, full_data, train_data, val_data, test_data, _, _ = \
get_link_prediction_data(dataset_name=args.dataset_name, val_ratio=args.val_ratio, test_ratio=args.test_ratio)

# initialize training neighbor sampler to retrieve temporal graph
train_neighbor_sampler = get_neighbor_sampler(data=train_data, sample_neighbor_strategy=args.sample_neighbor_strategy,
                                                  time_scaling_factor=args.time_scaling_factor, seed=0)

# initialize validation and test neighbor sampler to retrieve temporal graph
full_neighbor_sampler = get_neighbor_sampler(data=full_data, sample_neighbor_strategy=args.sample_neighbor_strategy,
                                                 time_scaling_factor=args.time_scaling_factor, seed=1)

# initialize negative samplers, set seeds for validation and testing so negatives are the same across different runs
# in the inductive setting, negatives are sampled only amongst other new nodes
# train negative edge sampler does not need to specify the seed, but evaluation samplers need to do so
train_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=train_data.src_node_ids, dst_node_ids=train_data.dst_node_ids)
val_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=full_data.src_node_ids, dst_node_ids=full_data.dst_node_ids, seed=0)
test_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=full_data.src_node_ids, dst_node_ids=full_data.dst_node_ids, seed=2)

# get data loaders
train_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(train_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)
val_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(val_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)
test_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(test_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)

val_metric_all_runs, test_metric_all_runs  = [], []

The dataset has 400000 interactions, involving 400 different nodes
The training dataset has 251432 interactions, involving 360 different nodes
The validation dataset has 40000 interactions, involving 400 different nodes
The test dataset has 40000 interactions, involving 400 different nodes
The new node validation dataset has 7542 interactions, involving 304 different nodes
The new node test dataset has 8502 interactions, involving 288 different nodes
40 nodes were used for the inductive testing, i.e. are never seen during training


In [None]:

for run in range(args.num_runs):
    set_random_seed(seed=run)
    args.seed = run

    args.save_model_name = f'{args.model_name}_seed{args.seed}'
    save_model_folder = f"saved_models/link_prediction/{args.dataset_name}"
    checkpoint_path = os.path.join(save_model_folder, f'{args.save_model_name}.pth')

    logger, fh, ch = set_up_logger(args)

    run_start_time = time.time()
    logger.info(f"********** Run {run + 1} starts. **********")

    logger.info(f'configuration is {args}')

    # Initialize the model
    dynamic_backbone = DyGFormer(node_raw_features=node_raw_features, edge_raw_features=edge_raw_features, neighbor_sampler=train_neighbor_sampler,
                                time_feat_dim=args.time_feat_dim, channel_embedding_dim=args.channel_embedding_dim, patch_size=args.patch_size,
                                num_layers=args.num_layers, num_heads=args.num_heads, dropout=args.dropout,
                                max_input_sequence_length=args.max_input_sequence_length, device=args.device)

    link_predictor = MergeLayer(input_dim1=node_raw_features.shape[1], input_dim2=node_raw_features.shape[1],
                                    hidden_dim=node_raw_features.shape[1], output_dim=1)
    model = nn.Sequential(dynamic_backbone, link_predictor)

    # log the model structure
    logger.info(f'model -> {model}')
    logger.info(f'model name: {args.model_name}, #parameters: {get_parameter_sizes(model) * 4} B, '
                    f'{get_parameter_sizes(model) * 4 / 1024} KB, {get_parameter_sizes(model) * 4 / 1024 / 1024} MB.')

    # Create the optimizer with specified parameters
    optimizer = create_optimizer(model=model, optimizer_name=args.optimizer, learning_rate=args.learning_rate, weight_decay=args.weight_decay)

    # Convert the model to GPU if available
    model = convert_to_gpu(model, device=args.device)

    os.makedirs(save_model_folder, exist_ok=True)

    # Initialize early stopping mechanism with the given patience and save model parameters
    early_stopping = EarlyStopping(patience=args.patience, save_model_folder=save_model_folder,
                                       save_model_name=args.save_model_name, logger=logger, model_name=args.model_name)

    epoch_resumed = 0
    # Load checkpoint if specified and exists
    if args.load_checkpoint == True and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=args.device)
        logger.info(f"load model {checkpoint_path}")
        early_stopping.load_checkpoint(model, checkpoint)
        epoch_resumed = checkpoint['epoch'] + 1
        print(f"Epoch resumed: {epoch_resumed}")

    # Define the binary cross-entropy loss function
    loss_func = nn.BCELoss()

    for epoch in range(epoch_resumed, args.num_epochs, 1):

        # Training for an epoch starts
        model.train()

        model[0].set_neighbor_sampler(train_neighbor_sampler)
        train_losses, train_metrics = [], []
        train_idx_data_loader_tqdm = tqdm(train_idx_data_loader, ncols=120)

        for batch_idx, train_data_indices in enumerate(train_idx_data_loader_tqdm):
          # Process each batch and compute predictions, loss, and labels
          predicts, loss, labels = process_batch(batch_idx, train_data_indices, train_neg_edge_sampler, train_data, 'train')

          train_losses.append(loss.item())
          train_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          train_idx_data_loader_tqdm.set_description(f'Epoch: {epoch}, train for the {batch_idx + 1}-th batch, train loss: {loss.item()}')

        # Get validation metrics
        val_losses, val_metrics = evaluate_model_link_prediction(model, full_neighbor_sampler, val_idx_data_loader, val_data)

        # Log the training and validation metrics
        logger.info(f'Epoch: {epoch + 1}, learning rate: {optimizer.param_groups[0]["lr"]}, train loss: {np.mean(train_losses):.4f}')
        for metric_name in train_metrics[0].keys():
            logger.info(f'train {metric_name}, {np.mean([train_metric[metric_name] for train_metric in train_metrics]):.4f}')
        logger.info(f'validate loss: {np.mean(val_losses):.4f}')
        for metric_name in val_metrics[0].keys():
            logger.info(f'validate {metric_name}, {np.mean([val_metric[metric_name] for val_metric in val_metrics]):.4f}')


        # perform testing once after test_interval_epochs
        if (epoch + 1) % args.test_interval_epochs == 0:
            test_losses, test_metrics = evaluate_model_link_prediction(model, full_neighbor_sampler, test_idx_data_loader, test_data)

            # Log the test metrics
            logger.info(f'test loss: {np.mean(test_losses):.4f}')
            for metric_name in test_metrics[0].keys():
                logger.info(f'test {metric_name}, {np.mean([test_metric[metric_name] for test_metric in test_metrics]):.4f}')

        # select the best model based on all the validate metrics
        val_metric_indicator = []
        for metric_name in val_metrics[0].keys():
          val_metric_indicator.append((metric_name, np.mean([val_metric[metric_name] for val_metric in val_metrics]), True))

        early_stop = early_stopping.step(val_metric_indicator, epoch, model)

        if early_stop:
            break # Stop training if early stopping condition is met

    # Load the best model
    checkpoint = torch.load(checkpoint_path, map_location=args.device)
    early_stopping.load_checkpoint(model, checkpoint)

    # Evaluate the best model on the test set
    logger.info(f'get final performance on dataset {args.dataset_name}...')
    test_losses, test_metrics = evaluate_model_link_prediction(model, full_neighbor_sampler, test_idx_data_loader, test_data)

    # Store the evaluation metrics at the current run
    test_metric_dict = {}

    logger.info(f'test loss: {np.mean(test_losses):.4f}')
    for metric_name in test_metrics[0].keys():
        average_test_metric = np.mean([test_metric[metric_name] for test_metric in test_metrics])
        logger.info(f'test {metric_name}, {average_test_metric:.4f}')
        test_metric_dict[metric_name] = average_test_metric

    single_run_time = time.time() - run_start_time
    logger.info(f'Run {run + 1} cost {single_run_time:.2f} seconds.')
    test_metric_all_runs.append(test_metric_dict)

    # avoid the overlap of logs
    if run < args.num_runs - 1:
        logger.removeHandler(fh)
        logger.removeHandler(ch)

    # Save the results of the current run to a JSON file
    result_json = {"test metrics": {metric_name: f'{test_metric_dict[metric_name]:.4f}' for metric_name in test_metric_dict}}
    result_json = json.dumps(result_json, indent=4)

    save_result_folder = f"saved_results/link_prediction/{args.dataset_name}"
    os.makedirs(save_result_folder, exist_ok=True)
    save_result_path = os.path.join(save_result_folder, f"{args.save_model_name}.json")

    with open(save_result_path, 'w') as file:
        file.write(result_json)

# store the average metrics at the log of the last run
logger.info(f'metrics over {args.num_runs} runs:')

# Log the metrics over all runs
for metric_name in test_metric_all_runs[0].keys():
    logger.info(f'test {metric_name}, {[test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]}')
    logger.info(f'average test {metric_name}, {np.mean([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]):.4f} '
                    f'± {np.std([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs], ddof=1):.4f}')

# Testing & Inferencing

Choose a model.

In [None]:
# Initialize the model
dynamic_backbone = DyGFormer(node_raw_features=node_raw_features, edge_raw_features=edge_raw_features, neighbor_sampler=train_neighbor_sampler,
                                time_feat_dim=args.time_feat_dim, channel_embedding_dim=args.channel_embedding_dim, patch_size=args.patch_size,
                                num_layers=args.num_layers, num_heads=args.num_heads, dropout=args.dropout,
                                max_input_sequence_length=args.max_input_sequence_length, device=args.device)

link_predictor = MergeLayer(input_dim1=node_raw_features.shape[1], input_dim2=node_raw_features.shape[1],
                                    hidden_dim=node_raw_features.shape[1], output_dim=1)
model = nn.Sequential(dynamic_backbone, link_predictor)

model = convert_to_gpu(model, device=args.device)

# Load the pretrained model from the specified path
load_model_path = f"saved_models/link_prediction/train_0/DyGFormer_seed2.pth"
checkpoint = torch.load(load_model_path, map_location=args.device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

Test the model with the testing set (OR)

In [None]:
loss_func = nn.BCELoss()
test_losses, test_metrics = evaluate_model_link_prediction(model, full_neighbor_sampler, test_idx_data_loader, test_data)

# Store the evaluation metrics at the current run
test_metric_dict = {}
print(f'test loss: {np.mean(test_losses):.4f}')

for metric_name in test_metrics[0].keys():
    average_test_metric = np.mean([test_metric[metric_name] for test_metric in test_metrics])
    print(f'test {metric_name}, {average_test_metric:.4f}')
    test_metric_dict[metric_name] = average_test_metric

See how the model performs with custom inputs.

In [None]:
# Set threshold and data generation parameters
threshold = 0.31 # should be around 0.3
each_occurance = 100 # must be < 400
size = 10000 # must be < 160000 (400 * 400)

# Generate an array of node IDs with specified occurrences (array of 100 times one, 100 times two, etc.)
base_array = np.repeat(np.arange(0, 400), each_occurance)
first_node_ids = np.sort(np.random.choice(base_array, size=size, replace=False))
second_node_ids = np.random.choice(base_array, size=size, replace=False)
node_interact_times = np.zeros(size)

# unique, counts = np.unique(second_node_ids, return_counts=True)
# print(f"occurance: {dict(zip(unique, counts))}")

# Compute source and destination node temporal embeddings using the model
first_node_embeddings, second_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=first_node_ids,
                                                  dst_node_ids=second_node_ids,
                                                  node_interact_times=node_interact_times)

# Predict link probabilities
probabilities = model[1](input_1=first_node_embeddings, input_2=second_node_embeddings).squeeze(dim=-1).sigmoid()

# Determine the links based on the threshold
link = np.where(probabilities.detach().cpu() > threshold, 1, 0)

# Create a DataFrame to store the inference results
inference_df = pd.DataFrame({
    'first_nodes': first_node_ids,
    'second_nodes': second_node_ids,
    'link': link
    })

inference_df = inference_df.drop_duplicates()
inference_df = inference_df.sort_values(by=['first_nodes', 'link', 'second_nodes'])
inference_df

Unnamed: 0,first_nodes,second_nodes,link
15,0,30,1
20,0,34,1
5,0,52,1
7,0,133,1
6,0,143,1
...,...,...,...
9979,399,352,1
9977,399,363,1
9989,399,377,1
9991,399,393,1
