# Learning on Dynamic Text-attributed Graphs

While we haven't explicitly covered dynamics graphs in this course, we believe an exciting application is when the edges encode textual information about the entities represented by the nodes. 

In short, dynamic graphs are graphs where edges and nodes are created, updated, or removed over a duration of time. Time can either be continuous (fractions of a second, for instance) or discrete (ie, as integer time steps). "Actions" can occur along these discrete or continuous segments of time.

In this notebook, we focus on a recent collection of such graphs called the **D**ynamic **T**ext-attributed **G**raph **B**enchmark ([DTGB](https://arxiv.org/abs/2406.12072)). 

> Specifically, we'll be using **GDELT**, a temporal knowledge graph dataset that tracks worldwide political behaviour. Nodes indicate entities like the *United States* or *Barack Obama*. It follows that edges represent relationships between these entities, capturing some type of political behaviour (eg: `Joe Biden <--Is President--> United States`). The edges are organized based on their datetime of occurrence. 

The dataset is originally from the [GDELT Project](https://www.gdeltproject.org/) that aims to capture real-time updates in the worldwide political landscape. 

### The task
We focus on **(binary) link prediction**, ie, whether an edge should exist between two nodes. The graph is a dynamic knowledge graph where nodes are political entities and edges denote relationships between these entities. These node and edge features are in textual format which must be embedded (we use BERT). The model we're using is GraphMixer, a MLP-Mixer-based architecture that will operate on these node and edge text embeddings.

We'll first transform our textual attributes into embeddings, followed by building our data loader, and finally, constructing the model architecture. 

## 0. Imports

Run this cell to make the necessary imports.

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

import logging
import warnings
warnings.filterwarnings("ignore")

loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
    if "transformers" in logger.name.lower():
        logger.setLevel(logging.ERROR)

## 1. Downloading GDELT

The GDELT dataset has been hosted on [Google Drive](https://drive.google.com/drive/folders/1QFxHIjusLOFma30gF59_hcB19Ix3QZtk). Please download and unzip the `GDELT.zip` file.

In [None]:
dataset_root = './GDELT/'

## 2. Feature extraction using BERT

The nodes and edges are in text format. To encode these pieces of text, we'll be using [BERT](https://arxiv.org/abs/1810.04805), a popular text encoder model built with the Transformer architecture covered in class.

We'll be using HuggingFace's [`transformers`](https://github.com/huggingface/transformers) library that provides wrappers around popular Transformer models, with BERT being one of them.

In [None]:
from transformers import AutoTokenizer, AutoConfig, AutoModel

Here, we'll be using the `bert-base-uncased` model. It means we're using the base model (with the smallest parameter count) that doesn't factor in the case (upper or lower) of the text passed in. This takes in input text and converts it into an embedding of some dimension $d=768$.

In [None]:
# Prepare the BERT model and tokenizer
config = AutoConfig.from_pretrained('bert-base-uncased')
hidden_size = config.hidden_size # 768
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
BERT_model = AutoModel.from_pretrained('bert-base-uncased').cuda()

# NOTE: remove the .cuda() if you wish to run this on CPU

Here's a short demo on using the tokenizer and language model to encode text into embeddings. You'll be using a similar strategy later on to embed the node and edge attributes! 

> You can ignore the parameters passed into the functions – keep them the same when you're using them on the graph data.

In [None]:
import torch
texts = [
    "you've heard of cat on the mat – what about giraffe on the graph?", 
    "what a wonderful day to learn on graphs!", 
    "this is a very long, uninformative, run-on sentence with no punchline!"
]

# tokenize the text (ie, convert text -> numbers)
tokens = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to("cuda")

with torch.no_grad(): # we do not want gradients to flow
    out = BERT_model(**tokens)[1] # the items in the [1] index contain the text embeddings of interest
    print (out.shape)
    
# note how the first dimension of the output is the # of sentences in the `texts` list

## 3. Encoding text attributes along nodes and edges

Now, let's use BERT in a similar way to encode the text attributes in the graph. Remember, the edges capture political relationships or behaviour between entities. 

> **NOTE**: Remember, when you're running the forward pass through the BERT model, we do _NOT_ want gradients to flow, so remember to use `torch.no_grad()` carefully.

The GDELT dataset contains a few files in CSV format. Each record contains a tuple representing either a node or edge, which when put together forms the entire (temporal) knowledge graph. Here's a succinct description of each column in the CSV files:
- **edge_list.csv**: 
    - `u`: ID of the source entity
    - `i`: ID of the target/recipient entity
    - `r`: ID of the relation between them
    - `ts`: the timestamp at which the edge occurs
    - `l`: the label of the edge
- **entity_text.csv**:
    - `i`: ID of the entity
    - `text`: text description of entity (eg: "obama", "egypt", "romanian")
- **relation_text.csv*: 
    - `i`: ID of the relation edge (as listed in edge_list.csv)
    - `text`: text content of the relationship (eg: "makes a visit", "engage in negotiation")
    
The first column is usually the _index_ of the record, which is just for enumeration. You can ignore this, we don't access it.

In [None]:
edge_list = pd.read_csv(os.path.join(dataset_root, 'edge_list.csv'))
num_node = max(edge_list['u'].max(), edge_list['i'].max())
num_rel = edge_list['r'].max()

### Question 1 (5 points)

In [None]:
# embed the node text attributes
entity_embeddings = [np.zeros([1, hidden_size])]
entity_text_reader = pd.read_csv(os.path.join(dataset_root, 'entity_text.csv'), chunksize=1000)

for batch in entity_text_reader:
    id_batch = batch['i'].tolist()
    text_batch = batch['text'].tolist()
    if 0 in id_batch:  # ignore the first row
        id_batch = id_batch[1:]
        text_batch = text_batch[1:] # you need to pass this `list` into the tokenizer + BERT pipeline
    
    ## Question 1: Embed the text attributes from the entity nodes
    ############# Your code here ############
    ## (~4-5 lines of code)
    
    # A. tokenize the text and push it to the CUDA device
    
    # B. put it through the BERT model (`BERT_model`) and extract the tensors from the output at the [1] index
    
    # C. put the output tensors on CPU
    
    # D. remember to add the output (in numpy format) of the current batch into the `entity_embeddings` np.array
        
    #########################################
            
# concatenate the outputs from each batch into one large tensor
entity_embeddings = np.concatenate(entity_embeddings, axis=0)
print([entity_embeddings.shape, num_node])
assert len(entity_embeddings) == num_node + 1

### Question 2 (5 points)

In [None]:
rel_embeddings = []
rel_text_reader = pd.read_csv(os.path.join(dataset_root, 'relation_text.csv'), chunksize=1000)

# embed the edge text attributes
for batch in rel_text_reader:
    id_batch = batch['i'].tolist()
    text_batch = batch['text'].tolist()
    if 0 in id_batch:  # ignore the first row
        id_batch = id_batch[1:]
        text_batch = text_batch[1:] # you need to pass this `list` into the tokenizer + BERT pipeline
        
    ## Question 2: Embed the text attributes from the relationships
    ############# Your code here ############
    ## (~4-5 lines of code)
    
    # A. tokenize the text and push it to the CUDA device
    
    # B. put it through the BERT model (`BERT_model`) and extract the tensors from the output at the [1] index
    
    # C. put the output tensors on CPU
    
    # D. remember to add the output (in numpy format) of the current batch into the `rel_embeddings` np.array
        
    #########################################

rel_embeddings = np.concatenate(rel_embeddings, axis=0)
assert len(rel_embeddings) == num_rel
print(rel_embeddings.shape)

## 4. Creating a dataset

To convert this disparate encoded data into a machine-learnable format, we'll be wrapping it in a `Data` object for easier access downstream. Each `Data` instance contains a single relations

In [None]:
class Data:
    def __init__(
            self, 
            src_node_ids: np.ndarray, 
            dst_node_ids: np.ndarray, 
            node_interact_times: np.ndarray, 
            edge_ids: np.ndarray, 
            labels: np.ndarray
        ):
        """
        Data object to store the nodes interaction information.
        
        :param src_node_ids: ndarray
        :param dst_node_ids: ndarray
        :param node_interact_times: ndarray
        :param edge_ids: ndarray
        :param labels: ndarray
        """
        self.src_node_ids = src_node_ids
        self.dst_node_ids = dst_node_ids
        self.node_interact_times = node_interact_times
        self.edge_ids = edge_ids
        self.labels = labels
        self.num_interactions = len(src_node_ids)
        self.unique_node_ids = set(src_node_ids) | set(dst_node_ids)
        self.num_unique_nodes = len(self.unique_node_ids)

Here, we simply package the data we've embedded into a nice machine-learnable format for an edge classification setting.

Implement the code to create the test data using the training and validation data as a template

### Question 3 (10 points)

In [None]:
import random
random.seed(2020)

def get_edge_classification_data(dataset_name: str, val_ratio: float, test_ratio: float):
    """
    generate data for link prediction task (inductive & transductive settings)
    
    :param dataset_name: str, dataset name
    :param val_ratio: float, validation data ratio
    :param test_ratio: float, test data ratio
    :return: node_raw_features, edge_raw_features, (np.ndarray),
            full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data, (Data object)
    """
    # Load data and train val test split
    graph_df = pd.read_csv(os.path.join(dataset_root, 'edge_list.csv'))
    node_num = max(graph_df['u'].max(), graph_df['i'].max()) + 1

    graph_df.ts = graph_df.ts // 15  # timestamp
    cat_num = graph_df['label'].max() + 1
    rel_num = cat_num

    src_node_ids = graph_df.u.values.astype(np.longlong)
    dst_node_ids = graph_df.i.values.astype(np.longlong)
    node_interact_times = graph_df.ts.values.astype(np.float64)
    edge_ids = graph_df.label.values.astype(np.longlong) #graph_df.r.values.astype(np.longlong)
    labels = graph_df.label.values

    full_data = Data(
                    src_node_ids=src_node_ids, 
                    dst_node_ids=dst_node_ids, 
                    node_interact_times=node_interact_times, 
                    edge_ids=edge_ids, 
                    labels=labels
                )

    # get the timestamp of validate and test set
    val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)]))
    
    # union to get node set
    node_set = set(src_node_ids) | set(dst_node_ids)
    num_total_unique_node_ids = len(node_set)

    # compute nodes which appear at test time
    test_node_set = set(src_node_ids[node_interact_times > val_time]).union(set(dst_node_ids[node_interact_times > val_time]))
    # sample nodes which we keep as new nodes (to test inductiveness), so then we have to remove all their edges from training
    new_test_node_sample_size = min(len(test_node_set), int(0.1 * num_total_unique_node_ids))
    new_test_node_set = set(random.sample(list(test_node_set), new_test_node_sample_size))

    # mask for each source and destination to denote whether they are new test nodes
    new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
    new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values

    # mask, which is true for edges with both destination and source not being new test nodes (because we want to remove all edges involving any new test node)
    observed_edges_mask = np.logical_and(~new_test_source_mask, ~new_test_destination_mask)

    # for train data, we keep edges happening before the validation time which do not involve any new node, used for inductiveness
    train_mask = np.logical_and(node_interact_times <= val_time, observed_edges_mask)

    train_data = Data(
                    src_node_ids=src_node_ids[train_mask], 
                    dst_node_ids=dst_node_ids[train_mask],
                    node_interact_times=node_interact_times[train_mask],
                    edge_ids=edge_ids[train_mask], 
                    labels=labels[train_mask]
                )

    # define the new nodes sets for testing inductiveness of the model
    train_node_set = set(train_data.src_node_ids).union(train_data.dst_node_ids)
    assert len(train_node_set & new_test_node_set) == 0

    val_mask = np.logical_and(node_interact_times <= test_time, node_interact_times > val_time)
    test_mask = node_interact_times > test_time

    ## Question 3: Create the test and validation data in a similar way as the training set using the masks
    ############# Your code here ############
    ## (~2 lines of code)
    
    val_data = None
    
    test_data = None
        
    #########################################    

    # output some graph statistics
    print("The dataset has {} interactions, involving {} different nodes".format(
        full_data.num_interactions, full_data.num_unique_nodes))
    print("The training dataset has {} interactions, involving {} different nodes".format(
        train_data.num_interactions, train_data.num_unique_nodes))
    print("The validation dataset has {} interactions, involving {} different nodes".format(
        val_data.num_interactions, val_data.num_unique_nodes))
    print("The test dataset has {} interactions, involving {} different nodes".format(
        test_data.num_interactions, test_data.num_unique_nodes))

    return full_data, train_data, val_data, test_data, cat_num

In [None]:
# create data splits
full_data, train_data, val_data, test_data, cat_num = get_edge_classification_data("GDELT", 0.15, 0.15)

Here, we create a wrapper that stores the indices of the nodes within each split, this makes access easier in the dataloader.

> You don't need to understand this code, feel free to run it and move on!

In [None]:
class CustomizedDataset(torch.utils.data.Dataset):
    def __init__(self, indices_list: list):
        """
        Customized dataset.
        
        :param indices_list: list, list of indices
        """
        super(CustomizedDataset, self).__init__()

        self.indices_list = indices_list

    def __getitem__(self, idx: int):
        """
        get item at the index in self.indices_list
        
        :param idx: int, the index
        :return:
        """
        return self.indices_list[idx]

    def __len__(self):
        return len(self.indices_list)


def get_idx_data_loader(indices_list: list, batch_size: int, shuffle: bool):
    """
    get data loader that iterates over indices
    
    :param indices_list: list, list of indices
    :param batch_size: int, batch size
    :param shuffle: boolean, whether to shuffle the data
    :return: data_loader, DataLoader
    """
    dataset = CustomizedDataset(indices_list=indices_list)

    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             drop_last=False,
                             num_workers=2)
    return data_loader

In [None]:
# create dataloaders

train_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(train_data.src_node_ids))), batch_size=128, shuffle=False)
val_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(val_data.src_node_ids))), batch_size=128, shuffle=False)
test_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(test_data.src_node_ids))), batch_size=128, shuffle=False)

## 5. Creating the NeighborSampler

Since the graph is very large and dense, training a GNN on it would be very compute intensive. Instead, we can  smaller fragments/subgraphs of the graph to make training more sustainable.

> You do not need to understand this part of the code, just run it before moving on!

In [None]:
class NeighborSampler:
    def __init__(self, adj_list: list, seed: int = None):
        """
        Neighbor sampler.
        
        :param adj_list: list, list of list, where each element is a list of triple tuple (node_id, edge_id, timestamp)
        :param seed: int, random seed
        """
        self.seed = seed

        # list of each node's neighbor ids, edge ids and interaction times, which are sorted by interaction times
        self.nodes_neighbor_ids = []
        self.nodes_edge_ids = []
        self.nodes_neighbor_times = []

        # the list at the first position in adj_list is empty, hence, sorted() will return an empty list for the first position
        # its corresponding value in self.nodes_neighbor_ids, self.nodes_edge_ids, self.nodes_neighbor_times will also be empty with length 0
        for node_idx, per_node_neighbors in enumerate(adj_list):
            # per_node_neighbors is a list of tuples (neighbor_id, edge_id, timestamp)
            # sort the list based on timestamps, sorted() function is stable
            # Note that sort the list based on edge id is also correct, as the original data file ensures the interactions are chronological
            sorted_per_node_neighbors = sorted(per_node_neighbors, key=lambda x: x[2])
            self.nodes_neighbor_ids.append(np.array([x[0] for x in sorted_per_node_neighbors]))
            self.nodes_edge_ids.append(np.array([x[1] for x in sorted_per_node_neighbors]))
            self.nodes_neighbor_times.append(np.array([x[2] for x in sorted_per_node_neighbors]))

        if self.seed is not None:
            self.random_state = np.random.RandomState(self.seed)

    def find_neighbors_before(self, node_id: int, interact_time: float):
        """
        extracts all the interactions happening before interact_time (less than interact_time) for node_id in the overall interaction graph
        the returned interactions are sorted by time.
        
        :param node_id: int, node id
        :param interact_time: float, interaction time
        :param return_sampled_probabilities: boolean, whether return the sampled probabilities of neighbors
        :return: neighbors, edge_ids, timestamps and sampled_probabilities (if return_sampled_probabilities is True) with shape (historical_nodes_num, )
        """
        # return index i, which satisfies list[i - 1] < v <= list[i]
        # return 0 for the first position in self.nodes_neighbor_times since the value at the first position is empty
        i = np.searchsorted(self.nodes_neighbor_times[node_id], interact_time)
        return self.nodes_neighbor_ids[node_id][:i], self.nodes_edge_ids[node_id][:i], self.nodes_neighbor_times[node_id][:i]

    def get_historical_neighbors(self, node_ids: np.ndarray, node_interact_times: np.ndarray, num_neighbors: int = 20):
        """
        get historical neighbors of nodes in node_ids with interactions before the corresponding time in node_interact_times
        
        :param node_ids: ndarray, shape (batch_size, ) or (*, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ) or (*, ), node interaction times
        :param num_neighbors: int, number of neighbors to sample for each node
        :return:
        """
        assert num_neighbors > 0, 'Number of sampled neighbors for each node should be greater than 0!'
        # All interactions described in the following three matrices are sorted in each row by time
        # each entry in position (i,j) represents the id of the j-th dst node of src node node_ids[i] with an interaction before node_interact_times[i]
        # ndarray, shape (batch_size, num_neighbors)
        nodes_neighbor_ids = np.zeros((len(node_ids), num_neighbors)).astype(np.longlong)
        # each entry in position (i,j) represents the id of the edge with src node node_ids[i] and dst node nodes_neighbor_ids[i][j] with an interaction before node_interact_times[i]
        # ndarray, shape (batch_size, num_neighbors)
        nodes_edge_ids = np.zeros((len(node_ids), num_neighbors)).astype(np.longlong)
        # each entry in position (i,j) represents the interaction time between src node node_ids[i] and dst node nodes_neighbor_ids[i][j], before node_interact_times[i]
        # ndarray, shape (batch_size, num_neighbors)
        nodes_neighbor_times = np.zeros((len(node_ids), num_neighbors)).astype(np.float32)

        # extracts all neighbors ids, edge ids and interaction times of nodes in node_ids, which happened before the corresponding time in node_interact_times
        for idx, (node_id, node_interact_time) in enumerate(zip(node_ids, node_interact_times)):
            # find neighbors that interacted with node_id before time node_interact_time
            node_neighbor_ids, node_edge_ids, node_neighbor_times = \
                self.find_neighbors_before(node_id=node_id, interact_time=node_interact_time)

            if len(node_neighbor_ids) > 0:
                # Take most recent interactions with number num_neighbors
                node_neighbor_ids = node_neighbor_ids[-num_neighbors:]
                node_edge_ids = node_edge_ids[-num_neighbors:]
                node_neighbor_times = node_neighbor_times[-num_neighbors:]
                # put the neighbors' information at the back positions
                nodes_neighbor_ids[idx, num_neighbors - len(node_neighbor_ids):] = node_neighbor_ids
                nodes_edge_ids[idx, num_neighbors - len(node_edge_ids):] = node_edge_ids
                nodes_neighbor_times[idx, num_neighbors - len(node_neighbor_times):] = node_neighbor_times

        # three ndarrays, with shape (batch_size, num_neighbors)
        return nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times

    def get_multi_hop_neighbors(self, num_hops: int, node_ids: np.ndarray, node_interact_times: np.ndarray, num_neighbors: int = 20):
        """
        get historical neighbors of nodes in node_ids within num_hops hops
        
        :param num_hops: int, number of sampled hops
        :param node_ids: ndarray, shape (batch_size, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ), node interaction times
        :param num_neighbors: int, number of neighbors to sample for each node
        :return:
        """
        assert num_hops > 0, 'Number of sampled hops should be greater than 0!'

        # get the temporal neighbors at the first hop
        # nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times -> ndarray, shape (batch_size, num_neighbors)
        nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times = self.get_historical_neighbors(node_ids=node_ids,
                                                                                                 node_interact_times=node_interact_times,
                                                                                                 num_neighbors=num_neighbors)
        # three lists to store the neighbor ids, edge ids and interaction timestamp information
        nodes_neighbor_ids_list = [nodes_neighbor_ids]
        nodes_edge_ids_list = [nodes_edge_ids]
        nodes_neighbor_times_list = [nodes_neighbor_times]
        for hop in range(1, num_hops):
            # get information of neighbors sampled at the current hop
            # three ndarrays, with shape (batch_size * num_neighbors ** hop, num_neighbors)
            nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times = self.get_historical_neighbors(node_ids=nodes_neighbor_ids_list[-1].flatten(),
                                                                                                     node_interact_times=nodes_neighbor_times_list[-1].flatten(),
                                                                                                     num_neighbors=num_neighbors)
            # three ndarrays with shape (batch_size, num_neighbors ** (hop + 1))
            nodes_neighbor_ids = nodes_neighbor_ids.reshape(len(node_ids), -1)
            nodes_edge_ids = nodes_edge_ids.reshape(len(node_ids), -1)
            nodes_neighbor_times = nodes_neighbor_times.reshape(len(node_ids), -1)

            nodes_neighbor_ids_list.append(nodes_neighbor_ids)
            nodes_edge_ids_list.append(nodes_edge_ids)
            nodes_neighbor_times_list.append(nodes_neighbor_times)

        # tuple, each element in the tuple is a list of num_hops ndarrays, each with shape (batch_size, num_neighbors ** current_hop)
        return nodes_neighbor_ids_list, nodes_edge_ids_list, nodes_neighbor_times_list

    def get_all_first_hop_neighbors(self, node_ids: np.ndarray, node_interact_times: np.ndarray):
        """
        get historical neighbors of nodes in node_ids at the first hop with max_num_neighbors as the maximal number of neighbors (make the computation feasible)

        :param node_ids: ndarray, shape (batch_size, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ), node interaction times
        :return:
        """
        # three lists to store the first-hop neighbor ids, edge ids and interaction timestamp information, with batch_size as the list length
        nodes_neighbor_ids_list, nodes_edge_ids_list, nodes_neighbor_times_list = [], [], []
        # get the temporal neighbors at the first hop
        for idx, (node_id, node_interact_time) in enumerate(zip(node_ids, node_interact_times)):
            # find neighbors that interacted with node_id before time node_interact_time
            node_neighbor_ids, node_edge_ids, node_neighbor_times, _ = self.find_neighbors_before(node_id=node_id,
                                                                                                  interact_time=node_interact_time,
                                                                                                  return_sampled_probabilities=False)
            nodes_neighbor_ids_list.append(node_neighbor_ids)
            nodes_edge_ids_list.append(node_edge_ids)
            nodes_neighbor_times_list.append(node_neighbor_times)

        return nodes_neighbor_ids_list, nodes_edge_ids_list, nodes_neighbor_times_list

    def reset_random_state(self):
        """
        reset the random state by self.seed
        
        :return:
        """
        self.random_state = np.random.RandomState(self.seed)

### Question 4 (10 points)

In [None]:
def get_neighbor_sampler(data: Data, seed: int = None):
    """
    get neighbor sampler for the large knowledge graph
    
    :param data: Data
    :param seed: int, random seed
    :return:
    """
    max_node_id = max(data.src_node_ids.max(), data.dst_node_ids.max())
    # the adjacency vector stores edges for each node (source or destination), undirected
    # adj_list, list of list, where each element is a list of triple tuple (node_id, edge_id, timestamp)
    # the list at the first position in adj_list is empty

    adj_list = [[] for _ in range(max_node_id + 1)]
    
    # iterate through the graph's edges
    for src_node_id, dst_node_id, edge_id, node_interact_time in zip(data.src_node_ids, data.dst_node_ids, data.edge_ids, data.node_interact_times):
        
        ## Question 4: Create the UNDIRECTED adjacency list by appending each node's neighbors and 
        ##             its related information based on the description of adj_list's contents above.
        ############# Your code here ############
        ## (~2 lines of code)
        
        

        #########################################    
        

    return NeighborSampler(adj_list=adj_list, seed=seed)

In [None]:
full_neighbor_sampler = get_neighbor_sampler(data=full_data, seed=1)

In [None]:
# initialize training neighbor sampler to retrieve temporal graph
train_neighbor_sampler = get_neighbor_sampler(data=train_data, seed=0)

## 6. Model implementation

Now that we've finished encoding the text attributes in the graph dataset, we can finally start building the main GNN that learns on this dynamic encoded knowledge graph. .

Our model of choice to learn on the knowledge graph is [GraphMixer](https://arxiv.org/abs/2302.11636). It's a powerful architecture that tries to substitute expensive self-attention operations in vision-transformer-like models using MLP-like modules within. We'll build a mechanism to encode the timestamps these events occur at, as well as a binary edge classifier to make the final link prediction.

### 6A: Building the TimeEncoder and Edge Classifier (Question 5-6: 20 points)

Let's first get started with the time encoder and binary edge classifier. The time encoder takes in the timestamp and outputs an embedding vector representing it. While the edge classifier is a straightforward MLP that takes in edge features after message passing, the time encoder is initialized in a certain way. For an event timestamp `t`, we have a **fixed** weight $\mathbf{\omega}$ that embeds the timestamp into different frequencies along the vector's dimensions to .

$$
\mathbf{\omega} = \{\alpha^{-(i-1)/\beta}\}^{d}_{i=1} \in [0, t) \\
\mathbf{p} = \cos{(t\mathbf{\omega})} \in [-1, +1]
$$

Specifically,
$$
\mathbf{p} = \begin{bmatrix}
    \cos{(t \cdot \alpha^{-0/\beta})} \\
    \cos{(t \cdot \alpha^{-1/\beta})} \\
    \cos{(t \cdot \alpha^{-2/\beta})} \\
    \vdots \\
    \cos{(t \cdot \alpha^{-(d-1)/\beta})} \\
\end{bmatrix} \in \mathbb{R}^{d}
$$

Here, $\alpha$ and $\beta$ correspond to hyperparameters based on the maximum timestamp we're encoding $t_{\text{max}}$ (we've already set these for you). Here, we set $d=100$ (it's called `time_dim` below). 

In [None]:
import torch.nn as nn

class TimeEncoder(nn.Module):

    def __init__(self, time_dim: int, parameter_requires_grad: bool = True):
        """
        Time encoder.
        
        :param time_dim: int, dimension of time encodings
        :param parameter_requires_grad: boolean, whether the parameter in TimeEncoder needs gradient
        """
        super(TimeEncoder, self).__init__()

        self.time_dim = time_dim
        # trainable parameters for time encoding
        self.w = nn.Linear(1, time_dim)
        self.w.weight = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, time_dim, dtype=np.float32))).reshape(time_dim, -1))
        self.w.bias = nn.Parameter(torch.zeros(time_dim))

        if not parameter_requires_grad:
            self.w.weight.requires_grad = False
            self.w.bias.requires_grad = False

    def forward(self, timestamps: torch.Tensor):
        """
        compute time encodings of time in timestamps
        
        :param timestamps: Tensor, shape (batch_size, seq_len)
        :return:
        """

        ## Question 5: Encode the timestamps using the linear layer with pre-initialized weights
        ############# Your code here ############
        ## (~2 lines of code)
        
        # reshape the `timestamps` tensor to be of shape (batch_size, seq_len, 1)
        
        # project this output using layer `w` and pass it through a cosine – call it `output`
        
        #########################################

        return output # Tensor of shape (batch_size, seq_len, time_dim)

class MLPClassifier_edge(nn.Module):
    def __init__(self, input_dim: int, dropout: float = 0.1, cat_num: int = 0):
        """
        Multi-Layer Perceptron Classifier.
        
        :param input_dim: int, dimension of input
        :param dropout: float, dropout rate
        """
        super().__init__()
        self.fc1 = nn.Linear(2*input_dim, input_dim, bias = True)
        self.fc2 = nn.Linear(input_dim, input_dim, bias = True)
        self.fc3 = nn.Linear(input_dim, cat_num, bias = True)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(dropout) # feel free to add this anywhere in your network (except the final layer)

    def forward(self, x_1: torch.Tensor, x_2: torch.Tensor, rel_embs: torch.Tensor):
        """
        multi-layer perceptron classifier forward process
        
        :param x: Tensor, shape (*, input_dim)
        :return:
        """
        ## Question 6: Pass the concatenated node embeddings `x_1` and `x_2` through the MLP
        ############# Your code here ############
        ## (~3 lines of code)
        
        # NOTE: be sure to concatenate along the correct dimensions!
        
        #########################################
        
        return output # Tensor of shape (*, input_dim)

### 6B. Building GraphMixer

GraphMixer has a few components described below:
- **node encoder**: captures node embeddings by incorporating the entity and associated relationship
- **edge/link encoder**: learns edge embeddings using channel and token mixing
- **link classifier**: predicts whether an edge exists between two nodes at some time $t_0$. It takes in the outputs of the link encoder and node encoder and outputs a binary classification. We've already implemented `MLPClassifier_edge` above.

We'll be building the entire model in a top-down fashion with the high-level GraphMixer first, followed by the miscellaneous modules used internally.

In [None]:
class GraphMixer(nn.Module):

    def __init__(
            self, 
            node_raw_features: np.ndarray, 
            edge_raw_features: np.ndarray, 
            neighbor_sampler: NeighborSampler,
            time_feat_dim: int, 
            num_tokens: int, 
            num_layers: int = 2, 
            token_dim_expansion_factor: float = 0.5,
            channel_dim_expansion_factor: float = 4.0, 
            dropout: float = 0.1, device: str = 'cpu'
        ):
        """
        TCL model.
        
        :param node_raw_features: ndarray, shape (num_nodes + 1, node_feat_dim)
        :param edge_raw_features: ndarray, shape (num_edges + 1, edge_feat_dim)
        :param neighbor_sampler: neighbor sampler
        :param time_feat_dim: int, dimension of time features (encodings)
        :param num_tokens: int, number of tokens
        :param num_layers: int, number of transformer layers
        :param token_dim_expansion_factor: float, dimension expansion factor for tokens
        :param channel_dim_expansion_factor: float, dimension expansion factor for channels
        :param dropout: float, dropout rate
        :param device: str, device
        """
        super(GraphMixer, self).__init__()

        self.node_raw_features = torch.from_numpy(node_raw_features.astype(np.float32)).to(device)
        self.edge_raw_features = torch.from_numpy(edge_raw_features.astype(np.float32)).to(device)

        self.neighbor_sampler = neighbor_sampler
        self.node_feat_dim = self.node_raw_features.shape[1]
        self.edge_feat_dim = self.edge_raw_features.shape[1]
        self.time_feat_dim = time_feat_dim
        self.num_tokens = num_tokens
        self.num_layers = num_layers
        self.token_dim_expansion_factor = token_dim_expansion_factor
        self.channel_dim_expansion_factor = channel_dim_expansion_factor
        self.dropout = dropout
        self.device = device

        self.num_channels = self.edge_feat_dim
        # in GraphMixer, the time encoding function is not trainable
        self.time_encoder = TimeEncoder(time_dim=time_feat_dim, parameter_requires_grad=False)
        self.projection_layer = nn.Linear(self.edge_feat_dim + time_feat_dim, self.num_channels)

        self.mlp_mixers = nn.ModuleList([
            MLPMixer(num_tokens=self.num_tokens, num_channels=self.num_channels,
                     token_dim_expansion_factor=self.token_dim_expansion_factor,
                     channel_dim_expansion_factor=self.channel_dim_expansion_factor, dropout=self.dropout)
            for _ in range(self.num_layers)
        ])

        self.output_layer = nn.Linear(in_features=self.num_channels + self.node_feat_dim, out_features=self.node_feat_dim, bias=True)

    def compute_src_dst_node_temporal_embeddings(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray,
                                                 node_interact_times: np.ndarray, num_neighbors: int = 20, time_gap: int = 2000):
        """
        compute source and destination node temporal embeddings
        
        :param src_node_ids: ndarray, shape (batch_size, )
        :param dst_node_ids: ndarray, shape (batch_size, )
        :param node_interact_times: ndarray, shape (batch_size, )
        :param num_neighbors: int, number of neighbors to sample for each node
        :param time_gap: int, time gap for neighbors to compute node features
        :return:
        """
        # Tensor, shape (batch_size, node_feat_dim)
        src_node_embeddings = self.compute_node_temporal_embeddings(node_ids=src_node_ids, node_interact_times=node_interact_times,
                                                                    num_neighbors=num_neighbors, time_gap=time_gap)
        # Tensor, shape (batch_size, node_feat_dim)
        dst_node_embeddings = self.compute_node_temporal_embeddings(node_ids=dst_node_ids, node_interact_times=node_interact_times,
                                                                    num_neighbors=num_neighbors, time_gap=time_gap)

        return src_node_embeddings, dst_node_embeddings

    def compute_node_temporal_embeddings(self, node_ids: np.ndarray, node_interact_times: np.ndarray,
                                         num_neighbors: int = 20, time_gap: int = 2000):
        """
        given node ids node_ids, and the corresponding time node_interact_times, return the temporal embeddings of nodes in node_ids

        :param node_ids: ndarray, shape (batch_size, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ), node interaction times
        :param num_neighbors: int, number of neighbors to sample for each node
        :param time_gap: int, time gap for neighbors to compute node features
        :return:
        """
        # link encoder
        # get temporal neighbors, including neighbor ids, edge ids and time information
        # neighbor_node_ids, ndarray, shape (batch_size, num_neighbors)
        # neighbor_edge_ids, ndarray, shape (batch_size, num_neighbors)
        # neighbor_times, ndarray, shape (batch_size, num_neighbors)
        neighbor_node_ids, neighbor_edge_ids, neighbor_times = \
            self.neighbor_sampler.get_historical_neighbors(node_ids=node_ids,
                                                           node_interact_times=node_interact_times,
                                                           num_neighbors=num_neighbors)

        # Tensor, shape (batch_size, num_neighbors, edge_feat_dim)
        nodes_edge_raw_features = self.edge_raw_features[torch.from_numpy(neighbor_edge_ids)]
        # Tensor, shape (batch_size, num_neighbors, time_feat_dim)
        nodes_neighbor_time_features = self.time_encoder(timestamps=torch.from_numpy(node_interact_times[:, np.newaxis] - neighbor_times).float().to(self.device))

        # ndarray, set the time features to all zeros for the padded timestamp
        nodes_neighbor_time_features[torch.from_numpy(neighbor_node_ids == 0)] = 0.0

        # Tensor, shape (batch_size, num_neighbors, edge_feat_dim + time_feat_dim)
        combined_features = torch.cat([nodes_edge_raw_features, nodes_neighbor_time_features], dim=-1)
        # Tensor, shape (batch_size, num_neighbors, num_channels)
        combined_features = self.projection_layer(combined_features)

        for mlp_mixer in self.mlp_mixers:
            # Tensor, shape (batch_size, num_neighbors, num_channels)
            combined_features = mlp_mixer(input_tensor=combined_features)

        # Tensor, shape (batch_size, num_channels)
        combined_features = torch.mean(combined_features, dim=1)

        # node encoder
        # get temporal neighbors of nodes, including neighbor ids
        # time_gap_neighbor_node_ids, ndarray, shape (batch_size, time_gap)
        time_gap_neighbor_node_ids, _, _ = self.neighbor_sampler.get_historical_neighbors(node_ids=node_ids,
                                                                                          node_interact_times=node_interact_times,
                                                                                          num_neighbors=time_gap)

        # Tensor, shape (batch_size, time_gap, node_feat_dim)
        nodes_time_gap_neighbor_node_raw_features = self.node_raw_features[torch.from_numpy(time_gap_neighbor_node_ids)]

        # Tensor, shape (batch_size, time_gap)
        valid_time_gap_neighbor_node_ids_mask = torch.from_numpy((time_gap_neighbor_node_ids > 0).astype(np.float32))
        # note that if a node has no valid neighbor (whose valid_time_gap_neighbor_node_ids_mask are all zero), directly set the mask to -np.inf will make the
        # scores after softmax be nan. Therefore, we choose a very large negative number (-1e10) instead of -np.inf to tackle this case
        # Tensor, shape (batch_size, time_gap)
        valid_time_gap_neighbor_node_ids_mask[valid_time_gap_neighbor_node_ids_mask == 0] = -1e10
        # Tensor, shape (batch_size, time_gap)
        scores = torch.softmax(valid_time_gap_neighbor_node_ids_mask, dim=1).to(self.device)

        # Tensor, shape (batch_size, node_feat_dim), average over the time_gap neighbors
        nodes_time_gap_neighbor_node_agg_features = torch.mean(nodes_time_gap_neighbor_node_raw_features * scores.unsqueeze(dim=-1), dim=1)

        # Tensor, shape (batch_size, node_feat_dim), add features of nodes in node_ids
        output_node_features = nodes_time_gap_neighbor_node_agg_features + self.node_raw_features[torch.from_numpy(node_ids)]

        # Tensor, shape (batch_size, node_feat_dim)
        node_embeddings = self.output_layer(torch.cat([combined_features, output_node_features], dim=1))

        return node_embeddings

    def set_neighbor_sampler(self, neighbor_sampler: NeighborSampler):
        """
        set neighbor sampler to neighbor_sampler and reset the random state (for reproducing the results for uniform and time_interval_aware sampling)

        :param neighbor_sampler: NeighborSampler, neighbor sampler
        :return:
        """
        self.neighbor_sampler = neighbor_sampler
        assert self.neighbor_sampler.seed is not None
        self.neighbor_sampler.reset_random_state()

### 6C. Implementing GraphMixer's link encoder MLP (Question 7: 10 points)

Follow the below schematic to implement the `ffn` module representing the feed-fordward net. For an input $\mathbf{x}$ of shape `(*, input_dim)`,

$$
\mathbf{z} = \operatorname{GELU}(\operatorname{Linear}(\mathbf{x}))
$$
$$
\mathbf{a} = \operatorname{Dropout}(\mathbf{z})
$$
$$
\mathbf{y} = \operatorname{Dropout}(\operatorname{Linear}(\mathbf{a}))
$$

**NOTE:** Ensure the dimensions match for the `Linear` layers. The MLP expands the input by some expansion factor and brings it back to the original input dimension. You can either create this with separate layers or can bunch them together using `nn.Sequential(...)`. 

In [None]:
class FeedForwardNet(nn.Module):

    def __init__(self, input_dim: int, dim_expansion_factor: float, dropout: float = 0.0):
        """
        two-layered MLP with GELU activation function.
        :param input_dim: int, dimension of input
        :param dim_expansion_factor: float, dimension expansion factor
        :param dropout: float, dropout rate
        """
        super(FeedForwardNet, self).__init__()

        self.input_dim = input_dim
        self.dim_expansion_factor = dim_expansion_factor
        self.dropout = dropout
        
        ## Question 7: Implement the feed-forward network used in MLPMixer
        ############# Your code here ############
        ## (~4-6 lines of code)
        
        self.ffn = None
        
        #########################################

    def forward(self, x: torch.Tensor):
        """
        feed forward net forward process
        :param x: Tensor, shape (*, input_dim)
        :return:
        """
        return self.ffn(x)

### 6D. Implmenting the MLPMixer (Question 8: 10 points)

The core of our GraphMixer relies on MLPMixer which you'll be implementing below. Carefully follow the equations below while constructing the `forward(...)` method in the `MLPMixer` object. The layer takes in an input tensor $\mathbf{H}_{\text{input}}$ of shape `(batch_Size, num_tokens, num_channels)` and outputs a tensor of shape `(batch_size, num_tokens, num_channels)` after "mixing" the channel information (hence, the name)!

$$
\mathbf{H}_{\text{token}} = \operatorname{FFN}(\mathbf{H}_{\text{input}})
$$
$$
\operatorname{mix-tokens-and-channels}(\mathbf{H}_{\text{token}})
$$
$$
\mathbf{H}_{\text{token}}^{\prime} = \mathbf{H}_{\text{token}} + \mathbf{H}_{\text{input}}
$$
$$
\mathbf{H}_{\text{channel}} = \operatorname{FFN}(\mathbf{H}_{\text{token}}^{\prime})
$$
$$
\mathbf{H}_{\text{output}} = \mathbf{H}_{\text{channel}} + \mathbf{H}_{\text{token}}^{\prime}
$$


Remember, in PyTorch, you can permute (swap) the dimensions of a tensor using `.permute(...)` and state the indices you want swapped. For example, to transpose a matrix `a`, you can use `a.permute(1, 0)` that swaps the first and second dimensions. You can do this for any dimensions by keying in the dimensions to be swapped.

Make sure to use the `FeedForwardNet` module you've implemented above. 

In [None]:
class MLPMixer(nn.Module):

    def __init__(
            self, 
            num_tokens: int, 
            num_channels: int, 
            token_dim_expansion_factor: float = 0.5,
            channel_dim_expansion_factor: float = 4.0, 
            dropout: float = 0.0
        ):
        """
        MLP Mixer.
        
        :param num_tokens: int, number of tokens
        :param num_channels: int, number of channels
        :param token_dim_expansion_factor: float, dimension expansion factor for tokens
        :param channel_dim_expansion_factor: float, dimension expansion factor for channels
        :param dropout: float, dropout rate
        """
        super(MLPMixer, self).__init__()

        self.token_norm = nn.LayerNorm(num_tokens)
        self.token_feedforward = FeedForwardNet(input_dim=num_tokens, dim_expansion_factor=token_dim_expansion_factor,
                                                dropout=dropout)

        self.channel_norm = nn.LayerNorm(num_channels)
        self.channel_feedforward = FeedForwardNet(input_dim=num_channels, dim_expansion_factor=channel_dim_expansion_factor,
                                                  dropout=dropout)

    def forward(self, input_tensor: torch.Tensor):
        """
        mlp mixer to compute over tokens and channels
        
        :param input_tensor: Tensor, shape (batch_size, num_tokens, num_channels)
        :return:
        """
        # swap tokens and channels
        input_tensor = input_tensor.permute(0, 2, 1)
        
        ## Question 8: Implement the MLPMixer forward pass that mixes tokens and channels
        ############# Your code here ############
        ## (~5-6 lines of code)
        
        #########################################

        return output_tensor # Tensor of shape (batch_size, num_tokens, num_channels)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dynamic_backbone = GraphMixer(
                        node_raw_features=entity_embeddings, 
                        edge_raw_features=rel_embeddings, 
                        neighbor_sampler=train_neighbor_sampler,
                        time_feat_dim=100, 
                        num_tokens=30, 
                        num_layers=2, 
                        dropout=0.2, 
                        device=device
                    )

In [None]:
edge_classifier = MLPClassifier_edge(input_dim=entity_embeddings.shape[1], dropout=0.1, cat_num=cat_num)

model = nn.Sequential(dynamic_backbone, edge_classifier).to(device)

## 7. Training

We'll now initialize the GraphMixer model and train it on our preprocessed temporal knowledge graph.

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-4)
loss_func = nn.CrossEntropyLoss()

Notice above that our `model` module has two parts: `model[0]` referring to the backbone and `model[1]` referring to the edge classifier. We'll be accessing these separate parts as and when necessary hereon.

In [None]:
model[0].set_neighbor_sampler(train_neighbor_sampler)

### Question 9-10 (10 points)

In [None]:
for epoch in range(5):
    
    # store train losses and metrics
    train_total_loss, train_y_trues, train_y_predicts = 0.0, [], []
    train_idx_data_loader_tqdm = tqdm(train_idx_data_loader, ncols=120, desc=f"epoch:{epoch}")
    
    for batch_idx, train_data_indices in enumerate(train_idx_data_loader_tqdm):
        train_data_indices = train_data_indices.numpy()
        batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids, batch_labels = \
            train_data.src_node_ids[train_data_indices], train_data.dst_node_ids[train_data_indices], \
            train_data.node_interact_times[train_data_indices], train_data.edge_ids[train_data_indices], train_data.labels[train_data_indices]
        
        """
        We need to compute for positive and negative edges respectively, because the new sampling strategy 
        (for evaluation) allows the negative source nodes to be different from the source nodes. 
        
        This is different from previous works that just replace destination nodes with negative destination nodes
        get temporal embedding of source and destination nodes.
        
        You'll then have two Tensors both with shape (batch_size, node_feat_dim)
        """
        
        ## Question 9: Get the temporal embeddings for src and dest nodes using the backbone model's 
        ##            `compute_src_dst_node_temporal_embeddings` method. Use a time gap of 2000 and number of neighbors 30
        ############# Your code here ############
        ## (~1-2 lines of code)
        
        batch_src_node_embeddings, batch_dst_node_embeddings = None
        
        #########################################
        
        
        ## Question 10: Get the predicted probabilities from the Edge Classifier MLP
        ############# Your code here ############
        ## (~1-2 lines of code)
        ## Use `model[1]` and pass in the src and dest temporal node embeddings, as well as the raw edge features from the backbone

        # get predicted probabilities, shape (batch_size, )
        predicts = None
        
        #########################################
        
        pred_labels = torch.max(predicts, dim=1)[1]
        labels = torch.from_numpy(batch_labels).long().to(predicts.device)
        loss = loss_func(predicts, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_total_loss += loss.item()

    print("avg loss", train_total_loss / len(train_idx_data_loader))

### 8. Evaluation

We'll first write some helper functions to evaluate our model. This includes precision, recall, and F1-score. We'll use `scikit-learn`'s inbuilt functions for this. You can read the documentation [here](https://scikit-learn.org/0.15/modules/generated/sklearn.metrics.precision_score.html) if in doubt on how to invoke these methods.

### Question 11 (10 points)

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import average_precision_score, roc_auc_score

def calculate_metrics(predicts: torch.Tensor, labels: torch.Tensor):
    """
    get metrics for the edge classification task
    
    :param predicts: Tensor, shape (num_samples, )
    :param labels: Tensor, shape (num_samples, )
    :return:
        dictionary of metrics {'metric_name_1': metric_1, ...}
    """
    
    predicts = predicts.cpu().detach().numpy()
    labels = labels.cpu().numpy()
    
    ## Question 11: Implement macro, micro, and weighted metrics for precision, recall, and F1-score
    ############# Your code here ############
    ## (9 lines of code)
    
    P_macro = None
    R_macro = None
    F_macro = None

    P_micro = None
    R_micro = None
    F_micro = None

    P_weight = None
    R_weight = None
    F_weight = None

    #########################################

    return {
            'p_macro': P_macro, 
            'R_macro': R_macro, 
            'F_macro': F_macro, 
            'p_micro': P_micro, 
            'R_micro': R_micro, 
            'F_micro': F_micro, 
            'p_weighted': P_weight, 
            'R_weighted': R_weight, 
            'F_weighted': F_weight
        }

### 8A. Evaluating on validation data (Question 12-13: 10 points)

In [None]:
model[0].set_neighbor_sampler(full_neighbor_sampler)

def evaluate(model, evaluate_data, idx_dataloader):
    
    model.eval()
    
    with torch.no_grad():
        # store evaluate losses, trues and predicts
        evaluate_total_loss, evaluate_y_trues, evaluate_y_predicts = 0.0, [], []
        evaluate_idx_data_loader_tqdm = tqdm(idx_dataloader, ncols=120)
        
        for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
            evaluate_data_indices = evaluate_data_indices.numpy()
            batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids, batch_labels = \
                evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
                evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices], \
                evaluate_data.labels[evaluate_data_indices]
            
            # get temporal embedding of source and destination nodes
            # two Tensors, with shape (batch_size, node_feat_dim)
            
            ## Question 12: Get the temporal embeddings for src and dest nodes using the backbone model's 
            ##            `compute_src_dst_node_temporal_embeddings` method. Use a time gap of 2000 and number of neighbors 30
            ############# Your code here ############
            ## (~1-2 lines of code)
            ## You already did this for Question 9

            batch_src_node_embeddings, batch_dst_node_embeddings = None

            #########################################
            
            
            ## Question 13: Get the predicted probabilities from the Edge Classifier MLP
            ############# Your code here ############
            ## (~1-2 lines of code)
            ## Use `model[1]` and pass in the src and dest temporal node embeddings, as well as the raw edge features from the backbone
            ## You already did this for Question 10

            # get predicted probabilities, shape (batch_size, )
            predicts = None

            #########################################
            
            pred_labels = torch.max(predicts, dim=1)[1]
            labels = torch.from_numpy(batch_labels).int().type(torch.LongTensor).to(predicts.device)

            loss = loss_func(input=predicts, target=labels)
            evaluate_total_loss += loss.item()
            
            evaluate_y_trues.append(labels)
            evaluate_y_predicts.append(pred_labels)
            evaluate_idx_data_loader_tqdm.set_description(f'evaluate for the {batch_idx + 1}-th batch, evaluate loss: {loss.item()}')
        
        evaluate_total_loss /= (batch_idx + 1)
        evaluate_y_trues = torch.cat(evaluate_y_trues, dim=0)
        evaluate_y_predicts = torch.cat(evaluate_y_predicts, dim=0)
        
        return calculate_metrics(predicts=evaluate_y_predicts, labels=evaluate_y_trues)

evaluate(model, val_data, val_idx_data_loader)

### 8B. Evaluating on test data

In [None]:
evaluate(model, test_data, test_idx_data_loader)

We've only trained this model for 5 epochs here. If you have a larger compute budget, feel free to see if you can improve the model's test/val performance even more! Other tricks might include better neighbor samplers. Ultimately, for such large-scale graph learning tasks, miscellaneous optimizations must be made to ensure compute efficiency.

> This is the FINAL notebook-based assignment for this term. Hope you've enjoyed the coding exercises we arranged for you. 
> 
> If you like what you learned, consider joining Rex Ying's [Graph and Geometric Learning Lab](https://graph-and-geometric-learning.github.io/)!! We'd be thrilled to have you :)