# Environmental Setup


> Since it is impractical to download our dataset of sheer size in a reasonable time frame, we decided to have our implementations utilize the shared folder, which can be accessible from the mounted google drive. Therefore, for reproducibility, it is **IMPORTANT** to
1. Visit this [drive folder](https://drive.google.com/drive/folders/1VMn57KmlJ20DlviBlGufDC7vgdWIR9ni?usp=sharing).
2. Once visited, it'll show up in 'Shared with me' section in your google drive, from which you can add the shortcut to your drive.
3. Then, the shortcut should have the path, `drive/MyDrive/CS471 Project`.

Mount Google Drive to the Colab VM and install necessary modules

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!pip install wget

import os
import shutil
import wget
from urllib.parse import urlparse

%cd /content/

repo_path = "https://github.com/KyawYeThu-11/DyGLib.git"
repo_name = os.path.splitext(os.path.basename(urlparse(repo_path).path))[0]

if not os.path.exists(repo_name):
  !git clone $repo_path
  !pip install -r DyGLib/requirements.txt

Collecting wget
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9656 sha256=a82b4476e3bc0e97da5abf8e052d97a57447d6611987013e106c817807e4c90c
  Stored in directory: /root/.cache/pip/wheels/8b/f1/7f/5c94f0a7a505ca1c81cd1d9208ae2064675d97582078e6c769
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
/content
Cloning into 'DyGLib'...
remote: Enumerating objects: 232, done.[K
remote: Counting objects: 100% (175/175), done.[K
remote: Compressing objects: 100% (93/93), done.[K
remote: Total 232 (delta 116), reused 119 (delta 81), pack-reused 57[K
Receiving objects: 100% (232/232), 16.50 MiB | 11.21 MiB/s, done.
Resolving deltas: 100% (128/128), done.
Collecting torch-geometric (from -r DyGLib/requirements.txt (line 6))
  Downloading torch_geo

In [None]:
# !python --version
# !pip install git+https://github.com/PyTorchLightning/pytorch-lightning

In [None]:
%cd /content/DyGLib

/content/DyGLib


In [None]:
import logging
import time
import sys
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
import warnings
import shutil
import json
import torch
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
# import lightning as L
from models.DyGFormerGraph import DyGFormerGraph
from models.DyGFormer import DyGFormer
from models.modules import MergeLayer, GraphRegressor
from utils.utils import set_random_seed, convert_to_gpu, get_parameter_sizes, create_optimizer
from utils.utils import get_neighbor_sampler, NegativeEdgeSampler
from utils.metrics import get_link_prediction_metrics
from utils.DataLoader import get_idx_data_loader, get_graph_regression_data
from utils.EarlyStopping import EarlyStopping
from utils.load_configs import get_link_prediction_args

# Training

In [None]:
class Args:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.batch_size = 1000
        self.model_name = 'DyGFormer'
        self.gpu = 0
        self.num_neighbors = 20
        self.sample_neighbor_strategy = 'recent'
        self.time_scaling_factor = 1e-6
        self.num_walk_heads = 8
        self.num_folds = 5
        self.num_heads = 2
        self.num_layers = 2
        self.walk_length = 1
        self.time_gap = 2000
        self.time_feat_dim = 100
        self.position_feat_dim = 172
        self.edge_bank_memory_mode = 'unlimited_memory'
        self.time_window_mode = 'fixed_proportion'
        self.patch_size = 2
        self.channel_embedding_dim = 50
        self.max_input_sequence_length = 64
        self.learning_rate = 0.005
        self.load_checkpoint = True
        self.dropout = 0.05
        self.scheduler_period = 2
        self.num_epochs = 10
        self.optimizer = 'Adam'
        self.weight_decay = 0.0
        self.patience = 20
        self.test_interval_epochs = 10
        self.negative_sample_strategy = 'random'
        self.load_best_configs = False

    def __str__(self):
        properties = [f"{key}={value}" for key, value in self.__dict__.items()]
        return f"Args({', '.join(properties)})"

    def __repr__(self):
        return self.__str__()

# Create an instance of Args with the loaded configuration
args = Args(dataset_name = 'connectome')
args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu'


### logger

In [None]:
def set_up_logger(args):
        # set up logger
        logging.basicConfig(level=logging.INFO)
        logger = logging.getLogger()
        logger.setLevel(logging.DEBUG)
        os.makedirs(f"./logs/{args.model_name}/{args.dataset_name}/", exist_ok=True)
        # create file handler that logs debug and higher level messages
        fh = logging.FileHandler(f"./logs/{args.model_name}/{args.dataset_name}/{str(time.time())}.log")
        fh.setLevel(logging.DEBUG)
        # create console handler with a higher log level
        ch = logging.StreamHandler()
        ch.setLevel(logging.WARNING)
        # create formatter and add it to the handlers
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        ch.setFormatter(formatter)
        # add the handlers to logger
        logger.addHandler(fh)
        logger.addHandler(ch)

        return logger, fh, ch

### main

In [None]:
def step_forward(dataset_dir, index, subject_id, label, loss_func):
  node_raw_features, edge_raw_features, full_data = \
        get_graph_regression_data(dataset_dir=dataset_dir, index=index)

  idx_data_loader = get_idx_data_loader(indices_list=list(range(len(full_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)
  idx_data_loader_tqdm = tqdm(idx_data_loader, ncols=120)

  # initialize validation and test neighbor sampler to retrieve temporal graph
  full_neighbor_sampler = get_neighbor_sampler(data=full_data, sample_neighbor_strategy=args.sample_neighbor_strategy,
                                                  time_scaling_factor=args.time_scaling_factor, seed=1)
  # training, only use training graph
  model[0].set_neighbor_sampler(full_neighbor_sampler)

  graph_embedding_list = []

  for batch_idx, train_data_indices in enumerate(idx_data_loader_tqdm):


    train_data_indices = train_data_indices.numpy()
    batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids = \
                      full_data.src_node_ids[train_data_indices], full_data.dst_node_ids[train_data_indices], \
                      full_data.node_interact_times[train_data_indices], full_data.edge_ids[train_data_indices]

    graph_embedding = model[0](node_raw_features, edge_raw_features, batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times)
    graph_embedding_list.append(graph_embedding)

    idx_data_loader_tqdm.set_description(f'Epoch: {epoch + 1}, Subject: {subject_id}, training for the {batch_idx + 1}-th batch.')

  graph_embeddings_tensor = torch.tensor(np.stack(graph_embedding_list).squeeze()).to(args.device)

  predict = model[1](graph_embeddings_tensor)
  loss = loss_func(input=predict, target=label)

  return predict, loss

In [None]:
percentile_folder = '5-percentile'
save_model_folder = f"/content/drive/MyDrive/CS471 Project/saved_models/graph_regression/{args.dataset_name}"
os.makedirs(save_model_folder, exist_ok=True)
checkpoint_path = os.path.join(save_model_folder, f'{args.model_name}.pth')

logger, fh, ch = set_up_logger(args)

tart_time = time.time()

logger.info(f'configuration is {args}')

# Initialize the model
dynamic_backbone = DyGFormerGraph(node_feat_dim=args.position_feat_dim, edge_feat_dim=args.position_feat_dim, time_feat_dim=args.time_feat_dim,
                                 channel_embedding_dim=args.channel_embedding_dim, patch_size=args.patch_size,
                                 num_layers=args.num_layers, num_heads=args.num_heads, dropout=args.dropout,
                                 max_input_sequence_length=args.max_input_sequence_length, device=args.device)

graph_regressor = GraphRegressor(in_channels=args.position_feat_dim, hidden_channels=int(args.position_feat_dim/2), out_channels=int(args.position_feat_dim/4))
model = nn.Sequential(dynamic_backbone, graph_regressor)

# log the model structure
logger.info(f'model -> {model}')
logger.info(f'model name: {args.model_name}, #parameters: {get_parameter_sizes(model) * 4} B, '
                    f'{get_parameter_sizes(model) * 4 / 1024} KB, {get_parameter_sizes(model) * 4 / 1024 / 1024} MB.')

# Create the optimizer and scheduler with specified parameters
optimizer = create_optimizer(model=model, optimizer_name=args.optimizer, learning_rate=args.learning_rate, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.scheduler_period)

# Convert the model to GPU if available
model = convert_to_gpu(model, device=args.device)


epoch_resumed = 0
index_resumed = 0
# Load checkpoint if specified and exists
if args.load_checkpoint == True and os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=args.device)
    logger.info(f"load model {checkpoint_path}")

    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])

    epoch_resumed = checkpoint['epoch']
    index_resumed = checkpoint['subject_index'] + 1
    print(f"Epoch resumed: {epoch_resumed}")
    print(f"Index resumed: {index_resumed}")

# Define the mean square loss function
loss_func = nn.MSELoss()

# Training loop
for epoch in range(epoch_resumed, args.num_epochs, 1):

    # Training for an epoch starts
    model.train()

    train_losses = []

    language_score_df = pd.read_csv(f'/content/drive/MyDrive/CS471 Project/{percentile_folder}-data/Train_Data_csv/Language_Task_Acc.csv')

    # Iterate over each subject in the training data
    for index, row in language_score_df.iterrows():
        if index < index_resumed:
          continue
        subject_id = int(row.iloc[2])
        score = torch.Tensor([row.iloc[3]]).to(args.device)
        subject_folder = os.path.join(f'/content/drive/MyDrive/CS471 Project/Preprocessed Data/{percentile_folder}/Train', str(index))

        # process every batch for a single subject and compute predictions and loss
        predict, loss = step_forward(subject_folder, index, subject_id, score, loss_func)

        train_losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        print(f"For subject {subject_id}, label: {score.item():.4f}, predict: {predict.item():.4f}, and loss: {loss:.4f}")

        # Save the checkpoint
        torch.save({
                  'subject_index': index,
                  'epoch': epoch,
                  'model': model.state_dict(),
                  'optimizer': optimizer.state_dict(),
                  'scheduler': scheduler.state_dict()},
                  checkpoint_path)

    # validation
    print("---------------Validation Starts--------------")

    model.eval()

    val_losses = []

    with torch.no_grad():
        language_score_df = pd.read_csv(f'/content/drive/MyDrive/CS471 Project/{percentile_folder}-data/Val_Data_csv/Language_Task_Acc.csv')

        for index, row in language_score_df.iterrows():
            subject_id = int(row.iloc[2])
            score = torch.Tensor([row.iloc[3]]).to(args.device)
            subject_folder = os.path.join(f'/content/drive/MyDrive/CS471 Project/Preprocessed Data/{percentile_folder}/Val', str(index))

            predict, loss = step_forward(subject_folder, subject_id, score, loss_func)

            val_losses.append(loss.item())

            print(f"For subject {subject_id}, label: {score.item():.4f}, predict: {predict.item():.4f}, and loss: {loss:.4f}")

    print("-----------------------------------")
    logger.info(f'Epoch: {epoch}, learning rate: {optimizer.param_groups[0]["lr"]}, train loss: {np.mean(train_losses):.4f} and val loss: {np.mean(val_losses):.4f}')
    print("-----------------------------------")

    # Save the results of the current run to a JSON file
    result_json = {
            "train losses": [f'{loss:.4f}' for loss in train_losses],
            'val losses': [f'{loss:.4f}' for loss in val_losses]
    }

    result_json = json.dumps(result_json, indent=4)

    save_result_folder = f"/content/drive/MyDrive/CS471 Project/saved_results/graph_regression/{args.dataset_name}"
    os.makedirs(save_result_folder, exist_ok=True)
    save_result_path = os.path.join(save_result_folder, f"{args.save_model_name}.json")

    with open(save_result_path, 'w') as file:
        file.write(result_json)

torch.save(model.state_dict(), os.path.join(save_model_folder, f'{args.model_name}_final.pth'))

INFO:root:configuration is Args(dataset_name=connectome, batch_size=1000, model_name=DyGFormer, gpu=0, num_neighbors=20, sample_neighbor_strategy=recent, time_scaling_factor=1e-06, num_walk_heads=8, num_folds=5, num_heads=2, num_layers=2, walk_length=1, time_gap=2000, time_feat_dim=100, position_feat_dim=172, edge_bank_memory_mode=unlimited_memory, time_window_mode=fixed_proportion, patch_size=2, channel_embedding_dim=50, max_input_sequence_length=64, learning_rate=0.005, load_checkpoint=True, dropout=0.05, scheduler_period=2, num_epochs=10, optimizer=Adam, weight_decay=0.0, patience=20, test_interval_epochs=10, negative_sample_strategy=random, load_best_configs=False, device=cpu)
INFO:root:model -> Sequential(
  (0): DyGFormerGraph(
    (time_encoder): TimeEncoder(
      (w): Linear(in_features=1, out_features=100, bias=True)
    )
    (neighbor_co_occurrence_encoder): NeighborCooccurrenceEncoder(
      (neighbor_co_occurrence_encode_layer): Sequential(
        (0): Linear(in_features

KeyboardInterrupt: 

## Testing

Load the model.

In [None]:
# Initialize the model
dynamic_backbone = DyGFormerGraph(node_feat_dim=args.position_feat_dim, edge_feat_dim=args.position_feat_dim, time_feat_dim=args.time_feat_dim,
                                 channel_embedding_dim=args.channel_embedding_dim, patch_size=args.patch_size,
                                 num_layers=args.num_layers, num_heads=args.num_heads, dropout=args.dropout,
                                 max_input_sequence_length=args.max_input_sequence_length, device=args.device)

graph_regressor = GraphRegressor(in_channels=args.position_feat_dim, hidden_channels=int(args.position_feat_dim/2), out_channels=int(args.position_feat_dim/4))
model = nn.Sequential(dynamic_backbone, graph_regressor)

# Load the pretrained model from the specified path
load_model_path = f"/content/drive/MyDrive/CS471 Project/saved_models/graph_regression/DyGFormer.pth"
checkpoint = torch.load(load_model_path, map_location=args.device)
model.load_state_dict(checkpoint['model_state_dict'])

Test the model with the testing set.

In [None]:
test_losses = []

with torch.no_grad():
    language_score_df = pd.read_csv(f'/content/drive/MyDrive/CS471 Project/{percentile_folder}-data/Test_Data_csv/Language_Task_Acc.csv')

    for index, row in language_score_df.iterrows():
        subject_id = int(row.iloc[2])
        score = torch.Tensor([row.iloc[3]]).to(args.device)
        subject_folder = os.path.join(f'/content/drive/MyDrive/CS471 Project/Preprocessed Data/{percentile_folder}/Test', str(index))

        predict, loss = step_forward(subject_folder, subject_id, score, loss_func)

        test_losses.append(loss.item())

        print(f"For subject {subject_id}, label: {score.item():.4f}, predict: {predict.item():.4f}, and loss: {loss:.4f}")

    print(f'test loss: {np.mean(val_losses):.4f}')

# Training Pipeline (v2)

In [None]:
import pandas as pd
import cudf
import random

In [None]:
# numBatch = 200
def csv2triplets(csvDf, dir, num):
  triplets=torch.tensor(cudf.DataFrame(csvDf, columns=['source', 'dest', 'time_interval']).values)
  triplets=triplets.reshape(50, 120, 200, 3)
  torch.save(triplets, os.path.join(dir, 'triplets'+str(num)+'.pt'))

In [None]:

num_folds=5
for k in range(num_folds):
  print(f"Fold {k+1}")
  for i in range(40):
    print(f"Iteration {i}")
    csvDf=cudf.read_csv(os.path.join('./DG_data/Train/Fold'+str(k), 'train_'+str(i)+'.csv'))
    if(csvDf.shape[0]!=50*120*200):
      randomDf=cudf.DataFrame(columns=['source', 'dest', 'time_interval', 'edge_label', 'w'])
      dic={'source':[random.randint(0, 399) for j in range(50*120*200-csvDf.shape[0])],
           'dest':[random.randint(0, 399) for j in range(50*120*200-csvDf.shape[0])],
           'time_interval':[random.randint(0, 49) for j in range(50*120*200-csvDf.shape[0])],
           'edge_label':[0 for j in range(50*120*200-csvDf.shape[0])],
           'w':[0 for j in range(50*120*200-csvDf.shape[0])]}
      randomDf=cudf.DataFrame(dic)
      csvDf=cudf.concat([csvDf, randomDf], axis=0, ignore_index=True)
      csvDf=csvDf.iloc[:,[0,1,2]]
      print(csvDf.shape[0])
    assert csvDf.shape[0]==50*120*200
    csvDf.sort_values(by=['time_interval', 'source', 'dest'])
    csv2triplets(csvDf,  './DG_data/Train/Fold'+str(k), i)
  for i in range(10):
    csvDf=cudf.read_csv(os.path.join('./DG_data/Test/Fold'+str(k), 'test_'+str(i)+'.csv'))
    if(csvDf.shape[0]!=50*120*200):
      randomDf=cudf.DataFrame(columns=['source', 'dest', 'time_interval', 'edge_label', 'w'])
      dic={'source':[random.randint(0, 399) for j in range(50*120*200-csvDf.shape[0])],
           'dest':[random.randint(0, 399) for j in range(50*120*200-csvDf.shape[0])],
           'time_interval':[random.randint(0, 49) for j in range(50*120*200-csvDf.shape[0])],
           'edge_label':[0 for j in range(50*120*200-csvDf.shape[0])],
           'w':[0 for j in range(50*120*200-csvDf.shape[0])]}
      randomDf=cudf.DataFrame(dic)
      csvDf=cudf.concat([csvDf, randomDf], axis=0, ignore_index=True)
      csvDf=csvDf.iloc[:,[0,1,2]]
    csvDf.sort_values(by=['time_interval', 'source', 'dest'])
    csv2triplets(csvDf, './DG_data/Test/Fold'+str(k), i)

In [None]:
class MLPRegressor(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLPRegressor, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu=nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [None]:
class NeighborSampler:

    def __init__(self, adj_list: list, sample_neighbor_strategy: str = 'uniform', time_scaling_factor: float = 0.0, seed: int = None):
        """
        Neighbor sampler.
        :param adj_list: list, list of list, where each element is a list of triple tuple (node_id, edge_id, timestamp)
        :param sample_neighbor_strategy: str, how to sample historical neighbors, 'uniform', 'recent', or 'time_interval_aware'
        :param time_scaling_factor: float, a hyper-parameter that controls the sampling preference with time interval,
        a large time_scaling_factor tends to sample more on recent links, this parameter works when sample_neighbor_strategy == 'time_interval_aware'
        :param seed: int, random seed
        """
        self.sample_neighbor_strategy = sample_neighbor_strategy
        self.seed = seed

        # list of each node's neighbor ids, edge ids and interaction times, which are sorted by interaction times
        self.nodes_neighbor_ids = []
        self.nodes_edge_ids = []
        self.nodes_neighbor_times = []

        if self.sample_neighbor_strategy == 'time_interval_aware':
            self.nodes_neighbor_sampled_probabilities = []
            self.time_scaling_factor = time_scaling_factor

        # the list at the first position in adj_list is empty, hence, sorted() will return an empty list for the first position
        # its corresponding value in self.nodes_neighbor_ids, self.nodes_edge_ids, self.nodes_neighbor_times will also be empty with length 0
        for node_idx, per_node_neighbors in enumerate(adj_list):
            # per_node_neighbors is a list of tuples (neighbor_id, edge_id, timestamp)
            # sort the list based on timestamps, sorted() function is stable
            # Note that sort the list based on edge id is also correct, as the original data file ensures the interactions are chronological
            sorted_per_node_neighbors = sorted(per_node_neighbors, key=lambda x: x[2])
            self.nodes_neighbor_ids.append(np.array([x[0] for x in sorted_per_node_neighbors]))
            self.nodes_edge_ids.append(np.array([x[1] for x in sorted_per_node_neighbors]))
            self.nodes_neighbor_times.append(np.array([x[2] for x in sorted_per_node_neighbors]))

            # additional for time interval aware sampling strategy (proposed in CAWN paper)
            if self.sample_neighbor_strategy == 'time_interval_aware':
                self.nodes_neighbor_sampled_probabilities.append(self.compute_sampled_probabilities(np.array([x[2] for x in sorted_per_node_neighbors])))

        if self.seed is not None:
            self.random_state = np.random.RandomState(self.seed)

    def compute_sampled_probabilities(self, node_neighbor_times: np.ndarray):
        """
        compute the sampled probabilities of historical neighbors based on their interaction times
        :param node_neighbor_times: ndarray, shape (num_historical_neighbors, )
        :return:
        """
        if len(node_neighbor_times) == 0:
            return np.array([])
        # compute the time delta with regard to the last time in node_neighbor_times
        node_neighbor_times = node_neighbor_times - np.max(node_neighbor_times)
        # compute the normalized sampled probabilities of historical neighbors
        exp_node_neighbor_times = np.exp(self.time_scaling_factor * node_neighbor_times)
        sampled_probabilities = exp_node_neighbor_times / np.cumsum(exp_node_neighbor_times)
        # note that the first few values in exp_node_neighbor_times may be all zero, which make the corresponding values in sampled_probabilities
        # become nan (divided by zero), so we replace the nan by a very large negative number -1e10 to denote the sampled probabilities
        sampled_probabilities[np.isnan(sampled_probabilities)] = -1e10
        return sampled_probabilities

    def find_neighbors_before(self, node_id: int, interact_time: float, return_sampled_probabilities: bool = False):
        """
        extracts all the interactions happening before interact_time (less than interact_time) for node_id in the overall interaction graph
        the returned interactions are sorted by time.
        :param node_id: int, node id
        :param interact_time: float, interaction time
        :param return_sampled_probabilities: boolean, whether return the sampled probabilities of neighbors
        :return: neighbors, edge_ids, timestamps and sampled_probabilities (if return_sampled_probabilities is True) with shape (historical_nodes_num, )
        """
        # return index i, which satisfies list[i - 1] < v <= list[i]
        # return 0 for the first position in self.nodes_neighbor_times since the value at the first position is empty
        i = np.searchsorted(self.nodes_neighbor_times[node_id], interact_time)

        if return_sampled_probabilities:
            return self.nodes_neighbor_ids[node_id][:i], self.nodes_edge_ids[node_id][:i], self.nodes_neighbor_times[node_id][:i], \
                   self.nodes_neighbor_sampled_probabilities[node_id][:i]
        else:
            return self.nodes_neighbor_ids[node_id][:i], self.nodes_edge_ids[node_id][:i], self.nodes_neighbor_times[node_id][:i], None

    def get_historical_neighbors(self, node_ids: np.ndarray, node_interact_times: np.ndarray, num_neighbors: int = 20):
        """
        get historical neighbors of nodes in node_ids with interactions before the corresponding time in node_interact_times
        :param node_ids: ndarray, shape (batch_size, ) or (*, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ) or (*, ), node interaction times
        :param num_neighbors: int, number of neighbors to sample for each node
        :return:
        """
        assert num_neighbors > 0, 'Number of sampled neighbors for each node should be greater than 0!'
        # All interactions described in the following three matrices are sorted in each row by time
        # each entry in position (i,j) represents the id of the j-th dst node of src node node_ids[i] with an interaction before node_interact_times[i]
        # ndarray, shape (batch_size, num_neighbors)
        nodes_neighbor_ids = np.zeros((len(node_ids), num_neighbors)).astype(np.longlong)
        # each entry in position (i,j) represents the id of the edge with src node node_ids[i] and dst node nodes_neighbor_ids[i][j] with an interaction before node_interact_times[i]
        # ndarray, shape (batch_size, num_neighbors)
        nodes_edge_ids = np.zeros((len(node_ids), num_neighbors)).astype(np.longlong)
        # each entry in position (i,j) represents the interaction time between src node node_ids[i] and dst node nodes_neighbor_ids[i][j], before node_interact_times[i]
        # ndarray, shape (batch_size, num_neighbors)
        nodes_neighbor_times = np.zeros((len(node_ids), num_neighbors)).astype(np.float32)

        # extracts all neighbors ids, edge ids and interaction times of nodes in node_ids, which happened before the corresponding time in node_interact_times
        for idx, (node_id, node_interact_time) in enumerate(zip(node_ids, node_interact_times)):
            # find neighbors that interacted with node_id before time node_interact_time
            node_neighbor_ids, node_edge_ids, node_neighbor_times, node_neighbor_sampled_probabilities = \
                self.find_neighbors_before(node_id=node_id, interact_time=node_interact_time, return_sampled_probabilities=self.sample_neighbor_strategy == 'time_interval_aware')

            if len(node_neighbor_ids) > 0:
                if self.sample_neighbor_strategy in ['uniform', 'time_interval_aware']:
                    # when self.sample_neighbor_strategy == 'uniform', we shuffle the data before sampling with node_neighbor_sampled_probabilities as None
                    # when self.sample_neighbor_strategy == 'time_interval_aware', we sample neighbors based on node_neighbor_sampled_probabilities
                    # for time_interval_aware sampling strategy, we additionally use softmax to make the sum of sampled probabilities be 1
                    if node_neighbor_sampled_probabilities is not None:
                        # for extreme case that node_neighbor_sampled_probabilities only contains -1e10, which will make the denominator of softmax be zero,
                        # torch.softmax() function can tackle this case
                        node_neighbor_sampled_probabilities = torch.softmax(torch.from_numpy(node_neighbor_sampled_probabilities).float(), dim=0).numpy()
                    if self.seed is None:
                        sampled_indices = np.random.choice(a=len(node_neighbor_ids), size=num_neighbors, p=node_neighbor_sampled_probabilities)
                    else:
                        sampled_indices = self.random_state.choice(a=len(node_neighbor_ids), size=num_neighbors, p=node_neighbor_sampled_probabilities)

                    nodes_neighbor_ids[idx, :] = node_neighbor_ids[sampled_indices]
                    nodes_edge_ids[idx, :] = node_edge_ids[sampled_indices]
                    nodes_neighbor_times[idx, :] = node_neighbor_times[sampled_indices]

                    # resort based on timestamps, return the ids in sorted increasing order, note this maybe unstable when multiple edges happen at the same time
                    # (we still do this though this is unnecessary for TGAT or CAWN to guarantee the order of nodes,
                    # since TGAT computes in an order-agnostic manner with relative time encoding, and CAWN computes for each walk while the sampled nodes are in different walks)
                    sorted_position = nodes_neighbor_times[idx, :].argsort()
                    nodes_neighbor_ids[idx, :] = nodes_neighbor_ids[idx, :][sorted_position]
                    nodes_edge_ids[idx, :] = nodes_edge_ids[idx, :][sorted_position]
                    nodes_neighbor_times[idx, :] = nodes_neighbor_times[idx, :][sorted_position]
                elif self.sample_neighbor_strategy == 'recent':
                    # Take most recent interactions with number num_neighbors
                    node_neighbor_ids = node_neighbor_ids[-num_neighbors:]
                    node_edge_ids = node_edge_ids[-num_neighbors:]
                    node_neighbor_times = node_neighbor_times[-num_neighbors:]

                    # put the neighbors' information at the back positions
                    nodes_neighbor_ids[idx, num_neighbors - len(node_neighbor_ids):] = node_neighbor_ids
                    nodes_edge_ids[idx, num_neighbors - len(node_edge_ids):] = node_edge_ids
                    nodes_neighbor_times[idx, num_neighbors - len(node_neighbor_times):] = node_neighbor_times
                else:
                    raise ValueError(f'Not implemented error for sample_neighbor_strategy {self.sample_neighbor_strategy}!')

        # three ndarrays, with shape (batch_size, num_neighbors)
        return nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times

    def get_multi_hop_neighbors(self, num_hops: int, node_ids: np.ndarray, node_interact_times: np.ndarray, num_neighbors: int = 20):
        """
        get historical neighbors of nodes in node_ids within num_hops hops
        :param num_hops: int, number of sampled hops
        :param node_ids: ndarray, shape (batch_size, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ), node interaction times
        :param num_neighbors: int, number of neighbors to sample for each node
        :return:
        """
        assert num_hops > 0, 'Number of sampled hops should be greater than 0!'

        # get the temporal neighbors at the first hop
        # nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times -> ndarray, shape (batch_size, num_neighbors)
        nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times = self.get_historical_neighbors(node_ids=node_ids,
                                                                                                 node_interact_times=node_interact_times,
                                                                                                 num_neighbors=num_neighbors)
        # three lists to store the neighbor ids, edge ids and interaction timestamp information
        nodes_neighbor_ids_list = [nodes_neighbor_ids]
        nodes_edge_ids_list = [nodes_edge_ids]
        nodes_neighbor_times_list = [nodes_neighbor_times]
        for hop in range(1, num_hops):
            # get information of neighbors sampled at the current hop
            # three ndarrays, with shape (batch_size * num_neighbors ** hop, num_neighbors)
            nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times = self.get_historical_neighbors(node_ids=nodes_neighbor_ids_list[-1].flatten(),
                                                                                                     node_interact_times=nodes_neighbor_times_list[-1].flatten(),
                                                                                                     num_neighbors=num_neighbors)
            # three ndarrays with shape (batch_size, num_neighbors ** (hop + 1))
            nodes_neighbor_ids = nodes_neighbor_ids.reshape(len(node_ids), -1)
            nodes_edge_ids = nodes_edge_ids.reshape(len(node_ids), -1)
            nodes_neighbor_times = nodes_neighbor_times.reshape(len(node_ids), -1)

            nodes_neighbor_ids_list.append(nodes_neighbor_ids)
            nodes_edge_ids_list.append(nodes_edge_ids)
            nodes_neighbor_times_list.append(nodes_neighbor_times)

        # tuple, each element in the tuple is a list of num_hops ndarrays, each with shape (batch_size, num_neighbors ** current_hop)
        return nodes_neighbor_ids_list, nodes_edge_ids_list, nodes_neighbor_times_list

    def get_all_first_hop_neighbors(self, node_ids: np.ndarray, node_interact_times: np.ndarray):
        """
        get historical neighbors of nodes in node_ids at the first hop with max_num_neighbors as the maximal number of neighbors (make the computation feasible)
        :param node_ids: ndarray, shape (batch_size, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ), node interaction times
        :return:
        """
        # three lists to store the first-hop neighbor ids, edge ids and interaction timestamp information, with batch_size as the list length
        nodes_neighbor_ids_list, nodes_edge_ids_list, nodes_neighbor_times_list = [], [], []
        # get the temporal neighbors at the first hop
        for idx, (node_id, node_interact_time) in enumerate(zip(node_ids, node_interact_times)):
            # find neighbors that interacted with node_id before time node_interact_time
            node_neighbor_ids, node_edge_ids, node_neighbor_times, _ = self.find_neighbors_before(node_id=node_id,
                                                                                                  interact_time=node_interact_time,
                                                                                                  return_sampled_probabilities=False)
            nodes_neighbor_ids_list.append(node_neighbor_ids)
            nodes_edge_ids_list.append(node_edge_ids)
            nodes_neighbor_times_list.append(node_neighbor_times)

        return nodes_neighbor_ids_list, nodes_edge_ids_list, nodes_neighbor_times_list

    def reset_random_state(self):
        """
        reset the random state by self.seed
        :return:
        """
        self.random_state = np.random.RandomState(self.seed)



In [None]:
def get_neighbor_sampler(data, sample_neighbor_strategy: str = 'uniform', time_scaling_factor: float = 0.0, seed: int = None):
    """
    get neighbor sampler
    :param data: Data
    :param sample_neighbor_strategy: str, how to sample historical neighbors, 'uniform', 'recent', or 'time_interval_aware''
    :param time_scaling_factor: float, a hyper-parameter that controls the sampling preference with time interval,
    a large time_scaling_factor tends to sample more on recent links, this parameter works when sample_neighbor_strategy == 'time_interval_aware'
    :param seed: int, random seed
    :return:
    """
    max_node_id = max(data[:,0].max(), data[:,1].max())
    # the adjacency vector stores edges for each node (source or destination), undirected
    # adj_list, list of list, where each element is a list of triple tuple (node_id, edge_id, timestamp)
    # the list at the first position in adj_list is empty
    adj_list = [[] for _ in range(max_node_id + 1)]
    edge_ids=torch.as_tensor([0 for i in range(data.shape[0])])
    for src_node_id, dst_node_id, edge_id, node_interact_time in zip(data[:,0], data[:,1], edge_ids, data[2]):
        adj_list[src_node_id].append((dst_node_id, edge_id, node_interact_time))
        adj_list[dst_node_id].append((src_node_id, edge_id, node_interact_time))

    return NeighborSampler(adj_list=adj_list, sample_neighbor_strategy=sample_neighbor_strategy, time_scaling_factor=time_scaling_factor, seed=seed)


In [None]:
if os.path.isfile(os.path.join('./drive/MyDrive/CS471 Project', 'checkpoint.pth')):
        print('resuming checkpoint experiment')
        checkpoint = torch.load(os.path.join('./drive/MyDrive/CS471 Project', 'checkpoint.pth'), map_location='cuda')
else:
  checkpoint = {
    'fold': 0,
    'epoch': 0,
    'subject': 0,
    'model': None,
    'mode': 'sum',
    'optimizer': None,
    'scheduler': None}

num_folds=5
num_epochs=100
num_subjects=40
num_timepoints=50
node_dim=172
edge_dim=172
learning_rate=0.001
bestLoss=np.zeros(5)
for k_index in range(num_folds):
  if checkpoint['fold']:
    if k_index < checkpoint['fold']:
      continue
  model=DyGFormerGraph(node_feat_dim=node_dim, edge_feat_dim=edge_dim, time_feat_dim=args.time_feat_dim,
                                 channel_embedding_dim=args.channel_embedding_dim, device='cuda')
  model.node_raw_features=torch.zeros((400+1,172)).to('cuda')
  model.edge_raw_features=torch.zeros((400+1,172)).to('cuda')


  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
  gRegressor=GraphRegressor(in_channels=1, hidden_channels=2, out_channels=1).half()
  gRegressor.to('cuda')
  mRegressor=MLPRegressor(input_dim=172, hidden_dim=13, output_dim=1).half()
  mRegressor.to('cuda')
  criterion=torch.nn.MSELoss()
  minLoss=np.Inf
  tolerance=0
  threshold=5

  if checkpoint['model'] is not None: model.load_state_dict(checkpoint['model'])
  if checkpoint['optimizer'] is not None: optimizer.load_state_dict(checkpoint['optimizer'])
  if checkpoint['scheduler'] is not None: scheduler.load_state_dict(checkpoint['scheduler'])
  for epoch in range(checkpoint['epoch'], num_epochs):
    embs=torch.zeros((node_dim,num_subjects))
    for i in range(checkpoint['subject'], num_subjects):
      triplets=torch.load(f'./DG_data/Train/Fold{k}/triplets{i}.pt')
      assert triplets.shape==(num_timepoints, 120, 200, 3)
      triplets3=triplets.reshape(-1,3)

      scoresDf=cudf.read_csv(f'./DG_data/Train/Fold{k}/Language_Task_Acc.csv')
      scoresDf=scoresDf.iloc[:,[1,3]]
      scores=torch.as_tensor(scoresDf.values).to('cuda')

      dataloader=torch.utils.data.DataLoader(triplets, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
      graphEmbs=torch.zeros((node_dim, num_timepoints), device=torch.device('cuda'))

      for t, data in enumerate(tqdm(dataloader)):
        data=data.squeeze(0)
        acrossB=torch.zeros((node_dim, 120)).to('cuda')
        for b, batch in enumerate(data):
          batch=batch.detach().to('cpu').numpy()
          srcNodes=batch[:,0]
          dstNodes=batch[:,1]
          timepoints=batch[:,2]
          train_neighbor_sampler = get_neighbor_sampler(data=batch, sample_neighbor_strategy=args.sample_neighbor_strategy,
                                                  time_scaling_factor=args.time_scaling_factor, seed=0)
          model.set_neighbor_sampler(train_neighbor_sampler)
          model.to('cuda')
          with torch.no_grad():
            batch_src_node_embeddings, batch_dst_node_embeddings = \
            model.compute_src_dst_node_temporal_embeddings(src_node_ids=srcNodes, dst_node_ids=dstNodes, node_interact_times=timepoints)
            if checkpoint['mode']=='sum':
              batch_graph_embeddings=torch.sum(torch.cat([batch_src_node_embeddings, batch_dst_node_embeddings], dim=0), dim=0)
            if checkpoint['mode']=='mean':
              batch_graph_embeddings=torch.mean(torch.cat([batch_src_node_embeddings, batch_dst_node_embeddings], dim=0), dim=0)
            batch_graph_embeddings=batch_graph_embeddings.to('cuda').t()
          acrossB[:,b]=batch_graph_embeddings
          del train_neighbor_sampler
          del batch
          del srcNodes
          del dstNodes
          del timepoints
          del batch_src_node_embeddings
          del batch_dst_node_embeddings
          del batch_graph_embeddings
          torch.cuda.empty_cache()
        bMean=torch.mean(acrossB, dim=1)
        graphEmbs[:,t]=bMean
        del data
        del acrossB
        del bMean
        torch.cuda.empty_cache()
      graphEmbs=graphEmbs.unsqueeze(1).permute(1,0,2)
      #graphEmbs.shape==(172,1,50)

      regEmb=gRegressor(graphEmbs).squeeze(1)
      print(regEmb.shape)
      embs[:,i]=regEmb
      del regEmb
      del graphEmbs

    embs=embs.t()
    embsMLP=mRegressor(embs)

    scores=scores.reshape(-1,1)
    optimizer.zero_grad()
    loss=criterion(embsMLP, scores)
    loss.backward()
    optimizer.step()
    scheduler.step()

    print(f'Training - k:{k} e:{epoch} loss:{loss}')
    if loss<minLoss:
      minLoss=loss
      tolerance=0
    else:
      tolerance=tolerance+1



    torch.save({
                'fold': k,
                'epoch': epoch+1,
                'model': model.state_dict(),
                'mode': checkpoint['mode'],
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()},
                os.path.join('./drive/MyDrive/CS471 Project', 'checkpoint.pth'))

    if tolerance==threshold:
      break



  torch.save(model.state_dict(), os.path.join('./drive/MyDrive/CS471 Project', f'model{k}.pth'))
  checkpoint.update({'epoch': 0, 'model': None, 'optimizer': None, 'scheduler': None})
  bestLoss[i]=minLoss

pd.DataFrame(bestLoss).to_csv('./drive/MyDrive/CS471 Project/bestLoss.csv')

In [None]:
class DyGFormerGraphL(L.LightningModule):
  def __init__(self,node_dim, edge_dim, time_feat_dim, channel_embedding_dim):
    super().__init__()
    self.model=DyGFormerGraph(node_feat_dim=node_dim, edge_feat_dim=edge_dim, time_feat_dim=time_feat_dim,
                                 channel_embedding_dim=channel_embedding_dim)
    self.model.node_raw_features=torch.zeros((400+1,node_dim))
    self.model.edge_raw_features=torch.zeros((400+1,edge_dim))
  def forward(self, input4d):
    for t, data in enumerate(tqdm(input4d)):
      data=data.squeeze(0)
      acrossB=torch.zeros((node_dim, 120))
      for b, batch in enumerate(data):
        batch=batch.detach().numpy()
        srcNodes=batch[:,0]
        dstNodes=batch[:,1]
        timepoints=batch[:,2]
        train_neighbor_sampler = get_neighbor_sampler(data=batch, sample_neighbor_strategy=args.sample_neighbor_strategy,
                                                time_scaling_factor=args.time_scaling_factor, seed=0)
        self.model.set_neighbor_sampler(train_neighbor_sampler)
        with torch.no_grad():
          batch_src_node_embeddings, batch_dst_node_embeddings = \
          self.model.compute_src_dst_node_temporal_embeddings(src_node_ids=srcNodes, dst_node_ids=dstNodes, node_interact_times=timepoints)
          if checkpoint['mode']=='sum':
            batch_graph_embeddings=torch.sum(torch.cat([batch_src_node_embeddings, batch_dst_node_embeddings], dim=0), dim=0)
          if checkpoint['mode']=='mean':
            batch_graph_embeddings=torch.mean(torch.cat([batch_src_node_embeddings, batch_dst_node_embeddings], dim=0), dim=0)
          batch_graph_embeddings=batch_graph_embeddings.t()
        acrossB[:,b]=batch_graph_embeddings

      bMean=torch.mean(acrossB, dim=1)
      graphEmbs[:,t]=bMean
    graphEmbs=graphEmbs.unsqueeze(1).permute(1,0,2)
    #graphEmbs.shape==(172,1,50)

    return graphEmbs

  def training_step():
    gRegressor=GraphRegressor(in_channels=1, hidden_channels=2, out_channels=1)
    mRegressor=MLPRegressor(input_dim=172, hidden_dim=13, output_dim=1)
    criterion=torch.nn.MSELoss()
    minLoss=np.Inf
    tolerance=0
    threshold=5

    triplets=torch.load(f'./DG_data/Train/Fold{k}/triplets{i}.pt')
    assert triplets.shape==(num_timepoints, 120, 200, 3)
    triplets3=triplets.reshape(-1,3)

    scoresDf=cudf.read_csv(f'./DG_data/Train/Fold{k}/Language_Task_Acc.csv')
    scoresDf=scoresDf.iloc[:,[1,3]]
    scores=torch.as_tensor(scoresDf.values).to('cuda')

    dataloader=torch.utils.data.DataLoader(triplets, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
    graphEmbs=torch.zeros((node_dim, num_timepoints))
    graphEmbs=self.forward(dataloader)
    regEmb=gRegressor(graphEmbs).squeeze(1)
    print(regEmb.shape)
    embs[:,i]=regEmb
