# Lab on Graph Neural Networks

Teacher: Prof. Gianluca Moro

Teaching Assistant: Dr. Giacomo Frisoni

Acknowledgments: Dr. Lorenzo Molfetta for baselines and resource-intensive server-side training.

**Contact.** For any doubt, question, issue or help, you can always contact us at the following email addresses: {gianluca.moro, giacomo.frisoni}@unibo.it.

**Keywords.** Graph Neural Networks, Graph Representation Learning, Relational Deep Learning.

## 📜 Outline

Relational databases are the cornerstone of modern data management, forming the backbone of the digital economy. Their widespread adoption stems from their intuitive table-based structure, which simplifies data organization and maintenance, coupled with powerful query capabilities provided by languages like SQL. Given their ubiquity, relational databases serve as the foundation for AI systems across diverse domains such as e-commerce, social media, banking, healthcare, manufacturing, and open-source scientific research repositories. These databases naturally align with graph representations, where relationships between items in different tables are pivotal for advanced tasks like building recommender systems.

We will delve into the cutting-edge field of **Relational Deep Learning**.
Using a **real-world dataset from H&M**, a globally renowned fashion retailer, we will model their e-commerce relational database as a temporal heterogeneous graph. Leveraging this structured representation, we will train a state-of-the-art **Graph Neural Network to power product recommendations**.
Through this hands-on exercise, we will witness how explicitly harnessing the relationships between customers, products, and transactions can achieve remarkable performance gains.
Our approach will outperform even leading techniques such as **Large Language Models** and **Deep Tabular Models** like XGBoost and LightGBM.

# Relational Deep Learning

## ⚙️ Install Dependencies ($<$ 1 min)

In [None]:
%%capture

import os
import torch

os.environ['TORCH'] = torch.__version__

!pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install relbench[full]

RELBENCH is a public benchmark for solving predictive tasks over relational databases with GNNs.

First submitted on ArXiv on 29 July 2024 [[paper link](https://arxiv.org/abs/2407.20060)], it has been published by **Stanford University, Kumo.AI and the Max Planck Institute for Informatics** at **NeurIPS 2024**, Track on Datasets and Benchmarks.

<img src="https://relbench.stanford.edu/img/logo.png" alt="Stanford RelBench Logo" width="400">

---

RELBENCH enables training and evaluation of deep learning models on relational databases. RELBENCH supports framework agnostic data loading, task specification, standardized data splitting, standardized evaluation metrics, and a leaderboard for tracking progress.

<img src="https://camo.githubusercontent.com/a858a15c33d8aebde0cdae555260be83db968a4112816a1149a971fe367503f2/68747470733a2f2f72656c62656e63682e7374616e666f72642e6564752f696d672f72656c62656e63682d6669672e706e67" alt="Stanford RelBench Pipeline" width="750">

In [None]:
import relbench
relbench.__version__

'1.1.0'

## 📂 Dataset

RELBENCH contains **7 real-world datasets**, each with a *relational database* and a *set of realistic predictive tasks*.

*   A **relational database** consists of a set of tables connected via primary-foreign key relationships. Each table has columns storing diverse information about each entity. Some tables also come with time columns, indicating the time at which the entity is created (e.g., transaction date).
*   A **predictive task** is defined by a training table with columns for Entity ID, seed time, and target labels. The seed time indicates *at which time* the target is to be predicted, filtering out future data. Zooming out, tasks are grouped into **three task types**:
  *   **Entity classification.** Predict binary labels of a given entity at a given seed time. Note that here only information from the single entity table is used.
  *   **Entity regression.** Predict numerical labels of an entity at a given seed time.
  *   **Link prediction.** Recommendation on pairs of entities. Predict a list of top-$K$ target entities given a source entity at a seed time. For this task, it is important to ensure a certain density of links in the training data in order for there to be sufficient predictive signal.



### Select a Dataset

Check the databases currently available in RelBench:

In [None]:
from relbench.datasets import get_dataset_names

get_dataset_names()

['rel-amazon',
 'rel-avito',
 'rel-event',
 'rel-f1',
 'rel-hm',
 'rel-stack',
 'rel-trial']

**rel-hm.** The H&M relational database hosts extensive customer and product data for online shopping experiences across its extensive network of brands and stores. This database includes detailed customer purchase histories and a rich set of metadata, encompassing everything from basic demographic information to extensive details about each product available.

<br>
<img src="https://drive.google.com/uc?export=view&id=1i9nxFYu7v26SeZPBlimX9ZVAgigXuZFC" width="200">

In [None]:
from relbench.datasets import get_dataset

# Download the dataset from Stanford website and unzip it
dataset = get_dataset(name="rel-hm", download=True)

Downloading file 'rel-hm/db.zip' from 'https://relbench.stanford.edu/download/rel-hm/db.zip' to '/root/.cache/relbench'.
100%|████████████████████████████████████████| 143M/143M [00:00<00:00, 124GB/s]
Unzipping contents of '/root/.cache/relbench/rel-hm/db.zip' to '/root/.cache/relbench/rel-hm/.'


In [None]:
dataset.url

'https://www.kaggle.com/competitions/h-and-m-personalized-fashion-recommendations'

Now we have loaded the database, let's start poking around to see what's inside.

### Check Temporal Splitting

Data is split temporally, with models trained on rows up to `val_timestamp`, validated on the rows between `val_timestamp` and `test_timestamp`, and tested on the rows after `val_timestamp`.

We must **avoid temporal leakage of information** during training and validation through temporal neighbor sampling.

In [None]:
dataset.val_timestamp, dataset.test_timestamp

(Timestamp('2020-09-07 00:00:00'), Timestamp('2020-09-14 00:00:00'))

*   Information up to September 7, 2020 can be used for training.
*   Information up to September 14, 2020 can be used for validation.
*   Information after September 14, 2020 can be used for testing.

### Relational Database

#### Load a Database

Let's check out the relational database itself...

In [None]:
# Load the dataset object and cache it in memory
db = dataset.get_db()

Loading Database object from /root/.cache/relbench/rel-hm/db...
Done in 1.47 seconds.


This returns a RelBench `Database` object.

By default, the rows with $\text{timestamp} > \text{test_timestamp}$ are excluded to prevent accidental test set leakage.

The complete database can be loaded with `database.get_db(upto_test_timestamp=False)`.

#### Check Full Timespan

With this we can double check the full timespan of the database:

In [None]:
db.min_timestamp, db.max_timestamp

(Timestamp('2019-09-07 00:00:00'), Timestamp('2020-09-14 00:00:00'))

Note that the `max_timestamp` is the same as `test_timestamp`.

#### Check Tables

In the selected dataset, we have the following tables:

In [None]:
db.table_dict.keys()

dict_keys(['customer', 'transactions', 'article'])

That's 3 tables total! Let's look more closely at one of them.

In [None]:
table = db.table_dict["article"]
table

Table(df=
        article_id  product_code               prod_name  product_type_no  \
0                0        108775               Strap top              253   
1                1        108775               Strap top              253   
2                2        108775           Strap top (1)              253   
3                3        110065       OP T-shirt (Idro)              306   
4                4        110065       OP T-shirt (Idro)              306   
...            ...           ...                     ...              ...   
105537      105537        953450  5pk regular Placement1              302   
105538      105538        953763       SPORT Malaga tank              253   
105539      105539        956217         Cartwheel dress              265   
105540      105540        957375        CLAIRE HAIR CLAW               72   
105541      105541        959461            Lounge dress              265   

       product_type_name  product_group_name  graphical_appearanc

The `article` table stores information on all products available in the H&M e-commerce. Note that the table comes with multiple bits of information:

*   The table itself, `table.df` which is a Pandas DataFrame.
*   The primary key column, `table.pkey_col`, which indicates that the `article_id` column holds the primary key for this particular table in the database.
*   The primary time column, `table.time_col`, which, if the entity is an event, records the time an event happened. In the case of articles, they are non-temporal entities, so `table.time_col=None`.
*   The other tables pointed by the foreign keys, if any, `table.fkey_col_to_pkey_table`. Again, in the case of articles, this is not applicable.



In [None]:
table = db.table_dict["customer"]
table.df

Unnamed: 0,customer_id,FN,Active,club_member_status,fashion_news_frequency,age,postal_code
0,0,,,ACTIVE,NONE,49.0,52043ee2162cf5aa7ee79974281641c6f11a68d276429a...
1,1,,,ACTIVE,NONE,25.0,2973abc54daa8a5f8ccfe9362140c63247c5eee03f1d93...
2,2,,,ACTIVE,NONE,24.0,64f17e6a330a85798e4998f62d0930d14db8db1c054af6...
3,3,,,ACTIVE,NONE,54.0,5d36574f52495e81f019b680c843c443bd343d5ca5b1c2...
4,4,1.0,1.0,ACTIVE,Regularly,52.0,25fa5ddee9aac01b35208d01736e57942317d756b32ddd...
...,...,...,...,...,...,...,...
1371975,1371975,,,ACTIVE,NONE,24.0,7aa399f7e669990daba2d92c577b52237380662f36480b...
1371976,1371976,,,ACTIVE,NONE,21.0,3f47f1279beb72215f4de557d950e0bfa73789d24acb5e...
1371977,1371977,1.0,1.0,ACTIVE,Regularly,21.0,4563fc79215672cd6a863f2b4bf56b8f898f2d96ed590e...
1371978,1371978,1.0,1.0,ACTIVE,Regularly,18.0,8892c18e9bc3dca6aa4000cb8094fc4b51ee8db2ed14d7...


In [None]:
table = db.table_dict["transactions"]
table

Table(df=
              t_dat  customer_id  article_id     price  sales_channel_id
0        2019-09-07          155       51985  0.010153                 1
1        2019-09-07          155       51985  0.010153                 1
2        2019-09-07          155       83127  0.042356                 1
3        2019-09-07          155        6066  0.005068                 1
4        2019-09-07          155       78525  0.033881                 1
...             ...          ...         ...       ...               ...
15187282 2020-09-14      1371926       93801  0.025407                 1
15187283 2020-09-14      1371926       17155  0.033881                 1
15187284 2020-09-14      1371926       65802  0.030492                 1
15187285 2020-09-14      1371926       85883  0.016932                 1
15187286 2020-09-14      1371926      104763  0.042356                 1

[15187287 rows x 5 columns],
  fkey_col_to_pkey_table={'customer_id': 'customer', 'article_id': 'article'},
  pke

**Database schema:**

<img src="https://relbench.stanford.edu/img/rel-hm.png" width="800px">


#### Load a Task

Each RELBENCH dataset comes with multiple pre-defined predictive tasks. For any given RELBENCH dataset, you can check all the associated tasks with:

In [None]:
from relbench.tasks import get_task_names, get_task

get_task_names("rel-hm")

['user-item-purchase', 'user-churn', 'item-sales']

We will work with `user-item-purchase`, where the task is: **"Predict the list of articles each customer will purchase in the next seven days"** [[Source](https://relbench.stanford.edu/datasets/rel-hm/)]. The task itself is instantiated by calling:

In [None]:
task = get_task("rel-hm", "user-item-purchase", download=True)
task

Downloading file 'rel-hm/tasks/user-item-purchase.zip' from 'https://relbench.stanford.edu/download/rel-hm/tasks/user-item-purchase.zip' to '/root/.cache/relbench'.
100%|█████████████████████████████████████| 46.9M/46.9M [00:00<00:00, 44.9GB/s]
Unzipping contents of '/root/.cache/relbench/rel-hm/tasks/user-item-purchase.zip' to '/root/.cache/relbench/rel-hm/tasks/.'


UserItemPurchaseTask(dataset=HMDataset())

In [None]:
from relbench.base import TaskType
assert task.task_type == TaskType.LINK_PREDICTION

Next, we load the train / val / test labels.

In [None]:
train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

Each link prediction task table contains **triples** `(entity_1_id, entity_2_id, timestamp)` indicating:

*   The target entity label (`entity_2_id`) associated to `entity_1_id`
*   `timestamp`, the timepoint at which the prediction is made

The task table also indicates which database table(s) it is "attached" to --- in this case the *customer* and *article* tables.

In [None]:
train_table

Table(df=
         timestamp  customer_id  \
0       2019-12-09       149853   
1       2019-12-09       435491   
2       2019-12-09       600889   
3       2019-12-09      1271535   
4       2019-12-09       124560   
...            ...          ...   
3878446 2020-04-20       408061   
3878447 2020-04-20      1138840   
3878448 2020-03-30       140490   
3878449 2020-03-30      1094930   
3878450 2020-03-30      1217756   

                                                article_id  
0                                           [11667, 83069]  
1        [8061, 56842, 70123, 83386, 14038, 70122, 3315...  
2                                           [25756, 72271]  
3         [78428, 38992, 91389, 86016, 2556, 72566, 10378]  
4                                                  [80745]  
...                                                    ...  
3878446                                            [82437]  
3878447                                           [101299]  
3878448             

### Data Sampling

Data sampling is not required for the original RELBENCH GNNs we will use in the first part. Everything can be runned on Google Colab.

Anyway, it can be necessary with larger models or fewer hardware resources.

Official sampling code from ContextGNN [[Source](https://github.com/snap-stanford/relbench/blob/6bcb12a94b163c52e01cc272dfd4817cd13eff69/examples/lightgbm_link.py#L113)].

In [None]:
# import numpy as np

# # Subsample train data if SAMPLE_SIZE is less than the current size
# TRAIN_SAMPLE_SIZE = 10000

# if TRAIN_SAMPLE_SIZE > 0 and TRAIN_SAMPLE_SIZE < len(train_table):
#     sampled_idx = np.random.permutation(len(train_table))[:TRAIN_SAMPLE_SIZE]
#     train_table.df = train_table.df.iloc[sampled_idx].reset_index(drop=True)  # Reset indices

# VAL_SAMPLE_SIZE = 1000
# if VAL_SAMPLE_SIZE > 0 and VAL_SAMPLE_SIZE < len(val_table):
#     sampled_idx = np.random.permutation(len(val_table))[:VAL_SAMPLE_SIZE]
#     val_table.df = val_table.df.iloc[sampled_idx].reset_index(drop=True)  # Reset indices

# TEST_SAMPLE_SIZE = 1000
# if TEST_SAMPLE_SIZE > 0 and TEST_SAMPLE_SIZE < len(test_table):
#     sampled_idx = np.random.permutation(len(test_table))[:TEST_SAMPLE_SIZE]
#     test_table.df = test_table.df.iloc[sampled_idx].reset_index(drop=True)  # Reset indices

## 🧑‍🏫 GNNs for Recommender Systems

### Preliminary of Recommendation

Information Explosion in the era of Internet

*   10K+ movies in Netflix
*   12M products in Amazon
*   70M+ music tracks in Spotify
*   10B+ videos on YouTube
*   200B+ pins (images) in Pinterest

**Personalized recommendation (i.e., suggesting a small number of interesting items for each user)** is critical for users to effectively explore the content of their interest.



### Recommendation System as a Graph

Recommender system can be naturally modeled as a **bipartite graph**.

*   A graph with two node types: **users** and **items**
*   Edges connect users and items
    *    Indicates user-item interaction (e.g., click, purchase, review)
    *    Often associated with timestamp (timing of the interaction)

<img src="https://drive.google.com/uc?export=view&id=1QTprqNjLmxKU5xhjTVPVPRlUQmhiHKFS" width="150">

> For the AI model to produce accurate recommendations it needs a detailed understanding of product properties as well as customer preferences. **Products as well as customers are often associated with unstructured textual information like product names, product descriptions, customer reviews, and more. It is critical for the recommender system to include text understanding capability because so much information is stored in unstructured text.**

> At the same time, graph information is highly valuable in recommender systems because it captures **complex relationships between customers, products, and their interactions**. Graphs represent the relationships between customers and products as edges, allowing the system to consider not just **direct** interactions (like purchases or ratings) but also **indirect** connections (e.g., customers who like similar products or products that are liked by similar customers).

> In graph-based systems, higher-order connections (i.e., multi-hop relationships) can be leveraged. **For instance, if Customer A likes a product that Customer B liked, and Customer C is similar to Customer B, the system might recommend that product to Customer C as well.** Overall, graph-based models provide a nuanced understanding of customers’ preferences by considering both direct and indirect interactions. This enables highly personalized recommendations, taking into account more contextual information about the customer and their network.

⚠️ **We will not model our relational database as a bipartite graph, but this provides you with an idea about the popularity of graph representations within the RecSys field and the underlying motivations.**

### Recommendation Task

Given

*   Past user-item interactions.

Task

*   Predict new items each user will interact in the future.
*   Can be cast as **link prediction** problem.
    *   Predict new user-item interaction edges given the past edges.

For $u \in U$, $v \in V$, we need to get a real-valued score $f(u, v)$.

<img src="https://drive.google.com/uc?export=view&id=1S17_5hMnCEfYfgeUwYcv0EKqbF9I4zxd" width="150">



### Moder Recommender System

**Problem:** Cannot evaluate $f(u,v)$ for every user $u$ $-$ item $v$ pair.
**Solution:** 2-stage process.
  *   Candidate generation (cheap, fast)
  *   Ranking (slow, accurate)

<img src="https://drive.google.com/uc?export=view&id=1KuJBKw9BnsnC9HoBF20pg48nuXnrRgTm" width="550">

### Top-K Recommendation

For each user, we recommend $K$ items (those from the user-item interaction pairs with the largest scores, excluding already-interacted items).

  *   For recommendation to be effective, $K$ needs to be much smaller than the total number of items (up to billions).
  *   $K$ is typically in the order of $10-100$.

The goal is to include as many **positive items** as possible in the top-$K$ recommended items.

  *   **Positive items = Items that the user will interact with in the future**.

<img src="https://drive.google.com/uc?export=view&id=1mZPzFOQkgN0HOVkEZUW77DHLioi44PZj" width="250">

### Embedding-Based Models

We consider embedding-based models for scoring user$-$item interactions.

*   For each user $u \in U$, let $\mathbf{u} \in \mathbb{R}^D$ be its $D$-dimensional embedding.
*   For each item $v \in V$, let $\mathbf{v} \in \mathbb{R}^D$ be its $D$-dimensional embedding.
*   $f_{\theta}(\cdot, \cdot): \mathbb{R}^D \times \mathbb{R}^D$ is a parametrized function.

Thus, embedding-based models have **three kinds of parameters**:
*   An encoder to generate user embeddings $\{\mathbf{u}\}_{u \in U}$.
*   An encoder to generate item embeddings $\{\mathbf{v}\}_{v \in V}$.
*   Score function $f_{\theta}(\cdot, \cdot)$.

<br>
<img src="https://drive.google.com/uc?export=view&id=1aht8ewGJ7Hv7KhG2cA8yw1Lbtep3rBgV" width="210">

### Training Objective

Optimize the model parameters to **achieve high recall@$K$ on seen (i.e., training) user$-$item interactions**.
*   We hope this objective would lead to high recall@$K$ on unseen (i.e., test) interactions.

Recall@$K$ for user $u$ is $\frac{|P_u ∩ R_u|}{|P_u|}$: percentage of recommended positive items. The final Recall@$K$ is computed by averaging the recall values across all users.

<img src="https://drive.google.com/uc?export=view&id=1FDWeao5uWsL1KY6KCOV6gK4xQ3MpRW4j" width="400">

The original training objective (recall@$K$) is **not differentiable**.
*   The community employs **surrogate (differentiable) losses** that align well with the original training objective.
*   Widely-used surrogate loss: **Bayesian Personalized Ranking (BPR)**.

Considering this illustration...

<img src="https://drive.google.com/uc?export=view&id=1o4wz7GNpZ4uBAurKj22uV9On0G__oIIt" width="200">

BPR [[Rendle et al., 2012](https://arxiv.org/pdf/1205.2618)]
*   **For each user $u^*$, we want the scores of rooted positive edges $E(u^*)$ to be higher than those of rooted negative edges $E_{neg}(u^*)$**.
*    It supports **mini-batches**.
    *    In each mini-batch, we sample a subset of users $U_{mini} ⊂ U$.
    *    For each user $u^* \in U_{mini}$, we sample one positive item $v_{pos}$ and a set of sampled negative items $V_{neg}$.

<br>
<img src="https://drive.google.com/uc?export=view&id=1fgZu6WsPznqWotFRl_jCjXcoOLsiTjey" width="600">
<br>
<img src="https://drive.google.com/uc?export=view&id=1RwS73U0kYAY4zv4zp43tFZeZJRhMB6vM" width="600">

<br>
More information: https://web.stanford.edu/class/cs224w/slides/12-recsys.pdf

### Why Embedding Models Work?

Underlying idea: **Collaborative filtering**

*   Recommend items for a user by **collecting preferences of many other similar users**.
*   **Similar users tend to prefer similar items**

Embedding-based models can capture similarity of users/items!

*   Low-dimensional embeddings cannot memorize all user$-$item interaction data.
*   Embeddings are forced to capture similarity between users/items to fit the data.
*   This allows the models to make effective prediction on *unseen* user$-$item interactions.

<img src="https://drive.google.com/uc?export=view&id=1VH59zqdaONHC45UkrQU2TQtpcNX6MxbP" width="300">




## 🤖 Train a GNN

To load the data we did not require any deep learning libraries. In this part of the notebook, we will work with PyTorch.

### Environment Setup

In [None]:
import torch

from torch_geometric.seed import seed_everything

# Check that it's cuda if you want it to run in reasonable time!
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.set_num_threads(1)
print(device)

# Set the seed for generating random numbers to ensure reproducibility
seed_everything(42)

# Path to the directory for caching graph data
root_dir = "./data"

cuda


### Construct the Graph and Initialize Features

**Heterogeneous Temporal Graph**

The first big move is to build a graph out of the database.
We will use a pre-prepared conversion function, `make_pkey_fkey_graph`, from Stanford [[Source Code](https://github.com/snap-stanford/relbench/blob/main/relbench/modeling/graph.py)].
Given a set of tables with primary-foreigh key relations between them, we automatically construct a heterogeneous temporal graph, where:

*   Each **table** represents a **node type** (HETEROGENEOUS).
*   Each **row in a table** represents a **node**.
*   A **primary-foreign-key relation** between two table rows (nodes) represent an **edge** between the respective nodes. Even edges have multiple types (HETEROGENEOUS), depending on the attribute that connects two tables.

Some node types are associated with time attributes, representing the timestamp at which a node appears (TEMPORAL).

The heterogeneous temporal graph is represented as a **PyTorch Geometric** graph object.

💡 The graph representation allows GNNs to be used as predictive models.

**Feature Initialization**

Each node in the graph comes with a rich feature derived from diverse columns of the corresponding table.

We use **Tensor Frame provided by PyTorch Frame** to represent rich node features with diverse column types, e.g., numerical, categorical, timestamp, and text.

**PyTorch Frame also stores the `stype` (i.e., modality) of each column** and **allows to set the feature encoders (e.g., text encoders) to be used later**.

So, we need to configure the `stype` for each column, for which we use a function that tries to automatically detect the `stype`.

In [None]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

{'customer': {'customer_id': <stype.numerical: 'numerical'>,
  'FN': <stype.categorical: 'categorical'>,
  'Active': <stype.categorical: 'categorical'>,
  'club_member_status': <stype.categorical: 'categorical'>,
  'fashion_news_frequency': <stype.text_embedded: 'text_embedded'>,
  'age': <stype.numerical: 'numerical'>,
  'postal_code': <stype.text_embedded: 'text_embedded'>},
 'transactions': {'t_dat': <stype.timestamp: 'timestamp'>,
  'customer_id': <stype.numerical: 'numerical'>,
  'article_id': <stype.numerical: 'numerical'>,
  'price': <stype.numerical: 'numerical'>,
  'sales_channel_id': <stype.categorical: 'categorical'>},
 'article': {'article_id': <stype.numerical: 'numerical'>,
  'product_code': <stype.numerical: 'numerical'>,
  'prod_name': <stype.text_embedded: 'text_embedded'>,
  'product_type_no': <stype.numerical: 'numerical'>,
  'product_type_name': <stype.text_embedded: 'text_embedded'>,
  'product_group_name': <stype.text_embedded: 'text_embedded'>,
  'graphical_appea

Next, we also define our text encoding model. We will use **GloVe** embeddings for speed and convenience. Feel free to try alternatives here.

In [None]:
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return self.model.encode(sentences, convert_to_tensor=True)

🕒 *Please note that graph construction can take up to 10 minutes.*

In [None]:
import os

# Root directory where files will be stored
root_dir = "./data"

# Run the from-scratch graph computation
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

# Configure the text encoder
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device),
    batch_size=256
)

# Generate graph data
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # Column types
    text_embedder_cfg=text_embedder_cfg,  # Our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-hm_materialized_cache"
    ),  # Store materialized graph for convenience
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/248 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

(…)beddings/whitespacetokenizer_config.json:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/480M [00:00<?, ?B/s]

(…)WordEmbeddings/wordembedding_config.json:   0%|          | 0.00/164 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Embedding raw data in mini-batch: 100%|██████████| 5360/5360 [00:28<00:00, 189.91it/s]
Embedding raw data in mini-batch: 100%|██████████| 5360/5360 [00:32<00:00, 166.32it/s]
Embedding raw data in mini-batch: 100%|██████████| 413/413 [00:02<00:00, 172.27it/s]
Embedding raw data in mini-batch: 100%|██████████| 413/413 [00:02<00:00, 164.90it/s]
Embedding raw data in mini-batch: 100%|██████████| 413/413 [00:04<00:00, 93.08it/s]
Embedding raw data in mini-batch: 100%|██████████| 413/413 [00:02<00:00, 171.70it/s]
Embedding raw data in mini-batch: 100%|██████████| 413/413 [00:02<00:00, 155.60it/s]
Embedding raw data in mini-batch: 100%|██████████| 413/413 [00:02<00:00, 167.76it/s]
Embedding raw data in mini-batch: 100%|██████████| 413/413 [00:02<00:00, 178.34it/s]
Embedding raw data in mini-batch: 100%|██████████| 413/413 [00:02<00:00, 161.93it/s]


We can now check out `data`, our main graph object, with node types given by the table it originates from.

In [None]:
data

HeteroData(
  customer={ tf=TensorFrame([1371980, 6]) },
  transactions={
    tf=TensorFrame([15187287, 3]),
    time=[15187287],
  },
  article={ tf=TensorFrame([105542, 24]) },
  (transactions, f2p_customer_id, customer)={ edge_index=[2, 15187287] },
  (customer, rev_f2p_customer_id, transactions)={ edge_index=[2, 15187287] },
  (transactions, f2p_article_id, article)={ edge_index=[2, 15187287] },
  (article, rev_f2p_article_id, transactions)={ edge_index=[2, 15187287] }
)

**What is `HeteroData`?**

`HeteroData` is a specialized container in PyTorch Geometric that extends the `Data` class to support heterogeneous graphs. It enables:
- Representation of **different node types** with distinct feature sets.
- Representation of **different edge types** capturing various relationships between nodes.
- Efficient storage of graph data for use in graph neural networks (GNNs).

---

**Node types**

We see 3 node types (`customer`, `transactions`, `article`). Each node type has its own `TensorFrame` (or tensor) to store features. Example: `customer={ tf=TensorFrame([1371980, 6]) }` means that there are 1,371,980 customer nodes, each with 6 features.

The `transactions` node type, `data['transactions']`, stores two feature types:

*   A `TensorFrame` object.
*   A timestamp for each node.

---

**Edge types**

An edge type defines the relationship between two node types in the graph. Each edge type consists of: `(source_node_type, relation_type, target_node_type)`.
We see 4 edge types. Example: `(transactions, f2p_customer_id, customer)={ edge_index=[2, 15187287] }` indicates 15,187,287 directed edges connecting `transactions` nodes to `customer` nodes via the relationship `f2p_customer_id`.

We can also check out the `TensorFrame` for one table like this:

In [None]:
data["customer"].tf

TensorFrame(
  num_cols=6,
  num_rows=1371980,
  categorical (3): ['Active', 'FN', 'club_member_status'],
  numerical (1): ['age'],
  embedding (2): ['fashion_news_frequency', 'postal_code'],
  has_target=False,
  device='cpu',
)

In [None]:
data["article"].tf

TensorFrame(
  num_cols=24,
  num_rows=105542,
  numerical (6): ['colour_group_code', 'department_no', 'graphical_appearance_no', 'product_code', 'product_type_no', 'section_no'],
  categorical (10): ['garment_group_name', 'garment_group_no', 'index_code', 'index_group_name', 'index_group_no', 'index_name', 'perceived_colour_master_id', 'perceived_colour_master_name', 'perceived_colour_value_id', 'perceived_colour_value_name'],
  embedding (8): ['colour_group_name', 'department_name', 'detail_desc', 'graphical_appearance_name', 'prod_name', 'product_group_name', 'product_type_name', 'section_name'],
  has_target=False,
  device='cpu',
)

In [None]:
data["transactions"].tf

TensorFrame(
  num_cols=3,
  num_rows=15187287,
  timestamp (1): ['t_dat'],
  numerical (1): ['price'],
  categorical (1): ['sales_channel_id'],
  has_target=False,
  device='cpu',
)

ℹ️ *Features marked as embedding in the TensorFrame (e.g., colour_group_name, department_name, etc.) represent pre-existing embeddings or dense representations, such as word embeddings for textual data. In our case, they are the GloVe embeddings previously computed at graph construction time.*

In [None]:
list(data["transactions"].keys())

['tf', 'time']

The `TensorFrame` object acts analogously to the usual tensor of node features, and you can simply use indexing to retrieve the features of a single row (node), or group of nodes.



In [None]:
# Features of node 10
data["transactions"].tf[10]

TensorFrame(
  num_cols=3,
  num_rows=1,
  timestamp (1): ['t_dat'],
  numerical (1): ['price'],
  categorical (1): ['sales_channel_id'],
  has_target=False,
  device='cpu',
)

In [None]:
# Features of nodes 10, 11, ..., 19
data["transactions"].tf[10:20]

TensorFrame(
  num_cols=3,
  num_rows=10,
  timestamp (1): ['t_dat'],
  numerical (1): ['price'],
  categorical (1): ['sales_channel_id'],
  has_target=False,
  device='cpu',
)

We can also check the edge indices between two different node types, such as `transactions` and `customers`. Since the edges are also heterogenous, we need to specify which edge type we want to look at. Here we look at `f2p_customer_id`, which are the directed edges pointing from a transaction (the *f* stands for *foreign key*), to the customer that completed that transaction (the *p* stands for *primary key*).

In [None]:
data[("transactions", "f2p_customer_id", "customer")]

{'edge_index': tensor([[       0,        1,        2,  ..., 15187284, 15187285, 15187286],
        [     155,      155,      155,  ...,  1371926,  1371926,  1371926]])}

The `edge_index` tensor has the following structure:

*   It is a 2D tensor with shape `[2, num_edges]`.
*   The first row (`edge_index[0]`) contains the indices of the source nodes (in this case, transactions).
*   The second row (`edge_index[1]`) contains the indices of the target nodes (in this case, customer).
*   Each column in `edge_index` defines a directed edge between two nodes, connecting a source node to a target node.

### Two-Tower GNN Architecture

Please note that recommendation requires computing scores between pairs of source nodes and target nodes.

For this task type, our architecture is a **two-tower GNN** [[Wang et al., 2019](https://dl.acm.org/doi/abs/10.1145/3331184.3331267?casa_token=NWPeKZ6jwg4AAAAA:H8GuVLZSfjA_KaABlf-UUHUjGKatRbwP4UyTkZpPJJRcsrhnfuRBa2dBFolHE4S6l1ggI8j-thqdig)] that computes the pairwise score via inner product between source and target node embeddings.

Illustration of a two-tower GNN from the reference paper:

<img src="https://drive.google.com/uc?export=view&id=1TspJlRuCw33VBB3CD1ZJeLDDiZrIDEZ7" width="600">

Initial embeddings of the customer (left) and article (right) nodes under comparison are separately refined through a GNN with multiple layers (shared weights). The outputs are pooled to make the final binary prediction (0 = no link, 1 = link) via inner product.

**Model Architecture for Temporal-Aware Heterogeneous Graph Neural Networks**

The following model implements a heterogeneous graph neural network with temporal awareness using PyTorch Geometric and PyTorch Frame.

It processes raw tabular data into node embeddings, incorporates temporal information, and performs neighbor aggregation through a GraphSAGE-inspired GNN architecture.

As in the RELBENCH paper, our modeling supports two representative predictive architectures.
1. **GraphSAGE [[Hamilton et al., 2017](https://proceedings.neurips.cc/paper/2017/hash/5dd9db5e033da9c6fb5ba83c7a7ebea9-Abstract.html)]**.
   * Each layer:
      * Concatenates self and neighborhood states, and apply on it a single trainable weight matrix as the update function.
      * Samples a fixed number of neighbors at multiple distances and treat them as direct neighbors of node $v$.
   * It considers the *Bayesian Personalized Ranking (BPR) loss*.
2. **ID-GNN [[You et al., 2021](https://ojs.aaai.org/index.php/AAAI/article/view/17283)]**.
   * ID-GNN extends existing GNN architectures (GraphSAGE in our case) by inductively considering nodes' identities during message passing.
   * It adds an identity coloring technique to distinguish a node itself (the root node in the computational graph) from other nodes in its local neighborhood, within its respective computational graph. The center node and the rest of the nodes are computed using different sets of parameters. In summary, it adds expressive power.
   * It considers a standard *Binary Cross-Entropy loss*.


In [None]:
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder

class Model(torch.nn.Module):
    """
    Heterogeneous Graph Neural Network Model with Temporal Encoding.

    This model is designed for tasks on heterogeneous graphs with temporal information. It processes
    raw tabular data into embeddings, encodes temporal information, aggregates node features
    with a GNN, and produces task-specific predictions.

    Args:
        data (HeteroData): A heterogeneous graph object containing nodes, edges, and features.
        col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): Column statistics for raw tabular features.
            - This is a dictionary where:
                - `node_type` -> Feature name -> Statistic (e.g., mean, std).
                - Used by the HeteroEncoder for normalizing input features.
        num_layers (int): Number of layers in the GNN (depth of feature propagation).
        channels (int): Dimensionality of intermediate node embeddings (hidden layer size).
        out_channels (int): Dimensionality of output embeddings (final layer size).
        aggr (str): Aggregation function for GraphSAGE (e.g., "sum", "mean").
        norm (str): Normalization method applied in the MLP head (e.g., "batchnorm").
        shallow_list (List[NodeType], optional): List of node types to assign shallow embeddings to. Defaults to [].
        id_awareness (bool, optional): If True, adds a unique embedding for ID-awareness. Defaults to False.
    """

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        shallow_list: List[NodeType] = [],
        id_awareness: bool = False,
    ):
        super().__init__()

        """
        1. HeteroEncoder: Encoding Raw Tabular Features
        ---
        HeteroEncoder is responsible for converting raw tabular data from different node types into dense embeddings.

        - TYPE-AWARE FEATURE EMBEDDING
          - NUMERICAL FEATURES: These are passed through a linear layer (e.g., LinearEncoder), which maps them into a dense vector space.
          - CATEGORICAL FEATURES: These are embedded using an embedding layer (e.g., EmbeddingEncoder). For example,
            if 'sales_channel_id' has 5 categories, each category is mapped to a learned vector of size `channels`.
          - TIMESTAMP FEATURES: These are encoded with a timestamp encoder (e.g., TimestampEncoder) to add time sensitivity.
          - EMBEDDING FEATURES: These are already dense vectors, unlike raw categorical or numerical features. The HeteroEncoder processes
            them using a linear transformation layer (LinearEmbeddingEncoder), which adjusts their dimensionality.

        - AGGREGATION: The encoded features for each node type are aggregated by a ResNet tabular model. The ResNet combines raw feature
          embeddings and their interactions, producing a final dense embedding of size `channels` for the node type. This aggregated
          embedding is passed downstream to the GNN.

        - END-TO-END TRAINING: All these transformations are trainable and learned alongside the GNN, similar to an embedding layer
          in transformer models. The model optimizes the parameters of these encoders during backpropagation.

        - ROLE OF COLUMN STATISTICS (`col_stats_dict`): These are used to normalize the numerical features (e.g., subtract mean, divide by std).
          For example, 'price' may be normalized using its mean and std before being passed through LinearEncoder.
        """
        self.encoder = HeteroEncoder(
            channels=channels,  # target embedding dim
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },  # dictionary mapping from node type to column names
            node_to_col_stats=col_stats_dict,
        )

        """
        2. HeteroTemporalEncoder: Incorporating Time Sensitivity
        ---
        The HeteroTemporalEncoder is designed to encode time-related information for nodes that have temporal features.
        It ensures that the model respects the temporal ordering of events and can include relative time-based signals
        in the embeddings, which is critical for temporal tasks (e.g., predicting future behavior).

        - NODE TYPES WITH TIME FEATURES
          - The encoder processes only those node types where a "time" key is present in the `HeteroData` object.
          - For example, since the `transactions` node type has a time feature (see `list(data["transactions"].keys())`),
            it will be included in this temporal processing.

        - TEMPORAL ENCODING PROCESS:
          1. Compute **relative time differences**:
            - For each node, compute the difference between the `seed_time` (reference time for the batch) and the node's
              timestamp (`time_dict[node_type]`).
            - Convert the difference from seconds to days (scaling for interpretability and numerical stability).
          2. Apply **positional encoding**:
            - The relative time differences are passed through the `PositionalEncoding` module, producing dense embeddings.
          3. Apply **linear transformation**:
            - The embeddings from positional encoding are passed through a node-type-specific linear layer for further processing.
            - This ensures the temporal embeddings are adapted to the specific node type and task.
          4. Store the output:
            - The resulting temporal embeddings are stored in a dictionary (`out_dict`), keyed by node type.

        - OUTPUT:
          - A dictionary of temporal embeddings for each node type, with embeddings of size `channels`.

        - TRAINING:
          - The positional encoding and linear transformations are fully trainable. Gradients flow through the entire
            temporal encoding process during backpropagation.

        - WHY IT MATTERS:
          - Temporal embeddings enable the model to capture time-sensitive patterns, such as sequential events or evolving
            relationships. By adding these embeddings to the base node embeddings, the model can better understand temporal
            dynamics within the graph.
        """
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )

        """
        3. HeteroGraphSAGE: Node Feature Aggregation in Heterogeneous Graphs
        ---
        The `HeteroGraphSAGE` module implements a heterogeneous version of the GraphSAGE model. It aggregates features from neighboring nodes
        using a series of message-passing layers (`HeteroConv`), followed by normalization (`LayerNorm`) and activation.

        - MAIN INPUTS:
          - **Node Features (`x_dict`)**:
            - A dictionary mapping each node type to its corresponding features.
            - Features are initially generated by the `HeteroEncoder` and optionally augmented with temporal embeddings.
          - **Edge Index (`edge_index_dict`)**:
            - A dictionary mapping each edge type to its corresponding sparse edge index tensor (i.e., graph connectivity).
            - Specifies how nodes of different types are connected in the graph.

        - ARCHITECTURE:
          1. **HeteroConv Layers (`self.convs`)**:
            - A stack of `HeteroConv` layers, each aggregating features from neighboring nodes across all edge types.
            - Each `HeteroConv` layer consists of:
              - Multiple `SAGEConv` layers (one per edge type), which perform the GraphSAGE aggregation for specific node pairs.
              - An aggregation function (`sum`) that combines the outputs of these `SAGEConv` layers into a unified embedding
                for each node type.
            - The number of layers is specified by `num_layers`, allowing multi-hop message passing.
          2. **Layer Normalization (`self.norms`)**:
            - A `LayerNorm` is applied to the aggregated features of each node type to stabilize training and improve convergence.
            - Each node type has its own normalization layer.
          3. **Non-Linearity (ReLU)**:
            - A ReLU activation is applied after normalization to introduce non-linearity into the learned embeddings.

        - MESSAGE-PASSING PROCESS:
          1. For each layer:
            - **Aggregation**:
              - Node embeddings are updated by aggregating features from neighboring nodes based on the graph structure.
              - This includes type-specific transformations via `SAGEConv` and cross-type aggregation via `HeteroConv`.
            - **Normalization**:
              - LayerNorm ensures that the embeddings for each node type are normalized.
            - **Non-Linearity**:
              - A ReLU activation is applied to add non-linear expressiveness.
          2. This process is repeated for `num_layers` steps, enabling multi-hop feature propagation across the graph.

        - TYPICAL GRAPHSAGE NEIGHBOR SAMPLING:
          - There is no optional sampling metadata at architecture definition. However, at forward (call) time, we can
            pass already-sampled nodes. In other terms, neighbor sampling, if performed, happens outside this module,
            as part of the data preprocessing.
          - Optionally, `num_sampled_nodes_dict` and `num_sampled_edges_dict` provide information on sampled nodes and edges
            for each layer during mini-batch training, enabling efficient scaling to large graphs.

        - TRAINING:
          - All parameters in the `SAGEConv` layers and `LayerNorm` modules are trainable. Gradients flow back through the entire
            message-passing process during backpropagation.

        - OUTPUT:
          - A dictionary mapping each node type to its updated embeddings after `num_layers` of message-passing.

        - WHY IT MATTERS:
          - The `HeteroGraphSAGE` module enables the model to capture relationships and dependencies between nodes of different types
            (heterogeneous graph learning). By aggregating features from neighbors, the model can learn richer representations that
            incorporate both node attributes and graph structure.
        """
        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )

        """
        4. MLP: Task-Specific Prediction Head
        ---
        The `MLP` (Multi-Layer Perceptron) from PyTorch Geometric is used as the final prediction head of the model.
        It transforms the node embeddings output by the `HeteroGraphSAGE` module into task-specific outputs.

        - INPUTS:
          - **Input Dimensionality (`channels`)**:
            - The dimensionality of the node embeddings output by the `HeteroGraphSAGE` module.
            - This represents the size of the features that the MLP processes.

          - **Output Dimensionality (`out_channels`)**:
            - The size of the final output. For example:
              - For node classification, this corresponds to the number of classes (e.g., `out_channels=4` for 4-way classification).
              - For regression tasks, this might correspond to the size of the predicted vector (e.g., `out_channels=1` for a single regression score).

          - **Normalization (`norm`)**:
            - Specifies the normalization method applied to the intermediate layers of the MLP.
            - Common options include `"batchnorm"` or `"layernorm"`. Normalization stabilizes training and accelerates convergence.

          - **Number of Layers (`num_layers`)**:
            - Controls the depth of the MLP. In this case, `num_layers=1`, meaning the MLP has a single linear transformation
              followed by optional normalization and activation.
              - For `num_layers=1`, the MLP effectively acts as a linear transformation.

        - TRAINABLE PARAMETERS:
          - All weights and biases of the MLP are trainable and are updated during backpropagation.

        - OUTPUT:
          - A tensor of size `[num_nodes, out_channels]`, where `num_nodes` corresponds to the number of nodes being predicted.

        - WHY IT MATTERS:
          - The `MLP` serves as the final layer that adapts the model's learned embeddings to the specific task at hand.
            It provides flexibility to handle a variety of prediction tasks with different output requirements.
        """
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )

        """
        5. Optional Shallow Embeddings for Selected Node Types
        ---
        These embeddings provide additional trainable representations for selected node types.
        Shallow embeddings are often used when certain node types require extra learnable parameters independent of their raw
        features or encoded embeddings.
        These embeddings are specific to each node type and are added to the node features before message passing.

        - INPUT:
          - `shallow_list`: A list of node types for which shallow embeddings are required.
          - `data.num_nodes_dict`: A dictionary mapping each node type to the number of nodes of that type.

        - IMPLEMENTATION:
          - For each node type in `shallow_list`, a separate embedding table is created using `torch.nn.Embedding`.
          - The size of the embedding for each node is `channels`, matching the size of other embeddings in the model.
          - These embeddings are trainable and initialized with random values.

        - USAGE:
          - During the forward pass, shallow embeddings are added directly to the node features for their corresponding node types.
          - This provides an additional trainable signal that can improve model performance in certain cases.
        """
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        """
        6. Optional ID-awareness Embedding
        ---
        ID-awareness in this model is not about assigning a unique embedding to each node.
        It serves as an additional, shared global feature vector added only to the embeddings of the "seed nodes" (the nodes for
        which predictions are being made).
        Thus, every seed node gets the same ID-awareness signal, which helps the model learn relationships better during training.
        """
        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)

        self.reset_parameters()

    def reset_parameters(self):
        """Resets parameters for all model components."""
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        """
        Forward pass for node-level prediction tasks.

        Args:
            batch (HeteroData): Batch of sampled subgraphs.
            entity_table (NodeType): The node type for which predictions are made.

        Returns:
            Tensor: Predictions for the specified node type.
        """
        # Extract the seed time for temporal encoding
        seed_time = batch[entity_table].seed_time

        # Encode raw node features
        x_dict = self.encoder(batch.tf_dict)

        # Add temporal information to embeddings
        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )
        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        # Add shallow embeddings for specified node types
        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        # Apply the GNN to aggregate features
        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        # Return the final predictions for the specified node type
        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        """
        Forward pass with destination table readout for prediction tasks.

        Args:
            batch (HeteroData): Batch of sampled subgraphs.
            entity_table (NodeType): The node type for seed entities.
            dst_table (NodeType): The destination node type for predictions.

        Returns:
            Tensor: Predictions for the destination node type.

        Raises:
            RuntimeError: If `id_awareness` is not enabled.
        """
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )

        # Extract seed time for temporal encoding
        seed_time = batch[entity_table].seed_time

        # Encode raw node features
        x_dict = self.encoder(batch.tf_dict)

        # Add ID-awareness to the seed node embeddings
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        # Add temporal information to embeddings
        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )
        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        # Add shallow embeddings for specified node types
        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        # Apply the GNN to aggregate features
        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        # Return predictions for the destination node type
        return self.head(x_dict[dst_table])

<br>
<img src="https://drive.google.com/uc?export=view&id=1r9Xt3Zfi8Y-z8ECeRMgyIa_Y9awN3B-k" width="400">
<br>
<img src="https://drive.google.com/uc?export=view&id=1lLaGuXUjzoT1DPBbF04FTQ1uv2ALxhJ2" width="500">

### Hyperparameters

Default values matches those reported by the RELBENCH paper [[Source, Table 9](https://arxiv.org/pdf/2407.20060)].

In [None]:
from ipywidgets import VBox, HBox, Button, IntSlider, FloatSlider, Checkbox, Dropdown, Text, Label, Layout, HTML

# Default values (reorganized based on the appearance order in boxes)
default_params = {
    # Architecture
    "num_layers": 2,
    "channels": 128,
    "aggr": "sum",
    "norm": "layer_norm",

    # Data Loading
    "time_attr": "time",
    "temporal_strategy": "uniform",
    "batch_size": 512,
    "share_same_time": True,
    "num_neighbors": 128,
    "num_workers": 0,

    # Training
    "use_shallow": True,
    "lr": 0.001,
    "epochs": 20,
    "max_steps_per_epoch": 2000,
    "eval_epochs_interval": 1,
    "tune_metric": "link_prediction_map",
}

# Widget definitions (reorganized to match default_params and box order)
widgets = {
    # Architecture
    "num_layers": IntSlider(value=default_params["num_layers"], min=1, max=5, step=1,
                            description="Number of Layers:", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "channels": IntSlider(value=default_params["channels"], min=1, max=512, step=1,
                          description="Embedding Dimensions (Channels):", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "aggr": Dropdown(options=["sum", "mean", "max"], value=default_params["aggr"],
                     description="Aggregation Strategy for Neighbor Messages:", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "norm": Dropdown(options=["batch_norm", "layer_norm", "none"], value=default_params["norm"],
                     description="Prediction Head Normalization:", style={"description_width": "500px"}, layout=Layout(width="800px")),

    # Data Loading
    "time_attr": Text(value=default_params["time_attr"], description="Time Attribute Name for Temporal Neighbor Sampling:",
                      style={"description_width": "500px"}, layout=Layout(width="800px")),
    "temporal_strategy": Dropdown(options=["uniform", "recent", "future"], value=default_params["temporal_strategy"],
                                  description="Temporal Neighbor Sampling Strategy:", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "batch_size": IntSlider(value=default_params["batch_size"], min=1, max=1024, step=1,
                            description="Batch Size:", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "share_same_time": Checkbox(value=default_params["share_same_time"],
                                description="Only Batches where User Nodes have the Same Seed Time (No Shuffling)", style={"description_width": "300px"}, layout=Layout(width="800px")),
    "num_neighbors": IntSlider(value=default_params["num_neighbors"], min=1, max=256, step=1,
                               description="Max Number of Sampled Neighbors (100% Layer 1, Progressive 50% for Layers >1):", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "num_workers": IntSlider(value=default_params["num_workers"], min=0, max=4, step=1,
                             description="Number of Workers for Parallel Data Loading:", style={"description_width": "500px"}, layout=Layout(width="800px")),

    # Training
    "use_shallow": Checkbox(value=default_params["use_shallow"],
                            description="Use Shallow Embeddings for Article Nodes", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "lr": FloatSlider(value=default_params["lr"], min=0.0001, max=0.001, step=0.0001,
                      description="Learning Rate:", readout_format=".4f", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "epochs": IntSlider(value=default_params["epochs"], min=1, max=20, step=1,
                        description="Number of Epochs:", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "max_steps_per_epoch": IntSlider(value=default_params["max_steps_per_epoch"], min=100, max=5000, step=100,
                                     description="Max Steps per Epoch:", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "eval_epochs_interval": IntSlider(value=default_params["eval_epochs_interval"], min=1, max=10, step=1,
                                      description="Evaluation Interval:", style={"description_width": "500px"}, layout=Layout(width="800px")),
    "tune_metric": Dropdown(options=["link_prediction_map", "link_prediction_recall", "link_prediction_precision"], value=default_params["tune_metric"],
                             description="Tuning Metric for Best-Eval Checkpoint Saving:", style={"description_width": "500px"}, layout=Layout(width="800px")),
}

# Grouping architecture-related widgets
architecture_label = HTML(value="<b style='font-size:16px;'>Architecture</b>")
architecture_box = VBox([
    architecture_label,
    widgets["num_layers"],
    widgets["channels"],
    widgets["aggr"],
    widgets["norm"],
], layout=Layout(border="solid 1px black", padding="10px", margin="10px"))

# Grouping data-related widgets
data_label = HTML(value="<b style='font-size:16px;'>Data Loading</b>")
data_box = VBox([
    data_label,
    widgets["time_attr"],
    widgets["temporal_strategy"],
    widgets["batch_size"],
    widgets["share_same_time"],
    widgets["num_neighbors"],
    widgets["num_workers"],
], layout=Layout(border="solid 1px black", padding="10px", margin="10px"))

# Grouping training-related widgets
training_label = HTML(value="<b style='font-size:16px;'>Training</b>")
training_box = VBox([
    training_label,
    widgets["use_shallow"],
    widgets["lr"],
    widgets["epochs"],
    widgets["max_steps_per_epoch"],
    widgets["eval_epochs_interval"],
    widgets["tune_metric"],
], layout=Layout(border="solid 1px black", padding="10px", margin="10px"))

# Reset function
def reset_to_defaults(button):
    for key, widget in widgets.items():
        widget.value = default_params[key]

# Button for resetting to defaults
reset_button = Button(
    description="Reset to Default",
    layout=Layout(width="200px", height="40px")
)
reset_button.add_class("reset-button")
reset_button.on_click(reset_to_defaults)

# Main layout
form = VBox([architecture_box, data_box, training_box, reset_button])

# Display
form

VBox(children=(VBox(children=(HTML(value="<b style='font-size:16px;'>Architecture</b>"), IntSlider(value=2, de…

In [None]:
args = {key: widget.value for key, widget in widgets.items()}
args["num_neighbors"] = [int(args["num_neighbors"] // 2**i) for i in range(args["num_layers"])]
args

{'num_layers': 2,
 'channels': 128,
 'aggr': 'sum',
 'norm': 'layer_norm',
 'time_attr': 'time',
 'temporal_strategy': 'uniform',
 'batch_size': 512,
 'share_same_time': True,
 'num_neighbors': [128, 64],
 'num_workers': 0,
 'use_shallow': True,
 'lr': 0.001,
 'epochs': 20,
 'max_steps_per_epoch': 2000,
 'eval_epochs_interval': 1,
 'tune_metric': 'link_prediction_map'}

### Data Loaders

In [None]:
from relbench.modeling.graph import get_link_train_table_input, make_pkey_fkey_graph
from relbench.modeling.loader import LinkNeighborLoader

"""
Prepare the training table for the link prediction task.

This function processes the input table (`train_table`) and task metadata (`task`) to structure the data required for
link prediction into a `LinkTrainTableInput` object. It creates a well-defined structure for the source and destination
nodes, their relationships, and timestamps if available.

Specifically, the function computes the following:

1. **Source Nodes (`src_nodes`)**:
   - A tuple consisting of:
     - The source node type (`task.src_entity_table`), indicating the type of entity for the source nodes.
     - A tensor (`src_node_idx`) containing the indices of the source nodes. This tensor is derived from the column
       in `table.df` specified by `task.src_entity_col`.

   **Example:**
   - Suppose `task.src_entity_col = "customer_id"` and the table contains:
     ```
     customer_id  article_id          timestamp
     0            [1, 2]   2024-01-01 12:00:00
     1            [2, 3]   2024-01-02 12:00:00
     2            [1, 3]   2024-01-03 12:00:00
     ```
   - Then, `src_node_idx` is:
     ```
     src_node_idx = tensor([0, 1, 2])  # Indices for customers
     ```

2. **Destination Nodes (`dst_nodes`)**:
   - A tuple consisting of:
     - The destination node type (`task.dst_entity_table`), indicating the type of entity for the destination nodes.
     - A sparse CSR (Compressed Sparse Row) matrix (`dst_node_indices`) representing the links between source nodes
       and destination nodes. Each row corresponds to a source node, and non-zero entries in a row indicate linked
       destination nodes.

   **Example:**
   - Suppose `task.dst_entity_col = "article_id"` and the table contains the same data as above:
     ```
     customer_id  article_id          timestamp
     0            [1, 2]   2024-01-01 12:00:00
     1            [2, 3]   2024-01-02 12:00:00
     2            [1, 3]   2024-01-03 12:00:00
     ```
   - After flattening and converting to COO and then CSR:
     ```
     COO Representation:
     Row (source): [0, 0, 1, 1, 2, 2]
     Col (dest):   [1, 2, 2, 3, 1, 3]

     CSR Representation:
     dst_node_indices =
       (row pointers) [0, 2, 4, 6]  # Indices where rows start and end
       (column indices) [1, 2, 2, 3, 1, 3]  # Linked destination nodes
     ```

3. **Number of Destination Nodes (`num_dst_nodes`)**:
   - The total number of unique destination nodes (e.g., all articles in a recommendation system).
   - In the above example, `num_dst_nodes = 4` (article IDs: 1, 2, 3).

4. **Source Timestamps (`src_time`)**:
   - Temporal information associated with the source nodes.
   - Converted to Unix time:
     ```
     src_time = tensor([1704110400, 1704196800, 1704283200])  # Unix timestamps
     ```

This structured input enables efficient neighbor sampling and subgraph preparation for training link prediction models.
"""
table_input = get_link_train_table_input(
    table=train_table,
    task=task,
)

"""
Create a data loader for training the link prediction model using neighbor sampling.

The `LinkNeighborLoader` is a custom data loader designed for temporal heterogeneous graphs, specifically for link prediction tasks.
It samples subgraphs containing the neighbors of source and destination nodes while adhering to temporal and structural constraints.
The key arguments are:

1. **Graph Data (`data`)**:
   - The input graph, a `HeteroData` object, containing node types, edge types, features, and timestamps.

2. **Number of Neighbors (`num_neighbors`)**:
   - Defines the maximum number of neighbors to sample at each GNN layer.
   - Can be specified globally as a list (e.g., `[10, 5]` for two layers) or per edge type as a dictionary.

3. **Temporal Attribute (`time_attr`)**:
   - Specifies the column in `data` that contains timestamp information.
   - Used for temporal neighbor sampling to ensure only past or relevant neighbors are included, avoiding time leakage.

4. **Source Nodes (`src_nodes`)**:
5. **Destination Nodes (`dst_nodes`)**:
6. **Number of Destination Nodes (`num_dst_nodes`)**:
7. **Source Timestamps (`src_time`)**:
   - From the previous function.

8. **Shared Time Context (`share_same_time`)**:
   - If `True`, ensures all nodes in a mini-batch share the same seed time, enabling uniform temporal context for predictions.
   - If enabled, `shuffle` is automatically set to `False` since shuffling disrupts the timestamp alignment.

9. **Batch Size (`batch_size`)**:
   - The number of source-destination node pairs to process per mini-batch.

10. **Temporal Strategy (`temporal_strategy`)**:
    - Specifies the method for temporal neighbor sampling:
      - `"uniform"`: Uniform sampling of neighbors constrained by temporal validity.
      - Other strategies may emphasize different temporal patterns.

11. **Shuffle (`shuffle`)**:
    - Controls whether the input node pairs are shuffled. If `share_same_time=True`, shuffling must be disabled to
      maintain temporal alignment.

12. **Number of Workers (`num_workers`)**:
    - Specifies the number of workers for parallel data loading. Increasing this value can improve performance
      for large graphs, but requires careful resource management.

**Purpose**:
This loader dynamically samples the subgraphs needed for each training batch, incorporating structural and temporal
constraints. It ensures scalability for large graphs by limiting the neighborhood size and focusing on the nodes
relevant for each source-destination pair.

**Output**:
Each batch from the loader includes:
- Source subgraph (pre-computed sampled neighborhood of source nodes in the batch, to use during message-passing).
- Positive destination subgraph (sampled neighborhoods for positively linked destination nodes).
- Negative destination subgraph (sampled neighborhoods for unlinked destination nodes).
"""
train_loader = LinkNeighborLoader(
    data,  # Input graph data
    num_neighbors=args["num_neighbors"],  # Number of neighbors to sample per layer
    time_attr=args["time_attr"],  # Timestamp column for temporal sampling
    src_nodes=table_input.src_nodes,  # Source nodes (type and indices)
    dst_nodes=table_input.dst_nodes,  # Destination nodes (type and sparse CSR matrix)
    num_dst_nodes=table_input.num_dst_nodes,  # Total number of destination nodes
    src_time=table_input.src_time,  # Source node timestamps
    share_same_time=args["share_same_time"],  # Shared temporal context within a batch
    batch_size=args["batch_size"],  # Number of source-destination pairs per batch
    temporal_strategy=args["temporal_strategy"],  # Temporal sampling strategy
    shuffle=not args["share_same_time"],  # Shuffle only if shared time is disabled
    num_workers=args["num_workers"],  # Number of parallel workers for data loading
)

Loading Database object from /root/.cache/relbench/rel-hm/db...
Done in 1.39 seconds.


  dst_node_indices = sparse_coo.to_sparse_csr()


**🧸 TOY EXAMPLE FOR BETTER UNDERSTANDING**


```
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader

# Create a sample heterogeneous graph:
data = HeteroData()
data['customer'].x = torch.randn(5, 16)  # 5 customers, 16 features
data['article'].x = torch.randn(6, 16)  # 6 articles, 16 features
data['customer', 'buys', 'article'].edge_index = torch.tensor(
    [[0, 1, 2, 3],  # Source nodes
     [1, 2, 0, 4]]  # Destination nodes
)

# Add temporal attributes:
data['customer', 'buys', 'article'].time = torch.tensor([100, 200, 300, 400])

# Initialize the loader:
loader = LinkNeighborLoader(
    data=data,
    num_neighbors=[2, 1],  # Sample 2 neighbors at layer 1, 1 neighbor at layer 2
    src_nodes=("customer", torch.tensor([0, 1, 2])),
    dst_nodes=("article", torch.tensor([0, 1, 2, 3])),
    num_dst_nodes=6,  # Total number of articles
    src_time=torch.tensor([100, 200, 300]),
    batch_size=2,
    time_attr="time",
    temporal_strategy="uniform",
    share_same_time=False,
    num_workers=0,
)
```

$↓$

```
Source Subgraph:
HeteroData(
  customer={
    x=[2, 16], edge_index=[2, 4], time=[4]
  }
)

Positive Destination Subgraph:
HeteroData(
  article={
    x=[2, 16], edge_index=[2, 2], time=[2]
  }
)

Negative Destination Subgraph:
HeteroData(
  article={
    x=[2, 16], edge_index=[2, 2], time=[2]
  }
)
```

Source Subgraph: Contains 2 sampled source nodes (customer) with their features, neighbors, and timestamps.
- **For each source node in a batch, we have ONE positive node and ONE negative node.**

In [None]:
from typing import Dict, Tuple
from torch_geometric.loader import NeighborLoader

"""
Prepare evaluation data loaders for validation and testing using `NeighborLoader`.

Key Details:
- Temporal Consistency:
  - Both loaders are initialized with `seed_time` to ensure all sampled neighbors respect the evaluation timestamp,
    avoiding future data leakage during validation and testing.
- Source vs. Destination Loaders:
  - The source loader processes the nodes for which predictions are made.
  - The destination loader processes the target (gold) nodes in the link prediction task.
- Output:
  - For each split (`val`, `test`), two loaders (`src_loader` and `dst_loader`) are available for evaluation tasks.
"""
eval_loaders_dict: Dict[str, Tuple[NeighborLoader, NeighborLoader]] = {}
for split in ["val", "test"]:

    # 1. Get the timestamp for the evaluation split
    timestamp = dataset.val_timestamp if split == "val" else dataset.test_timestamp
    seed_time = int(timestamp.timestamp())  # Convert to Unix time

    # 2. Retrieve the target table and source node indices
    target_table = task.get_table(split)
    src_node_indices = torch.from_numpy(target_table.df[task.src_entity_col].values)

    # 3. Create the source node loader
    src_loader = NeighborLoader(
        data,  # Graph data
        num_neighbors=args["num_neighbors"],  # Number of neighbors to sample per layer
        time_attr=args["time_attr"],  # Temporal attribute for neighbor sampling
        input_nodes=(task.src_entity_table, src_node_indices),  # Source node type and indices
        input_time=torch.full(  # Seed time tensor for all source nodes
            size=(len(src_node_indices),), fill_value=seed_time, dtype=torch.long
        ),
        batch_size=args["batch_size"],  # Batch size for source nodes
        shuffle=not args["share_same_time"],  # Disable shuffle if shared time is enforced
        num_workers=args["num_workers"],  # Parallel workers for loading
    )

    # 4. Create the destination node loader
    dst_loader = NeighborLoader(
        data,  # Graph data
        num_neighbors=args["num_neighbors"],  # Number of neighbors to sample per layer
        time_attr=args["time_attr"],  # Temporal attribute for neighbor sampling
        input_nodes=task.dst_entity_table,  # Destination node type
        input_time=torch.full(  # Seed time tensor for all destination nodes
            size=(task.num_dst_nodes,), fill_value=seed_time, dtype=torch.long
        ),
        batch_size=args["batch_size"],  # Batch size for destination nodes
        shuffle=not args["share_same_time"],  # Disable shuffle if shared time is enforced
        num_workers=args["num_workers"],  # Parallel workers for loading
    )

    # 5. Store loaders in the dictionary for the current split
    eval_loaders_dict[split] = (src_loader, dst_loader)

### Initialize the Model, Train and Evaluate

In [None]:
# Initialize the model for link prediction
model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=args["num_layers"],
    channels=args["channels"],
    out_channels=1,  # A single scalar indicating the likelihood of the customer interacting with the article
    aggr=args["aggr"],
    norm=args["norm"],
    shallow_list=[task.dst_entity_table] if args["use_shallow"] else [],
).to(device)

# Initialize the optimizer for training
optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"])

We utilize the standard **Bayesian Personalized Ranking loss (w/ mini-batches, one negative only)** for training.

In [None]:
import warnings
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

def train() -> float:
    """
    Train the model for one epoch using Bayesian Personalized Ranking (BPR) loss.

    The function iterates over the training data, computes positive and negative scores for link prediction,
    and updates the model parameters using gradient descent. The BPR loss encourages the model to assign
    higher scores to positive links than to negative links.

    Returns:
        float: The average training loss for the epoch.
    """
    model.train()  # Set the model to training mode

    # Initialize accumulators for loss and count
    loss_accum = count_accum = 0
    steps = 0

    # Determine the maximum number of training steps
    total_steps = min(len(train_loader), args["max_steps_per_epoch"])

    for batch in tqdm(train_loader, total=total_steps):  # Iterate through batches with progress bar

        # Unpack the batch into source, positive destination, and negative destination
        src_batch, batch_pos_dst, batch_neg_dst = batch
        src_batch, batch_pos_dst, batch_neg_dst = (
            src_batch.to(device),
            batch_pos_dst.to(device),
            batch_neg_dst.to(device),
        )

        # Compute embeddings for source, positive, and negative destination nodes
        x_src = model(src_batch, task.src_entity_table)  # Source embeddings
        x_pos_dst = model(batch_pos_dst, task.dst_entity_table)  # Positive destination embeddings
        x_neg_dst = model(batch_neg_dst, task.dst_entity_table)  # Negative destination embeddings

        # Compute positive scores (dot product between source and positive destination embeddings)
        pos_score = torch.sum(x_src * x_pos_dst, dim=1)  # [batch_size]

        if args["share_same_time"]:
            # Compute negative scores as a matrix product for the shared time context
            neg_score = x_src @ x_neg_dst.t()  # [batch_size, batch_size]
            pos_score = pos_score.view(-1, 1)  # Reshape positive scores to [batch_size, 1]
        else:
            # Compute negative scores as a dot product for individual pairs
            neg_score = torch.sum(x_src * x_neg_dst, dim=1)  # [batch_size]

        optimizer.zero_grad()  # Reset gradients

        # Compute the Bayesian Personalized Ranking (BPR) loss
        diff_score = pos_score - neg_score  # Difference between positive and negative scores
        loss = F.softplus(-diff_score).mean()  # BPR loss with softplus activation
        loss.backward()  # Backpropagate gradients
        optimizer.step()  # Update model parameters

        # Accumulate the loss and count
        loss_accum += float(loss) * x_src.size(0)  # Weighted by batch size
        count_accum += x_src.size(0)  # Count the number of samples

        steps += 1
        if steps > args["max_steps_per_epoch"]:  # Stop if max steps reached
            break

    # Warn if no valid destination nodes were sampled
    if count_accum == 0:
        warnings.warn(
            f"Did not sample a single '{task.dst_entity_table}' "
            f"node in any mini-batch. Try to increase the number "
            f"of layers/hops and re-try. If you run into memory "
            f"issues with deeper nets, decrease the batch size."
        )

    # Return the average loss for the epoch
    return loss_accum / count_accum if count_accum > 0 else float("nan")


@torch.no_grad()
def test(src_loader: NeighborLoader, dst_loader: NeighborLoader) -> np.ndarray:
    """
    Evaluate the model on a validation or test set using top-k predictions.

    The function computes embeddings for all destination nodes, then evaluates the top-k predicted links
    for each source node based on dot-product similarity between embeddings.

    Args:
        src_loader (NeighborLoader): Data loader for source nodes.
        dst_loader (NeighborLoader): Data loader for destination nodes.

    Returns:
        np.ndarray: An array of indices corresponding to the top-k predicted destination nodes
                    for each source node.
    """
    model.eval()  # Set the model to evaluation mode

    dst_embs: list[Tensor] = []  # List to accumulate destination node embeddings

    # Compute embeddings for all destination nodes
    for batch in tqdm(dst_loader):  # Iterate through the destination loader with a progress bar
        batch = batch.to(device)
        emb = model(batch, task.dst_entity_table).detach()  # Compute and detach embeddings
        dst_embs.append(emb)
    dst_emb = torch.cat(dst_embs, dim=0)  # Concatenate all destination embeddings
    del dst_embs  # Free memory

    pred_index_mat_list: list[Tensor] = []  # List to accumulate top-k indices

    # Compute top-k predictions for source nodes
    for batch in tqdm(src_loader):  # Iterate through the source loader
        batch = batch.to(device)
        emb = model(batch, task.src_entity_table)  # Compute source node embeddings
        # Compute dot-product similarity and retrieve top-k predictions
        _, pred_index_mat = torch.topk(emb @ dst_emb.t(), k=task.eval_k, dim=1)
        pred_index_mat_list.append(pred_index_mat.cpu())  # Move indices to CPU and store them

    # Concatenate all top-k predictions and convert to NumPy
    pred = torch.cat(pred_index_mat_list, dim=0).numpy()
    return pred  # Return the array of top-k predictions

To recap, our task is to predict a list of top $K$ target entities given a source entity at a given seed time. The metric we use is **Mean Average Precision (MAP) @ $K$**.

MAP is a commonly used evaluation metric in recommender systems, particularly for ranking-based tasks like item recommendations or link prediction. It measures the quality of the ordered list of recommended items by **considering both relevance and the ranking position of relevant items**.

Key concepts:

* **Precision:** Precision at a given position $i$ in the ranked list is the proportion of relevant items in the top-$i$ recommendations:
$\text{Precision}@k=\frac{\text{Number of relevant items in top-k}}{k}$
* **Average Precision (AP):** For a single user, AP computes the average of precision values at all positions where a relevant item is found:
$\text{AP}@k=\frac{\sum_{i=1}^n\text{Precision}@i \cdot \text{relevance}(i)}{\text{Total number of relevant items in the top-}k}$. Here, $\text{relevance}(i)=1$ if the item at position $i$ is relevant, 0 otherwise.
* **Mean Average Precision (MAP):** MAP aggregates the AP values over all users and computes the mean: $\text{MAP}=\frac{1}{|U|}\sum_{u \in U}\text{AP}_u$. Here, $U$ is the set of all users, and $\text{AP}_u$ is the Average Precision for user $u$.

<img src="https://drive.google.com/uc?export=view&id=1KZtjMvjcBugLB-CijLR59ynqClXYBAOx" width="550">

The boundary of the MAP output is **between 0 and 1, inclusive (the higher the better)**.
* $\text{MAP}=1$. Every relevant item for each user is included in the recommendation list, and all relevant items are ranked at the top.
* $\text{MAP}=0$. No relevant items are present in the recommendation lists for any user.
* A low MAP value (e.g., 0.2) means the system struggles to rank relevant items effectively, often placing them lower in the recommendation list or not recommending them at all.

In RELBENCH, the value of $K$ is set per task.
Ours, `UserItemPurchaseTask` has **$K=12$** [[Source](https://github.com/snap-stanford/relbench/blob/6bcb12a94b163c52e01cc272dfd4817cd13eff69/relbench/tasks/hm.py#L19)].

In [None]:
# CHECK THE VALUE OF K THAT WILL BE USED FOR MAP EVALUATION
task.eval_k

12

In [None]:
# BEFORE TRAINING
import json
test_pred = test(*eval_loaders_dict["test"])
test_metrics = task.evaluate(test_pred, target_table=test_table)
print(f"Best test metrics: \n{json.dumps(test_metrics, indent=2)}")

100%|██████████| 207/207 [00:09<00:00, 22.24it/s]
100%|██████████| 132/132 [00:06<00:00, 19.95it/s]


Best test metrics: 
{
  "link_prediction_precision": 1.2411136264347273e-06,
  "link_prediction_recall": 4.964454505738909e-06,
  "link_prediction_map": 6.205568132173637e-07
}


🕒 *Training requires $\approx$13 minutes per epoch.*

In [None]:
import copy
import gdown

load_pretrained = False

if load_pretrained:

    # The model has been pretrained with default hyperparameters
    # Ensure the model is compatible before loading the checkpoint
    checkpoint_url = "https://drive.google.com/uc?id=16b6z77S-p9LNAOV9PL-19ICfK-VaosJn"
    gdown.download(checkpoint_url, "./relbench_hm_graphsage_checkpoint.pt", quiet=False)
    state_dict = torch.load("./relbench_hm_graphsage_checkpoint.pt")
    model.load_state_dict(state_dict)

else:

    state_dict = None
    best_val_metric = 0
    for epoch in range(1, args["epochs"] + 1):
        train_loss = train()
        if epoch % args["eval_epochs_interval"] == 0:
            val_pred = test(*eval_loaders_dict["val"])
            val_metrics = task.evaluate(val_pred, target_table=val_table)
            print(
                f"Epoch: {epoch:02d}, Train loss: {train_loss}, "
                f"Val metrics: \n{json.dumps(val_metrics, indent=2)}"
            )
            if val_metrics[args["tune_metric"]] >= best_val_metric:
                best_val_metric = val_metrics[args["tune_metric"]]
                state_dict = copy.deepcopy(model.state_dict())
    model.load_state_dict(state_dict)
    # Save the checkpoint
    torch.save(state_dict, "/content/relbench_hm_gnn.pt")

val_pred = test(*eval_loaders_dict["val"])
val_metrics = task.evaluate(val_pred, target_table=val_table)
print(f"Best Val metrics: {val_metrics}")

Epoch: 01, Train loss: 0.1087472868700703, Val metrics:
{'link_prediction_precision': 0.005003911051514135, 'link_prediction_recall': 0.01959143846237254, 'link_prediction_map': 0.005726242268206763}

Epoch: 02, Train loss: 0.08682966694935806, Val metrics:
{'link_prediction_precision': 0.0056833165716839865, 'link_prediction_recall': 0.023962679731169785, 'link_prediction_map': 0.007834079957611971}

Epoch: 03, Train loss: 0.08096918152242169, Val metrics:
{'link_prediction_precision': 0.005184936864454128, 'link_prediction_recall': 0.021360978176741836, 'link_prediction_map': 0.007129379139913946}

Epoch: 04, Train loss: 0.0776181455900793, Val metrics:
{'link_prediction_precision': 0.0059079226729243485, 'link_prediction_recall': 0.02450972418118223, 'link_prediction_map': 0.00784053437725389}

Epoch: 05, Train loss: 0.07522295702663974, Val metrics:
{'link_prediction_precision': 0.005837523745669906, 'link_prediction_recall': 0.025018195532812256, 'link_prediction_map': 0.0077185426961791734}

Epoch: 06, Train loss: 0.07302185772978205, Val metrics:
{'link_prediction_precision': 0.006036428651245948, 'link_prediction_recall': 0.025450903598061386, 'link_prediction_map': 0.007799122365245512}

Epoch: 07, Train loss: 0.07127932919704098, Val metrics:
{'link_prediction_precision': 0.00608112638283607, 'link_prediction_recall': 0.025176688056008544, 'link_prediction_map': 0.0077879343088768285}

Epoch: 08, Train loss: 0.06931001383779586, Val metrics:
{'link_prediction_precision': 0.006159347413118783, 'link_prediction_recall': 0.025842584524293773, 'link_prediction_map': 0.00799117079994164}

Epoch: 09, Train loss: 0.0682670115169765, Val metrics:
{'link_prediction_precision': 0.006333668566320258, 'link_prediction_recall': 0.02689625655466527, 'link_prediction_map': 0.008352073392629328}

Epoch: 10, Train loss: 0.06654648913227874, Val metrics:
{'link_prediction_precision': 0.006262152195776064, 'link_prediction_recall': 0.026929272878504707, 'link_prediction_map': 0.00853130592736636}

Epoch: 11, Train loss: 0.06582279922603548, Val metrics:
{'link_prediction_precision': 0.006080008939546317, 'link_prediction_recall': 0.02601899683266872, 'link_prediction_map': 0.00820869170336041}

Epoch: 12, Train loss: 0.06526682509490099, Val metrics:
{'link_prediction_precision': 0.006180578835624091, 'link_prediction_recall': 0.02633966436163992, 'link_prediction_map': 0.008524656209741353}

Epoch: 13, Train loss: 0.06395177540836157, Val metrics:
{'link_prediction_precision': 0.006586210749804446, 'link_prediction_recall': 0.027832192113045016, 'link_prediction_map': 0.009010682303600543}

Epoch: 14, Train loss: 0.06328508630022593, Val metrics:
{'link_prediction_precision': 0.006080008939546318, 'link_prediction_recall': 0.026080306645656606, 'link_prediction_map': 0.008498832402539992}

Epoch: 15, Train loss: 0.06332891656440505, Val metrics:
{'link_prediction_precision': 0.006514694379260252, 'link_prediction_recall': 0.028010604181663984, 'link_prediction_map': 0.008871537800073928}

Epoch: 16, Train loss: 0.06318439393751565, Val metrics:
{'link_prediction_precision': 0.0065649793272991395, 'link_prediction_recall': 0.02800220728207887, 'link_prediction_map': 0.009205209084666452}

Epoch: 17, Train loss: 0.06196682529895381, Val metrics:
{'link_prediction_precision': 0.006614146832048273, 'link_prediction_recall': 0.02863960764112953, 'link_prediction_map': 0.008910399061920309}

Epoch: 18, Train loss: 0.06157738820224747, Val metrics:
{'link_prediction_precision': 0.0063873058442284044, 'link_prediction_recall': 0.027685218907965155, 'link_prediction_map': 0.008744753840983485}

Epoch: 19, Train loss: 0.061020472694350326, Val metrics:
{'link_prediction_precision': 0.006463291987931612, 'link_prediction_recall': 0.027365864870720256, 'link_prediction_map': 0.008493917513842326}

Epoch: 20, Train loss: 0.06088389920769841, Val metrics:
{'link_prediction_precision': 0.006218571907475695, 'link_prediction_recall': 0.02645116573418713, 'link_prediction_map': 0.00831022733622273}

**Best Val metrics: {'link_prediction_precision': 0.006602972399150742, 'link_prediction_recall': 0.028296511136977648, 'link_prediction_map': 0.009285367460726964}**

**Best test metrics: {'link_prediction_precision': 0.005721533817864092, 'link_prediction_recall': 0.025377126447581024, 'link_prediction_map': 0.008338150296773754}**

In [None]:
# AFTER TRAINING
import json
test_pred = test(*eval_loaders_dict["test"])
test_metrics = task.evaluate(test_pred, target_table=test_table)
print(f"Best test metrics: \n{json.dumps(test_metrics, indent=2)}")

## ⚔️ Baselines

RELBENCH baseline description:

> Despite the importance of relational databases, the rich relational information is typically foregone, as no model architecture is capable of handling varied database structures. Instead, **data is "flattened" into a simpler format such as a single table**, often by manual feature engineering, **on which standard tabular models can be used**.



### Gradient-Boosted Trees Recap

Gradient boosting is one of the most popular machine learning algorithms for tabular datasets. It is powerful enough to find any non-linear relationship between your model target and features.

Gradient boosting is one of the variants of ensemble methods where **you create multiple weak models and combine them to get better performance as a whole**.

Let's suppose to have a non-linear relationship between a feature $x$ and a target $y$.

* *Iteration 0*. A first, very naive predictor ($F_0$) could be the overall average of $y$.
* *Iteration 1*.
  * To improve our prediction, we could focus on the **residuals** (i.e., prediction errors shown as vertical blue lines in the figure below) from the first step because this is what we want **to minimize** to get a better prediction.
  * To minimize these residuals, we can build a **regression tree model** with $x$ as its feature and the residuals $r_1 = y - mean(y)$ as its target. For example:

  <img src="https://drive.google.com/uc?export=view&id=1DP0RY32Kk7ezc7s5-o9AKmnvpxXrFsJA" width="550">

  * The outputs of the first predictor are thereby corrected in sum with those from the second predictor to reduce the residuals, i.e., $F_1 = F_0 + \gamma_1$.
* ...

By repeating this process, the combined prediction $F_m$ is getting more closer to the target $y$.

<img src="https://drive.google.com/uc?export=view&id=1avR1wH3W7l7qA5HnMux0VY522WcFVrX9" width="650">

In summary, in the gradient-boosted trees algorithm, we iterate the following:
* We train a tree on the errors made at the previous iteration.
* We add the tree to the ensemble, and we predict with the new model.
* We compute the errors made for this iteration.

<img src="https://drive.google.com/uc?export=view&id=1PzE_lgfwm3xQ6R8reSdAgoT0Qwla8_8j" width="650">

Typically, to prevent overfitting, we use the method of shrinkage. We multiply the contribution of the new tree by a small factor such that the new tree does not bring too much impact on the overall prediction.

<img src="https://drive.google.com/uc?export=view&id=1Pc-xGUsD3u6bDC8WFwnVK43GJXTTwvwq" width="650">

[[Source 1](https://towardsdatascience.com/all-you-need-to-know-about-gradient-boosting-algorithm-part-1-regression-2520a34a502), [Source 2](https://newsletter.theaiedge.io/p/gbm-vs-xgboost-vs-lightgbm-vs-catboost)]

Both **XGBoost** (Extreme Gradient Boosting) and **LightGBM** (Light Gradient Boosting Machine) are advanced implementations of the gradient boosting framework and are particularly well-suited for modeling on structured tabular data. These algorithms build upon the foundation of gradient boosting, aiming to address its limitations while improving efficiency, scalability, and accuracy. Here's a detailed look at how they work and differ, especially in the context of making recommendations on single-table datasets.

#### The Single-Table Data Formalism

Tabular machine learning models, particularly those using methods like tree-based learning or feature aggregation, require data to be represented in a **flat**, single-table format. Relational databases, however, store data in normalized tables to reduce redundancy and optimize storage. Transforming relational data into a single table for machine learning requires **careful preprocessing** to preserve as much of the relational structure and information as possible.

The standard method of flattening relational data involves joining tables and engineering features to combine relevant information into a unified table. However, this process is often **computationally expensive** and risks **losing valuable predictive signals**. Flattening discards latent correlations and complex relationships between entities, which are essential for accurately capturing interactions in many predictive tasks.

**Core Preprocessing Operations**

The preprocessing phase of the `rel-hm` dataset for the link prediction task `user-item-purchase` involves a series of structured operations that transform relational data into a feature-rich table suitable for training and evaluating machine learning models. The key steps include:

1. **Entity Table Link**:
   - Data from related entity tables is merged with the interaction table. Specifically, entries with users' information and the characteristics of the purchased items connected.

2. **Feature Engineering**:
   - **Historical Context**: Features that capture past interaction patterns, such as the frequency of interactions between specific entities, are computed. These features provide insights into recurring behaviors and relationships.
   - **Global Popularity**: Metrics that quantify the overall popularity or importance of selling articles across the dataset are calculated. These metrics reflect trends and preferences within the data.

3. **Negative Sampling**:
   - Negative examples are generated by randomly pairing users with articles they never bought. These synthetic samples are essential for training the model to differentiate between meaningful and random associations.


> ‼️ **NOTE**: In addition to modifying the data handling paradigm, the following models and experiments utilize tables that have been enhanced with two additional features:
- `global_popularity_fraction` ( $\forall \; \text{item}$ ): The ratio of the number of items sold to the total number of sales.
- `num_past_visits` ( $\forall \; \langle \text{user}, \text{item}\rangle$ ): The total number of times each user has visited a specific item.

In [None]:
#@title Install Libraries

%%capture


import os
import torch

os.environ['TORCH'] = torch.__version__

!pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install relbench[full]


!pip install lightgbm\
             xgboost \
             optuna_integration \
             dask[dataframe]

In [None]:
#@title Download data

print("Importing libraries...")
import warnings
warnings.filterwarnings('ignore')


import json
import os
import warnings
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import copy
import time

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer # Assuming this was previously imported
from torch import Tensor

import torch_frame
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.config.text_embedder import TextEmbedderConfig as TextEmbedderConfig2 # Renamed to avoid conflict
from torch_frame.data import Dataset as TorchFrameDataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm

from relbench.base import Database, EntityTask, RecommendationTask, Table, TaskType, Dataset
from relbench.datasets import get_dataset
from relbench.modeling.graph import get_link_train_table_input
from relbench.modeling.loader import SparseTensor
from relbench.modeling.utils import get_stype_proposal, remove_pkey_fkey, to_unix_time
from relbench.tasks import get_task


from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


LINK_PRED_BASELINE_TARGET_COL_NAME = "link_pred_baseline_target_column_name"
PRED_SCORE_COL_NAME = "pred_score_col_name"
dataset_name = "rel-hm"
task_name = "user-item-purchase"


class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))


def make_pkey_fkey_graph(
    db: Database,
    col_to_stype_dict: Dict[str, Dict[str, stype]],
    text_embedder_cfg: Optional[TextEmbedderConfig] = None,
    cache_dir: Optional[str] = None,
) -> Tuple[HeteroData, Dict[str, Dict[str, Dict[StatType, Any]]]]:
    r"""Given a :class:`Database` object, construct a heterogeneous graph with primary-
    foreign key relationships, together with the column stats of each table.

    Args:
        db: A database object containing a set of tables.
        col_to_stype_dict: Column to stype for
            each table.
        text_embedder_cfg: Text embedder config.
        cache_dir: A directory for storing materialized tensor
            frames. If specified, we will either cache the file or use the
            cached file. If not specified, we will not use cached file and
            re-process everything from scratch without saving the cache.

    Returns:
        HeteroData: The heterogeneous :class:`PyG` object with
            :class:`TensorFrame` feature.
    """
    data = HeteroData()
    col_stats_dict = dict()
    if cache_dir is not None:
        os.makedirs(cache_dir, exist_ok=True)

    for table_name, table in db.table_dict.items():
        # Materialize the tables into tensor frames:
        df = table.df
        # Ensure that pkey is consecutive.
        if table.pkey_col is not None:
            assert (df[table.pkey_col].values == np.arange(len(df))).all()

        col_to_stype = col_to_stype_dict[table_name]

        # Remove pkey, fkey columns since they will not be used as input
        # feature.
        remove_pkey_fkey(col_to_stype, table)

        if len(col_to_stype) == 0:  # Add constant feature in case df is empty:
            col_to_stype = {"__const__": stype.numerical}
            # We need to add edges later, so we need to also keep the fkeys
            fkey_dict = {key: df[key] for key in table.fkey_col_to_pkey_table}
            df = pd.DataFrame({"__const__": np.ones(len(table.df)), **fkey_dict})

        path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
        )

        dataset_list = []

        dataset = TorchFrameDataset(
            df=df,
            col_to_stype=col_to_stype,
            col_to_text_embedder_cfg=text_embedder_cfg,
        ).materialize(path=path, device="cuda")

        data[table_name].tf = dataset.tensor_frame
        col_stats_dict[table_name] = dataset.col_stats

        # Add time attribute:
        if table.time_col is not None:
            data[table_name].time = torch.from_numpy(
                to_unix_time(table.df[table.time_col])
            )

        # Add edges:
        for fkey_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
            pkey_index = df[fkey_name]
            # Filter out dangling foreign keys
            mask = ~pkey_index.isna()
            fkey_index = torch.arange(len(pkey_index))
            # Filter dangling foreign keys:
            pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
            fkey_index = fkey_index[torch.from_numpy(mask.values)]
            # Ensure no dangling fkeys
            assert (pkey_index < len(db.table_dict[pkey_table_name])).all()

            # fkey -> pkey edges
            edge_index = torch.stack([fkey_index, pkey_index], dim=0)
            edge_type = (table_name, f"f2p_{fkey_name}", pkey_table_name)
            data[edge_type].edge_index = sort_edge_index(edge_index)

            # pkey -> fkey edges.
            # "rev_" is added so that PyG loader recognizes the reverse edges
            edge_index = torch.stack([pkey_index, fkey_index], dim=0)
            edge_type = (pkey_table_name, f"rev_f2p_{fkey_name}", table_name)
            data[edge_type].edge_index = sort_edge_index(edge_index)

    data.validate()

    return data, col_stats_dict


print("Downloading data ...")
dataset: Dataset = get_dataset(dataset_name, download=True)
task: RecommendationTask = get_task(dataset_name, task_name, download=True)
target_col_name: str = LINK_PRED_BASELINE_TARGET_COL_NAME

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

print("Getting data from tables ...")
db = dataset.get_db()
src_entity_table = db.table_dict[task.src_entity_table]
dst_entity_table = db.table_dict[task.dst_entity_table]

src_entity_df = src_entity_table.df
dst_entity_df = dst_entity_table.df
print("Data loaded successfully")


cache_dir = os.path.expanduser("~/.cache/relbench_examples")
stypes_cache_path = Path(f"{cache_dir}/{dataset_name}/stypes.json")
try:
    with open(stypes_cache_path, "r") as f:
        col_to_stype_dict = json.load(f)
    for table, c_to_s in col_to_stype_dict.items():
        for col, stype_str in c_to_s.items():
            c_to_s[col] = stype(stype_str)
except FileNotFoundError:
    col_to_stype_dict = get_stype_proposal(dataset.get_db())
    Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True)
    with open(stypes_cache_path, "w") as f:
        json.dump(col_to_stype_dict, f, indent=2, default=str)


# -----------------------------------------
# Prepare col_to_stype dictionary
print("Preparing tables ...")
col_to_stype = {}
src_entity_table_col_to_stype = copy.deepcopy(col_to_stype_dict[task.src_entity_table])
dst_entity_table_col_to_stype = copy.deepcopy(col_to_stype_dict[task.dst_entity_table])

remove_pkey_fkey(src_entity_table_col_to_stype, src_entity_table)
remove_pkey_fkey(dst_entity_table_col_to_stype, dst_entity_table)

# Resolve naming conflicts by adding _x and _y suffixes
src_dst_intersection_column_names = set(src_entity_table_col_to_stype.keys()) & set(
    dst_entity_table_col_to_stype.keys()
)
for column_name in src_dst_intersection_column_names:
    src_entity_table_col_to_stype[f"{column_name}_x"] = src_entity_table_col_to_stype[column_name]
    del src_entity_table_col_to_stype[column_name]
    dst_entity_table_col_to_stype[f"{column_name}_y"] = dst_entity_table_col_to_stype[column_name]
    del dst_entity_table_col_to_stype[column_name]

col_to_stype.update(src_entity_table_col_to_stype)
col_to_stype.update(dst_entity_table_col_to_stype)
col_to_stype[target_col_name] = torch_frame.categorical
print("Download finished with success!")

Importing libraries...
Downloading data ...
Getting data from tables ...
Loading Database object from /root/.cache/relbench/rel-hm/db...
Done in 3.41 seconds.
Data loaded successfully
Preparing tables ...
Download finished with success!


##### Data Preprocessing

Let's define the preprocessing functions described above to obtain a single-table version of data.

In [None]:
def add_past_label_feature(
    train_table_df: pd.DataFrame,
    past_table_df: pd.DataFrame,
    ) -> pd.DataFrame:
    """Add past visit count and percentage of global popularity to train table df used
    for training, evaluation of testing.

    Args:
        evaluate_table_df (pd.DataFrame): The dataframe used for evaluation.
        past_table_df (pd.DataFrame): The dataframe containing labels in the
            past.
    """
    # Add number of past visit for each src_entity and dst_entity pair
    # Explode the dst_entity list to get one row per (src_entity, dst_entity) pair
    exploded_past_table = past_table_df.explode(dst_entity)

    # Count occurrences of each (src_entity, dst_entity) pair
    dst_entity_count = (
        exploded_past_table.groupby([src_entity, dst_entity])
        .size()
        .reset_index(name="num_past_visit")
    )

    # Merge the count information with train_table_df
    train_table_df = train_table_df.merge(
        dst_entity_count, how="left", on=[src_entity, dst_entity]
    )

    # Fill NaN values with 0 (if there are any dst_entity in train_table_df not present in past_table_df)
    train_table_df["num_past_visit"] = (
        train_table_df["num_past_visit"].fillna(0).astype(int)
    )

    # Add percentage of global popularity for each dst_entity
    # Count occurrences of each dst_entity
    dst_entity_count = exploded_past_table[dst_entity].value_counts().reset_index()

    # Calculate the fraction
    # total_right_entities = len(exploded_past_table)
    dst_entity_count["global_popularity_fraction"] = (
        dst_entity_count["count"] / dst_entity_count["count"].max()
    )

    # Merge the fraction information with train_table_df
    train_table_df = train_table_df.merge(
        dst_entity_count[[dst_entity, "global_popularity_fraction"]],
        how="left",
        on=dst_entity,
    )

    # Fill NaN values with 0 (if there are any dst_entity in train_table_df not present in past_table_df)
    train_table_df["global_popularity_fraction"] = train_table_df[
        "global_popularity_fraction"
    ].fillna(0)

    return train_table_df

In [None]:
from collections import Counter

def dst_entities_aggr(dst_entities):
    r"concatenate and rank dst entities"
    dst_entities_concat = []
    for dst_entity_list in list(dst_entities):
        dst_entities_concat.extend(dst_entity_list)
    counter = Counter(dst_entities_concat)
    topk = [elem for elem, _ in counter.most_common(task.eval_k)]
    return topk


def prepare_for_link_pred_eval(
    evaluate_table_df: pd.DataFrame,
    past_table_df: pd.DataFrame
) -> pd.DataFrame:
    """Transform evaluation dataframe into the correct format for link prediction metric
    calculation.

    Args:
        pred_table_df (pd.DataFrame): The prediction dataframe.
        past_table_df (pd.DataFrame): The dataframe containing labels in the
            past.
    Returns:
        (pd.DataFrame): The evaluation dataframe containing past visit and
            global popularity dst entities as candidate set.
    """

    def interleave_lists(list1, list2):
        interleaved = [item for pair in zip(list1, list2) for item in pair]
        longer_list = list1 if len(list1) > len(list2) else list2
        interleaved.extend(longer_list[len(interleaved) // 2 :])
        return interleaved

    grouped_ranked_past_table_df = (
        past_table_df.groupby(src_entity)[dst_entity]
        .apply(dst_entities_aggr)
        .reset_index()
    )
    evaluate_table_df = pd.merge(
        evaluate_table_df, grouped_ranked_past_table_df, how="left", on=src_entity
    )

    # collect the most popular dst entities
    all_dst_entities = [
        entity for sublist in past_table_df[dst_entity] for entity in sublist
    ]
    dst_entity_counter = Counter(all_dst_entities)
    top_dst_entities = [
        entity for entity, _ in dst_entity_counter.most_common(task.eval_k * 2)
    ]

    evaluate_table_df[dst_entity] = evaluate_table_df[dst_entity].apply(
        lambda x: (
            interleave_lists(x, top_dst_entities)
            if isinstance(x, list)
            else top_dst_entities
        )
    )
    # For each src entity, keep at most `task.eval_k * 2` dst entity candidates
    evaluate_table_df[dst_entity] = evaluate_table_df[dst_entity].apply(
        lambda x: (
            x[: task.eval_k * 2]
            if isinstance(x, list) and len(x) > task.eval_k * 2
            else x
        )
    )

    # Include src and dst entity table features for `evaluate_table_df`
    evaluate_table_df = pd.merge(
        evaluate_table_df,
        src_entity_df,
        how="left",
        left_on=src_entity,
        right_on=src_entity_table.pkey_col,
    )

    evaluate_table_df = evaluate_table_df.explode(dst_entity)
    evaluate_table_df = pd.merge(
        evaluate_table_df,
        dst_entity_df,
        how="left",
        left_on=dst_entity,
        right_on=dst_entity_table.pkey_col,
    )

    evaluate_table_df = add_past_label_feature(evaluate_table_df, past_table_df)
    return evaluate_table_df

In [None]:
#@title Data sampling scheme
from IPython.display import display
import ipywidgets as widgets

# Default values for the variables
DEFAULT_TRAIN_SAMPLE_SIZE = 5000
DEFAULT_VAL_SAMPLE_SIZE = 1000
DEFAULT_TEST_SAMPLE_SIZE = 1000

# Define the text fields with placeholders for default values
var1_field = widgets.Text(
    value='',
    placeholder=f'(default {DEFAULT_TRAIN_SAMPLE_SIZE})',
    description='TRAIN_SIZE:',
    layout=widgets.Layout(width='400px')
)
var2_field = widgets.Text(
    value='',
    placeholder=f'(default {DEFAULT_VAL_SAMPLE_SIZE})',
    description='VAL_SIZE:',
    layout=widgets.Layout(width='400px')
)
var3_field = widgets.Text(
    value='',
    placeholder=f'(default {DEFAULT_TEST_SAMPLE_SIZE})',
    description='TEST_SIZE:',
    layout=widgets.Layout(width='400px')
)

# Define a button to confirm inputs
submit_button = widgets.Button(
    description="Submit",
    layout=widgets.Layout(width='150px')
)

# Output area
output = widgets.Output()

# Initialize variables with default values
TRAIN_SIZE = DEFAULT_TRAIN_SAMPLE_SIZE
VAL_SIZE = DEFAULT_VAL_SAMPLE_SIZE
TEST_SIZE = DEFAULT_TEST_SAMPLE_SIZE

# Function to handle the button click
def on_submit_button_clicked(b):
    global TRAIN_SIZE, VAL_SIZE, TEST_SIZE
    with output:
        try:
            # Use defaults if input is empty, otherwise parse the input
            TRAIN_SIZE = int(var1_field.value) if var1_field.value.strip() else DEFAULT_TRAIN_SAMPLE_SIZE
            VAL_SIZE = int(var2_field.value) if var2_field.value.strip() else DEFAULT_VAL_SAMPLE_SIZE
            TEST_SIZE = int(var3_field.value) if var3_field.value.strip() else DEFAULT_TEST_SAMPLE_SIZE
        except ValueError:
            print("Error: Please enter valid integer values or leave fields blank for default values.")
        output.clear_output()

# Attach the button click event
submit_button.on_click(on_submit_button_clicked)

# Arrange widgets vertically
form = widgets.VBox([var1_field, var2_field, var3_field, submit_button, output])

# Display the form
display(form)

VBox(children=(Text(value='', description='TRAIN_SIZE:', layout=Layout(width='400px'), placeholder='(default 5…


As mentioned before, one of the major drawbacks of the one-table formalism is the **high computational requirements**. The

In [None]:
# Sample train set
sampled_train_table = copy.deepcopy(train_table)
sampled_idx = np.random.permutation(len(sampled_train_table))[:TRAIN_SIZE]
sampled_train_table.df = sampled_train_table.df.iloc[sampled_idx]

# Sample validation set
sampled_val_table = copy.deepcopy(val_table)
sampled_idx = np.random.permutation(len(sampled_val_table))[:VAL_SIZE]
sampled_val_table.df = sampled_val_table.df.iloc[sampled_idx]

# Sample test set
sampled_test_table = copy.deepcopy(test_table)
sampled_idx = np.random.permutation(len(sampled_test_table))[:TEST_SIZE]
sampled_test_table.df = sampled_test_table.df.iloc[sampled_idx]

We can now prepare the dataset for training our baseline algorithms. Specifically, for each src entity (user), its corresponding dst entities (items) are used as **positive** label. The same number of random dst entities are sampled as **negative** label. Models will be trained and evaluated on this **binary classification task**.

In [None]:
src_entity = list(sampled_train_table.fkey_col_to_pkey_table.keys())[0]
dst_entity = list(sampled_train_table.fkey_col_to_pkey_table.keys())[1]


def timed_execution(description, func, *args, **kwargs):
    """
    Measures the execution time of a task, prints a descriptive message, and returns the result.

    Args:
        description (str): A brief description of the task being executed.
        func (callable): The function to be executed.
        *args: Positional arguments for the function.
        **kwargs: Keyword arguments for the function.

    Returns:
        The result of the function execution.
    """
    start_time = time.time()
    print(f"{description}...")
    result = func(*args, **kwargs)
    print(f"Completed {description} in {time.time() - start_time:.2f} seconds\n")
    return result


def process_split(
    split_name,
    table,
    src_entity_df,
    dst_entity_df,
    target_col_name,
    train_table_df,
    src_entity_col,
    dst_entity_col,
    additional_data=None,
):
    """
    Unified function to process splits for training, validation, validation prediction, and testing.

    Args:
        split_name (str): Name of the split being processed (e.g., "train", "val", "val_pred", "test").
        table: The data table for the split.
        src_entity_df (pd.DataFrame): Dataframe for the source entity with additional features.
        dst_entity_df (pd.DataFrame): Dataframe for the destination entity with additional features.
        target_col_name (str): Name of the column for target labels.
        train_table_df (pd.DataFrame): Training table dataframe used for past label feature computation.
        src_entity_col (str): Column name for the source entity primary key.
        dst_entity_col (str): Column name for the destination entity primary key.
        additional_data (Dict[str, pd.DataFrame], optional): Additional data required for specific splits.

    Returns:
        pd.DataFrame: The processed dataframe for the split.
    """
    print(f"{'=' * 30}\nProcessing {split_name} split\n{'=' * 30}")

    # Train and validation splits
    if split_name in ["train", "val"]:
        def process_train_val():
            # Step 1: Ensure dtype consistency
            src_entity_df_typed = src_entity_df.astype(
                {src_entity_col: table.df[src_entity_col].dtype}
            )
            dst_entity_df_typed = dst_entity_df.astype(
                {dst_entity_col: table.df[dst_entity_col].dtype}
            )

            # Step 2: Merge source entity
            df = table.df.merge(
                src_entity_df_typed,
                how="left",
                left_on=src_entity_col,
                right_on=src_entity_col,
            )

            # Step 3: Explode destination entity
            df = df.explode(dst_entity_col)

            # Step 4: Add a target column indicating positive links
            df[target_col_name] = 1

            # Step 5: Create negative samples
            negative_sample_df_columns = list(df.columns)
            negative_sample_df_columns.remove(dst_entity_col)
            negative_samples_df = df[negative_sample_df_columns]
            negative_samples_df[dst_entity_col] = np.random.choice(
                dst_entity_df_typed[dst_entity_col], size=len(negative_samples_df)
            )
            negative_samples_df[target_col_name] = 0

            # Step 6: Combine positive and negative samples
            df = pd.concat([df, negative_samples_df], ignore_index=True)

            # Step 7: Merge destination entity features
            df = pd.merge(
                df,
                dst_entity_df_typed,
                how="left",
                left_on=dst_entity_col,
                right_on=dst_entity_col,
            )

            # Step 8: Add past label feature
            df = add_past_label_feature(df, train_table_df)

            return df

        return timed_execution(f"Parsing and creating positive-negatives for {split_name} split", process_train_val)

    # Validation prediction split
    elif split_name == "val_pred":
        def prepare_val_pred():
            """
            Prepares the validation prediction split for evaluation by merging historical training data
            with validation data, and removing irrelevant columns.
            """
            val_df_pred_columns = list(table.df.columns)
            val_df_pred_columns.remove(dst_entity_col)
            val_df_pred = table.df[val_df_pred_columns]

            val_past_table_df = train_table_df.copy()
            val_past_table_df.drop(columns=[table.time_col], inplace=True)

            return prepare_for_link_pred_eval(val_df_pred, val_past_table_df)

        return timed_execution("Preparing val_pred", prepare_val_pred)

    # Test split
    elif split_name == "test":
        def prepare_test():
            """
            Prepares the test split for evaluation by merging historical train and validation data
            with the test data, and removing irrelevant columns.
            """
            test_df_columns = list(table.df.columns)
            test_df_columns.remove(dst_entity_col)
            test_df = table.df[test_df_columns]

            test_past_table_df = pd.concat(additional_data.values(), axis=0)
            test_past_table_df.drop(columns=[table.time_col], inplace=True)

            return prepare_for_link_pred_eval(test_df, test_past_table_df)

        return timed_execution("Preparing test", prepare_test)

    else:
        raise ValueError(f"Unknown split name: {split_name}")


# Extract source and destination entity columns
src_entity_col = list(sampled_train_table.fkey_col_to_pkey_table.keys())[0]
dst_entity_col = list(sampled_train_table.fkey_col_to_pkey_table.keys())[1]

# Prepare all splits
dfs = {
    "train": process_split(
        "train",
        sampled_train_table,
        src_entity_df,
        dst_entity_df,
        target_col_name,
        sampled_train_table.df,
        src_entity_col,
        dst_entity_col,
    ),
    "val": process_split(
        "val",
        sampled_val_table,
        src_entity_df,
        dst_entity_df,
        target_col_name,
        sampled_train_table.df,
        src_entity_col,
        dst_entity_col,
    ),
    "val_pred": process_split(
        "val_pred",
        sampled_val_table,
        src_entity_df,
        dst_entity_df,
        target_col_name,
        sampled_train_table.df,
        src_entity_col,
        dst_entity_col,
    ),
    "test": process_split(
        "test",
        sampled_test_table,
        src_entity_df,
        dst_entity_df,
        target_col_name,
        sampled_train_table.df,
        src_entity_col,
        dst_entity_col,
        additional_data={"train": sampled_train_table.df, "val": sampled_val_table.df},
    ),
}

print("Processing complete for all splits!")

Processing train split
Parsing and creating positive-negatives for train split...
Completed Parsing and creating positive-negatives for train split in 1.11 seconds

Processing val split
Parsing and creating positive-negatives for val split...
Completed Parsing and creating positive-negatives for val split in 1.00 seconds

Processing val_pred split
Preparing val_pred...
Completed Preparing val_pred in 0.71 seconds

Processing test split
Preparing test...
Completed Preparing test in 0.48 seconds

Processing complete for all splits!


#### XGBoost

Developed by [Tianqi Chen (2014)](https://arxiv.org/pdf/1603.02754.pdf).

* *Tree Growth Strategy:* **Level-wise**, meaning all nodes at a given depth are split before increasing the depth of the tree. It can lead to unnecessary splits in regions where further partitioning isn’t helpful, which slows down training and increases computational cost.
* *Split Finding:* **Exact**, evaluating all possible split points for numeric features. While accurate, this is computationally expensive and can be slow for datasets with high cardinality features.
* *Categorical Feature Handling:* **Requires to be encoded manually** (e.g., one-hot encoding). This can inflate dimensionality, especially for high-cardinality features, increasing computational overhead.

<img src="https://drive.google.com/uc?export=view&id=16MhV7Dw69T8vDwlLKHW4eMzL1SzEqeQN" width="650">

In [None]:
#@title **XGBoost** class definition

import copy
from typing import Any

import numpy as np
import torch
from torch import Tensor

from torch_frame import Metric, TaskType, TensorFrame, stype
from torch_frame.gbdt import GBDT
import xgboost as xgb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def neg_to_nan(x: Tensor) -> Tensor:
    r"""Convert -1 category back to NaN that can be handled by GBDT.

    Args:
        x (Tensor): Input categ. feature, where `-1` represents `NaN`.

    Returns:
        x (Tensor): Output categ. feature, where `-1` is replaced with `NaN`
    """
    is_neg = x == -1
    if is_neg.any():
        x = copy.copy(x).to(torch.float32)
        x[is_neg] = torch.nan
    return x



class XGBoost(GBDT):
    """
    Optimized XGBoost implementation extending GBDT.
    Supports hyperparameter tuning and uses GPU acceleration for training.
    """
    def _to_xgboost_input(
        self, tf: TensorFrame
    ) -> tuple[xgb.DMatrix, torch.Tensor, list[str]]:
        """
        Convert TensorFrame to XGBoost-compatible input.
        Args:
            tf (TensorFrame): Input TensorFrame.
        Returns:
            xgb.DMatrix: Input data in DMatrix format for XGBoost.
        """
        tf = tf.cpu()  # Ensure data is on CPU for conversion to DMatrix
        y = tf.y.numpy() if tf.y is not None else None

        feats = []
        types = []

        # Process categorical features
        if stype.categorical in tf.feat_dict:
            feats.append(neg_to_nan(tf.feat_dict[stype.categorical]).numpy())
            types.extend(["c"] * len(tf.col_names_dict[stype.categorical]))

        # Process numerical features
        if stype.numerical in tf.feat_dict:
            feats.append(tf.feat_dict[stype.numerical].numpy())
            types.extend(["q"] * len(tf.col_names_dict[stype.numerical]))

        # Process embedding features
        if stype.embedding in tf.feat_dict:
            feat = tf.feat_dict[stype.embedding]
            feat = feat.values.view(feat.size(0), -1).numpy()
            feats.append(feat)
            types.extend(["q"] * feat.shape[1])

        if not feats:
            raise ValueError("TensorFrame contains no features.")

        # Concatenate all features into a single matrix
        X = np.hstack(feats)
        dmatrix = xgb.DMatrix(X, label=y, enable_categorical=True)

        return dmatrix

    def objective(
        self, trial: Any, dtrain: xgb.DMatrix, dvalid: xgb.DMatrix, num_boost_round: int
    ) -> float:
        """
        Objective function for hyperparameter tuning.
        Args:
            trial (optuna.Trial): Trial object for hyperparameter optimization.
            dtrain (xgb.DMatrix): Training data.
            dvalid (xgb.DMatrix): Validation data.
            num_boost_round (int): Number of boosting rounds.
        Returns:
            float: Validation score.
        """
        import optuna

        self.params = {
            "tree_method": "gpu_hist",  # Use GPU for faster training
            "max_depth": trial.suggest_int("max_depth", 3, 10),
            "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3),
            "subsample": trial.suggest_float("subsample", 0.5, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
            "lambda": trial.suggest_float("lambda", 1e-9, 10.0, log=True),
            "alpha": trial.suggest_float("alpha", 1e-9, 10.0, log=True),
        }

        if self.task_type == TaskType.BINARY_CLASSIFICATION:
            self.params.update({"objective": "binary:logistic", "eval_metric": "auc"})
        elif self.task_type == TaskType.REGRESSION:
            self.params.update({"objective": "reg:squarederror", "eval_metric": "rmse"})
        elif self.task_type == TaskType.MULTICLASS_CLASSIFICATION:
            self.params.update(
                {
                    "objective": "multi:softmax",
                    "num_class": self._num_classes,
                    "eval_metric": "mlogloss",
                }
            )
        else:
            raise ValueError(f"Unsupported task type: {self.task_type}")

        pruning_callback = optuna.integration.XGBoostPruningCallback(
            trial, f"validation-{self.params['eval_metric']}"
        )

        booster = xgb.train(
            self.params,
            dtrain,
            num_boost_round=num_boost_round,
            early_stopping_rounds=50,
            evals=[(dvalid, "validation")],
            callbacks=[pruning_callback],
        )

        pred = booster.predict(dvalid)
        score = self.compute_metric(
            torch.tensor(dvalid.get_label()), torch.tensor(pred)
        )
        return score

    def _tune(
        self, tf_train: TensorFrame, tf_val: TensorFrame, num_trials: int, num_boost_round: int = 2000
    ):
        """
        Perform hyperparameter tuning using Optuna.
        Args:
            tf_train (TensorFrame): Training data.
            tf_val (TensorFrame): Validation data.
            num_trials (int): Number of tuning trials.
            num_boost_round (int): Number of boosting rounds.
        """
        import optuna

        dtrain = self._to_xgboost_input(tf_train)
        dvalid = self._to_xgboost_input(tf_val)

        study = optuna.create_study(direction="maximize")
        study.optimize(
            lambda trial: self.objective(trial, dtrain, dvalid, num_boost_round),
            n_trials=num_trials,
        )
        self.params.update(study.best_params)

        # Train final model with best parameters
        self.model = xgb.train(
            self.params,
            dtrain,
            num_boost_round=num_boost_round,
            evals=[(dvalid, "validation")],
            early_stopping_rounds=50,
        )

    def _predict(self, tf_test: TensorFrame) -> Tensor:
        """
        Predict using the trained model.
        Args:
            tf_test (TensorFrame): Test data.
        Returns:
            Tensor: Predictions.
        """
        dtest = self._to_xgboost_input(tf_test)
        pred = self.model.predict(dtest)
        return torch.tensor(pred, device=tf_test.device)

    def _load(self, path: str):
        """
        Load a pre-trained model from file.
        Args:
            path (str): Path to the saved model.
        """
        self.model = xgb.Booster(model_file=path)

    def _save(self, path: str):
        """
        Save the trained model to file.
        Args:
            path (str): Path to save the model.
        """
        if self.model:
            self.model.save_model(path)
        else:
            raise ValueError("No model has been trained.")

#### LightGBM

Developed by Microsoft, [Ke et al. (2017)](https://proceedings.neurips.cc/paper_files/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf).

* *Tree Growth Strategy:* **Life-wise**, selecting the leaf with the largest loss reduction to split. This produces trees that are deeper in parts of the feature space with high variance while remaining shallow in simpler regions. It captures complex, localized interactions better but risks overfitting, especially if the dataset is small or lacks regularization.
* *Split Finding:* **Histogram-based**, grouping feature values into discrete bins and finding splits based on these bins. This approach is much faster, particularly for high-dimensional or sparse data.
* *Categorical Feature Handling:* **Natively handles categorical features** by directly learning optimal splits for them, avoiding the need for preprocessing. This greatly reduces preprocessing effort, memory usage, and runtime for datasets with categorical features.

<img src="https://drive.google.com/uc?export=view&id=1uA4EdbDYvfYegsq1t85vnmtSFe1HpkHH" width="650">

In [None]:
#@title #### **LightGBM** class definition

import numpy as np
import pandas as pd
import torch
from torch import Tensor

from torch_frame import Metric, TaskType, TensorFrame, stype
from torch_frame.gbdt import GBDT
import optuna
import lightgbm

class LightGBM(GBDT):
    r"""LightGBM implementation with hyper-parameter tuning using Optuna.

    This implementation extends GBDT and aims to find optimal hyperparameters
    by optimizing the given objective function.
    """
    def _to_lightgbm_input(
        self,
        tf: TensorFrame,
        ) -> tuple[pd.DataFrame, np.ndarray, list[str]]:
        r"""Convert :class:`TensorFrame` into LightGBM-compatible input format:
        :obj:`(feat, y, cat_features)`.

        Args:
            tf (Tensor Frame): Input :obj:TensorFrame object.

        Returns:
            df (DataFrame): :obj:`DataFrame` that concatenates tensors of
                numerical and categorical features of the input
                :class:`TensorFrame`.
            y (numpy.ndarray, optional): Prediction label.
            cat_features (list[int]): Array containing indexes of
                categorical features.
        """
        tf = tf.cpu()
        y = tf.y
        if y is not None:
            y: np.ndarray = y.numpy()

        dfs: list[pd.DataFrame] = []
        cat_features_list: list[np.ndarray] = []
        offset: int = 0

        if stype.categorical in tf.feat_dict:
            feat = tf.feat_dict[stype.categorical].numpy()
            arange = np.arange(offset, offset + feat.shape[1])
            dfs.append(pd.DataFrame(feat, columns=arange))
            cat_features_list.append(arange)
            offset += feat.shape[1]

        if stype.numerical in tf.feat_dict:
            feat = tf.feat_dict[stype.numerical].numpy()
            arange = np.arange(offset, offset + feat.shape[1])
            dfs.append(pd.DataFrame(feat, columns=arange))
            offset += feat.shape[1]

        if stype.embedding in tf.feat_dict:
            feat = tf.feat_dict[stype.embedding]
            feat = feat.values
            feat = feat.view(feat.size(0), -1).numpy()
            arange = np.arange(offset, offset + feat.shape[1])
            dfs.append(pd.DataFrame(feat, columns=arange))
            offset += feat.shape[1]

        # TODO Add support for other stypes.

        if len(dfs) == 0:
            raise ValueError("The input TensorFrame object is empty.")

        df = pd.concat(dfs, axis=1)
        cat_features: list[int] = np.concatenate(
            cat_features_list,
            axis=0).tolist() if len(cat_features_list) else []

        return df, y, cat_features

    def _predict_helper(
        self,
        model: Any,
        x: pd.DataFrame,
    ) -> np.ndarray:
        r"""A helper function that applies the lightgbm model on DataFrame
        :obj:`x`.

        Args:
            model (lightgbm.Booster): The lightgbm model.
            x (DataFrame): The input `DataFrame`.

        Returns:
            pred (numpy.ndarray): The prediction output.
        """
        pred = model.predict(x)
        if self.task_type == TaskType.MULTICLASS_CLASSIFICATION:
            pred = pred.argmax(axis=1)

        return pred

    def objective(
        self,
        trial: Any,  # optuna.trial.Trial
        train_data: Any,  # lightgbm.Dataset
        eval_data: Any,  # lightgbm.Dataset
        cat_features: list[int],
        num_boost_round: int,
    ) -> float:
        r"""Objective function to be optimized.

        Args:
            trial (optuna.trial.Trial): Optuna trial object.
            train_data (lightgbm.Dataset): Train data.
            eval_data (lightgbm.Dataset): Validation data.
            cat_features (list[int]): Array containing indexes of
                categorical features.
            num_boost_round (int): Number of boosting round.

        Returns:
            float: Best objective value. Mean absolute error for
            regression task and accuracy for classification task.
        """
        self.params = {
            "verbosity":
            -1,
            "bagging_freq":
            1,
            "max_depth":
            trial.suggest_int("max_depth", 3, 11),
            "learning_rate":
            trial.suggest_float("learning_rate", 1e-3, 0.1, log=True),
            "num_leaves":
            trial.suggest_int("num_leaves", 2, 2**10),
            "subsample":
            trial.suggest_float("subsample", 0.05, 1.0),
            "colsample_bytree":
            trial.suggest_float("colsample_bytree", 0.05, 1.0),
            'lambda_l1':
            trial.suggest_float('lambda_l1', 1e-9, 10.0, log=True),
            'lambda_l2':
            trial.suggest_float('lambda_l2', 1e-9, 10.0, log=True),
            "min_data_in_leaf":
            trial.suggest_int("min_data_in_leaf", 1, 100),
        }

        if self.task_type == TaskType.REGRESSION:
            if self.metric == Metric.RMSE:
                self.params["objective"] = "regression"
                self.params["metric"] = "rmse"
            elif self.metric == Metric.MAE:
                self.params["objective"] = "regression_l1"
                self.params["metric"] = "mae"
        elif self.task_type == TaskType.BINARY_CLASSIFICATION:
            self.params["objective"] = "binary"
            if self.metric == Metric.ROCAUC:
                self.params["metric"] = "auc"
            elif self.metric == Metric.ACCURACY:
                self.params["metric"] = "binary_error"
        elif self.task_type == TaskType.MULTICLASS_CLASSIFICATION:
            self.params["objective"] = "multiclass"
            self.params["metric"] = "multi_error"
            self.params["num_class"] = self._num_classes or len(
                np.unique(train_data.label))
        else:
            raise ValueError(f"{self.__class__.__name__} is not supported for "
                             f"{self.task_type}.")

        boost = lightgbm.train(
            self.params, train_data, num_boost_round=num_boost_round,
            categorical_feature=cat_features, valid_sets=[eval_data],
            callbacks=[
                lightgbm.early_stopping(stopping_rounds=50, verbose=False),
                lightgbm.log_evaluation(period=2000)
            ])
        pred = self._predict_helper(boost, eval_data.data)
        score = self.compute_metric(torch.from_numpy(eval_data.label),
                                    torch.from_numpy(pred))
        return score

    def _tune(
        self,
        tf_train: TensorFrame,
        tf_val: TensorFrame,
        num_trials: int,
        num_boost_round=2000,
    ):
        if self.task_type == TaskType.REGRESSION:
            study = optuna.create_study(direction="minimize")
        else:
            study = optuna.create_study(direction="maximize")

        train_x, train_y, cat_features = self._to_lightgbm_input(tf_train)
        val_x, val_y, _ = self._to_lightgbm_input(tf_val)
        assert train_y is not None
        assert val_y is not None
        train_data = lightgbm.Dataset(train_x, label=train_y,
                                      free_raw_data=False)
        eval_data = lightgbm.Dataset(val_x, label=val_y, free_raw_data=False)

        study.optimize(
            lambda trial: self.objective(trial, train_data, eval_data,
                                         cat_features, num_boost_round),
            num_trials)
        self.params.update(study.best_params)

        self.model = lightgbm.train(
            self.params, train_data, num_boost_round=num_boost_round,
            categorical_feature=cat_features, valid_sets=[eval_data],
            callbacks=[
                lightgbm.early_stopping(stopping_rounds=50, verbose=False),
                lightgbm.log_evaluation(period=2000)
            ])

    def _predict(self, tf_test: TensorFrame) -> Tensor:
        device = tf_test.device
        test_x, _, _ = self._to_lightgbm_input(tf_test)
        pred = self._predict_helper(self.model, test_x)
        return torch.from_numpy(pred).to(device)

    def _load(self, path: str) -> None:
        self.model = lightgbm.Booster(model_file=path)

#### Training and Evaluation

In [None]:
train_dataset = torch_frame.data.Dataset(
    df=dfs["train"],
    col_to_stype=col_to_stype,
    target_col=target_col_name,
    col_to_text_embedder_cfg=TextEmbedderConfig(
        text_embedder=GloveTextEmbedding(device=device),
        batch_size=512,
    ),
)
train_dataset = train_dataset.materialize(device=device)

Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 36.22it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 79.33it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 38.43it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 85.68it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 66.25it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 80.31it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 81.22it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 76.63it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 71.81it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 64.75it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 75.45it/s]
Embedding raw data in mini-batch: 100%|██████████| 14/14 [00:00<00:00, 74.00it/s]


We can now anlyze the dataset and check out the newly introduced features.

In [None]:
train_dataset.df.head()

Unnamed: 0,timestamp,customer_id,article_id,FN,Active,club_member_status,fashion_news_frequency,age,postal_code,link_pred_baseline_target_column_name,...,index_name,index_group_no,index_group_name,section_no,section_name,garment_group_no,garment_group_name,detail_desc,num_past_visit,global_popularity_fraction
0,2020-07-20,1277347,68263,,,ACTIVE,NONE,18.0,0584da44f04a07d5e88012ad8243f2602432f0642097a2...,1,...,Ladieswear,1,Ladieswear,15,Womens Everyday Collection,1009,Trousers,Trousers in woven fabric. High waist with plea...,1,0.555556
1,2019-12-02,705424,73070,1.0,1.0,ACTIVE,Regularly,41.0,d9095bf603e105fb56b80dd6a05652110b255dfb1f46f4...,1,...,Lingeries/Tights,1,Ladieswear,62,"Womens Nightwear, Socks & Tigh",1017,"Under-, Nightwear",Knee-length dressing gown in soft pile with a ...,1,0.111111
2,2020-07-20,208528,98623,,,ACTIVE,NONE,24.0,f35eaf6b194d1849fd777598f86c069974f0c65ea15c00...,1,...,Divided,2,Divided,53,Divided Collection,1005,Jersey Fancy,Cropped top in ribbed jersey with short sleeve...,1,0.111111
3,2020-07-20,208528,84867,,,ACTIVE,NONE,24.0,f35eaf6b194d1849fd777598f86c069974f0c65ea15c00...,1,...,Ladieswear,1,Ladieswear,15,Womens Everyday Collection,1010,Blouses,Short-sleeved wrapover blouse in linen with a ...,1,0.222222
4,2020-07-20,208528,62642,,,ACTIVE,NONE,24.0,f35eaf6b194d1849fd777598f86c069974f0c65ea15c00...,1,...,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,"Tops in soft, organic cotton jersey with narro...",1,0.111111


Before training, we need to transform tables into tensors.

In [None]:
tf_train = train_dataset.tensor_frame
tf_val = train_dataset.convert_to_tensor_frame(dfs["val"])
tf_val_pred = train_dataset.convert_to_tensor_frame(dfs["val_pred"])
tf_test = train_dataset.convert_to_tensor_frame(dfs["test"])

Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 67.19it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 69.47it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 42.37it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 72.63it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 72.81it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 79.26it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 78.32it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 77.76it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 69.67it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 65.05it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 78.11it/s]
Embedding raw data in mini-batch: 100%|██████████| 13/13 [00:00<00:00, 75.26it/s]
Embedding raw da

In [None]:
NUM_TRIALS = 10

# Train LighGBM
tune_metric = Metric.ROCAUC
lightgbm_model = LightGBM(task_type=train_dataset.task_type,
                          metric=tune_metric)

In [None]:
lightgbm_model.tune(tf_train=tf_train,
                    tf_val=tf_val,
                    num_trials=NUM_TRIALS)

[I 2024-12-10 22:48:16,562] A new study created in memory with name: no-name-a34e7aa9-49e4-4e5d-bd0f-62e05e1e0143
[I 2024-12-10 22:48:50,406] Trial 0 finished with value: 0.8530609460390436 and parameters: {'max_depth': 10, 'learning_rate': 0.048499447167924634, 'num_leaves': 505, 'subsample': 0.6754336171443361, 'colsample_bytree': 0.8118279295826983, 'lambda_l1': 3.116058485973789e-06, 'lambda_l2': 0.0003298669825740467, 'min_data_in_leaf': 86}. Best is trial 0 with value: 0.8530609460390436.
[I 2024-12-10 22:49:33,688] Trial 1 finished with value: 0.8428766292224387 and parameters: {'max_depth': 10, 'learning_rate': 0.06270376970387177, 'num_leaves': 655, 'subsample': 0.6128006014420998, 'colsample_bytree': 0.2122855353422044, 'lambda_l1': 6.293166467265311e-05, 'lambda_l2': 8.522033681757318, 'min_data_in_leaf': 4}. Best is trial 0 with value: 0.8530609460390436.
[I 2024-12-10 22:49:50,679] Trial 2 finished with value: 0.8404196496274747 and parameters: {'max_depth': 5, 'learning_r

In [None]:
# Train XGBoost
tune_metric = Metric.ROCAUC
xgboost_model = XGBoost(task_type=train_dataset.task_type,
                        metric=tune_metric)

In [None]:
xgboost_model.tune(tf_train=tf_train,
                   tf_val=tf_val,
                   num_trials=NUM_TRIALS)

[I 2024-12-10 23:11:19,183] A new study created in memory with name: no-name-c965d43f-d8f3-4748-9676-5705938697c6


[0]	validation-auc:0.82086
[1]	validation-auc:0.82854
[2]	validation-auc:0.83051
[3]	validation-auc:0.83089
[4]	validation-auc:0.83187
[5]	validation-auc:0.83353
[6]	validation-auc:0.83529
[7]	validation-auc:0.83518
[8]	validation-auc:0.83536
[9]	validation-auc:0.83547
[10]	validation-auc:0.84004
[11]	validation-auc:0.84095
[12]	validation-auc:0.84242
[13]	validation-auc:0.84369
[14]	validation-auc:0.83749
[15]	validation-auc:0.83949
[16]	validation-auc:0.84010
[17]	validation-auc:0.84022
[18]	validation-auc:0.84138
[19]	validation-auc:0.83780
[20]	validation-auc:0.83899
[21]	validation-auc:0.84017
[22]	validation-auc:0.84097
[23]	validation-auc:0.83908
[24]	validation-auc:0.83944
[25]	validation-auc:0.84050
[26]	validation-auc:0.84149
[27]	validation-auc:0.84274
[28]	validation-auc:0.84336
[29]	validation-auc:0.84436
[30]	validation-auc:0.84449
[31]	validation-auc:0.84240
[32]	validation-auc:0.84344
[33]	validation-auc:0.84397
[34]	validation-auc:0.84463
[35]	validation-auc:0.84524
[3

[I 2024-12-10 23:11:21,401] Trial 0 finished with value: 0.8472706061796367 and parameters: {'max_depth': 3, 'learning_rate': 0.055745809595654096, 'subsample': 0.90813823473549, 'colsample_bytree': 0.8011343972341652, 'lambda': 3.2728397828723335e-08, 'alpha': 6.846296738400049e-06}. Best is trial 0 with value: 0.8472706061796367.


[0]	validation-auc:0.80652
[1]	validation-auc:0.82711
[2]	validation-auc:0.82721
[3]	validation-auc:0.83644
[4]	validation-auc:0.84271
[5]	validation-auc:0.83794
[6]	validation-auc:0.83079
[7]	validation-auc:0.83011
[8]	validation-auc:0.83022
[9]	validation-auc:0.82180
[10]	validation-auc:0.82161
[11]	validation-auc:0.82204
[12]	validation-auc:0.82418
[13]	validation-auc:0.82308
[14]	validation-auc:0.82684
[15]	validation-auc:0.82665
[16]	validation-auc:0.82700
[17]	validation-auc:0.82818
[18]	validation-auc:0.82958
[19]	validation-auc:0.82999
[20]	validation-auc:0.83082
[21]	validation-auc:0.83368
[22]	validation-auc:0.83264
[23]	validation-auc:0.83340
[24]	validation-auc:0.83332
[25]	validation-auc:0.83415
[26]	validation-auc:0.83348
[27]	validation-auc:0.83496
[28]	validation-auc:0.83440
[29]	validation-auc:0.83462
[30]	validation-auc:0.83506
[31]	validation-auc:0.83412
[32]	validation-auc:0.83554
[33]	validation-auc:0.83469
[34]	validation-auc:0.83533
[35]	validation-auc:0.83551
[3

[I 2024-12-10 23:11:22,918] Trial 1 finished with value: 0.8338651773581507 and parameters: {'max_depth': 5, 'learning_rate': 0.2474040447483419, 'subsample': 0.7905196664933589, 'colsample_bytree': 0.803132179398331, 'lambda': 8.73242667842769e-07, 'alpha': 4.911540726451648e-06}. Best is trial 0 with value: 0.8472706061796367.


[0]	validation-auc:0.82597
[1]	validation-auc:0.83771
[2]	validation-auc:0.82338
[3]	validation-auc:0.82918
[4]	validation-auc:0.82432
[5]	validation-auc:0.82094
[6]	validation-auc:0.81998
[7]	validation-auc:0.81888
[8]	validation-auc:0.82122
[9]	validation-auc:0.82363
[10]	validation-auc:0.82520
[11]	validation-auc:0.82651
[12]	validation-auc:0.82546
[13]	validation-auc:0.82733
[14]	validation-auc:0.82632
[15]	validation-auc:0.82895
[16]	validation-auc:0.82778
[17]	validation-auc:0.82852
[18]	validation-auc:0.82967
[19]	validation-auc:0.83003
[20]	validation-auc:0.83012
[21]	validation-auc:0.83075
[22]	validation-auc:0.83150
[23]	validation-auc:0.83166
[24]	validation-auc:0.83280
[25]	validation-auc:0.83353
[26]	validation-auc:0.83387
[27]	validation-auc:0.83564
[28]	validation-auc:0.83688
[29]	validation-auc:0.83558
[30]	validation-auc:0.83641
[31]	validation-auc:0.83537
[32]	validation-auc:0.83645
[33]	validation-auc:0.83722
[34]	validation-auc:0.83826
[35]	validation-auc:0.83788
[3

[I 2024-12-10 23:11:27,734] Trial 2 finished with value: 0.8524818666635923 and parameters: {'max_depth': 4, 'learning_rate': 0.030079090812104477, 'subsample': 0.8973097478044262, 'colsample_bytree': 0.8333003438681112, 'lambda': 7.238752402364147e-05, 'alpha': 1.1854026063298904}. Best is trial 2 with value: 0.8524818666635923.


[0]	validation-auc:0.72048
[1]	validation-auc:0.79816
[2]	validation-auc:0.81724
[3]	validation-auc:0.82673
[4]	validation-auc:0.83316
[5]	validation-auc:0.83096
[6]	validation-auc:0.83247
[7]	validation-auc:0.83438
[8]	validation-auc:0.83563
[9]	validation-auc:0.83356
[10]	validation-auc:0.83534
[11]	validation-auc:0.83639
[12]	validation-auc:0.83670
[13]	validation-auc:0.83709
[14]	validation-auc:0.83798
[15]	validation-auc:0.83805
[16]	validation-auc:0.83693
[17]	validation-auc:0.83691
[18]	validation-auc:0.83812
[19]	validation-auc:0.83871
[20]	validation-auc:0.83826
[21]	validation-auc:0.83853
[22]	validation-auc:0.83955
[23]	validation-auc:0.83963
[24]	validation-auc:0.84030
[25]	validation-auc:0.84065
[26]	validation-auc:0.84064
[27]	validation-auc:0.84013
[28]	validation-auc:0.83999
[29]	validation-auc:0.84070
[30]	validation-auc:0.84125
[31]	validation-auc:0.84127
[32]	validation-auc:0.84115
[33]	validation-auc:0.84148
[34]	validation-auc:0.84141
[35]	validation-auc:0.84197
[3

[I 2024-12-10 23:11:50,934] Trial 3 finished with value: 0.8432616684403373 and parameters: {'max_depth': 10, 'learning_rate': 0.03441864631825322, 'subsample': 0.8120883049488828, 'colsample_bytree': 0.6763690718113979, 'lambda': 2.556807890506349e-05, 'alpha': 3.2335517948689755e-09}. Best is trial 2 with value: 0.8524818666635923.


[0]	validation-auc:0.70272
[1]	validation-auc:0.78030
[2]	validation-auc:0.80218
[3]	validation-auc:0.81739
[4]	validation-auc:0.81528
[5]	validation-auc:0.82142
[6]	validation-auc:0.82357
[7]	validation-auc:0.82562
[8]	validation-auc:0.82692
[9]	validation-auc:0.82525
[10]	validation-auc:0.82771
[11]	validation-auc:0.82686
[12]	validation-auc:0.82882
[13]	validation-auc:0.82975
[14]	validation-auc:0.83110
[15]	validation-auc:0.83359
[16]	validation-auc:0.83510
[17]	validation-auc:0.83386
[18]	validation-auc:0.83534
[19]	validation-auc:0.83783
[20]	validation-auc:0.83861
[21]	validation-auc:0.83909
[22]	validation-auc:0.84024
[23]	validation-auc:0.83885
[24]	validation-auc:0.83836
[25]	validation-auc:0.83873
[26]	validation-auc:0.83856
[27]	validation-auc:0.83875
[28]	validation-auc:0.83759
[29]	validation-auc:0.83834
[30]	validation-auc:0.83721
[31]	validation-auc:0.83798
[32]	validation-auc:0.83773
[33]	validation-auc:0.83628
[34]	validation-auc:0.83658
[35]	validation-auc:0.83662
[3

[I 2024-12-10 23:12:00,529] Trial 4 finished with value: 0.8380256252836277 and parameters: {'max_depth': 10, 'learning_rate': 0.14532446284305878, 'subsample': 0.671630802841204, 'colsample_bytree': 0.9815616562720384, 'lambda': 6.000402459889521e-08, 'alpha': 1.8884356627529005e-06}. Best is trial 2 with value: 0.8524818666635923.


[0]	validation-auc:0.79575


[I 2024-12-10 23:12:00,561] Trial 5 pruned. Trial was pruned at iteration 0.
[I 2024-12-10 23:12:00,671] Trial 6 pruned. Trial was pruned at iteration 0.
[I 2024-12-10 23:12:00,736] Trial 7 pruned. Trial was pruned at iteration 0.


[0]	validation-auc:0.82550
[1]	validation-auc:0.83200
[2]	validation-auc:0.83174
[3]	validation-auc:0.83338
[4]	validation-auc:0.83577
[5]	validation-auc:0.83573
[6]	validation-auc:0.84034
[7]	validation-auc:0.84068
[8]	validation-auc:0.84478
[9]	validation-auc:0.83657
[10]	validation-auc:0.83951
[11]	validation-auc:0.83868
[12]	validation-auc:0.84059
[13]	validation-auc:0.84244
[14]	validation-auc:0.83997
[15]	validation-auc:0.83731
[16]	validation-auc:0.83808
[17]	validation-auc:0.83938
[18]	validation-auc:0.84090
[19]	validation-auc:0.84114
[20]	validation-auc:0.84162
[21]	validation-auc:0.83980
[22]	validation-auc:0.83920
[23]	validation-auc:0.83904
[24]	validation-auc:0.84010
[25]	validation-auc:0.84100
[26]	validation-auc:0.84196
[27]	validation-auc:0.84174
[28]	validation-auc:0.84230
[29]	validation-auc:0.84437
[30]	validation-auc:0.84245
[31]	validation-auc:0.84308
[32]	validation-auc:0.84404
[33]	validation-auc:0.84424
[34]	validation-auc:0.84509
[35]	validation-auc:0.84596
[3

[I 2024-12-10 23:12:02,591] Trial 8 finished with value: 0.8509406658332547 and parameters: {'max_depth': 3, 'learning_rate': 0.0933620125770937, 'subsample': 0.637933556195867, 'colsample_bytree': 0.6912506175835335, 'lambda': 0.0024297405841102702, 'alpha': 0.0001113233561046967}. Best is trial 2 with value: 0.8524818666635923.


[0]	validation-auc:0.73217


[I 2024-12-10 23:12:02,691] Trial 9 pruned. Trial was pruned at iteration 0.


[0]	validation-auc:0.82597
[1]	validation-auc:0.83771
[2]	validation-auc:0.82338
[3]	validation-auc:0.82918
[4]	validation-auc:0.82432
[5]	validation-auc:0.82094
[6]	validation-auc:0.81998
[7]	validation-auc:0.81888
[8]	validation-auc:0.82122
[9]	validation-auc:0.82363
[10]	validation-auc:0.82520
[11]	validation-auc:0.82651
[12]	validation-auc:0.82546
[13]	validation-auc:0.82733
[14]	validation-auc:0.82632
[15]	validation-auc:0.82895
[16]	validation-auc:0.82778
[17]	validation-auc:0.82852
[18]	validation-auc:0.82967
[19]	validation-auc:0.83003
[20]	validation-auc:0.83012
[21]	validation-auc:0.83075
[22]	validation-auc:0.83150
[23]	validation-auc:0.83166
[24]	validation-auc:0.83280
[25]	validation-auc:0.83353
[26]	validation-auc:0.83387
[27]	validation-auc:0.83564
[28]	validation-auc:0.83688
[29]	validation-auc:0.83558
[30]	validation-auc:0.83641
[31]	validation-auc:0.83537
[32]	validation-auc:0.83645
[33]	validation-auc:0.83722
[34]	validation-auc:0.83826
[35]	validation-auc:0.83788
[3

... or you can load our pre-trained weights.

In [None]:
import gdown

lightgbm_checkpoint = "/content/relbench_hm_lightgmb_checkpoint"
xgboost_checkpoint = "/content/relbench_hm_xgboost_checkpoint"

checkpoint_url = "https://drive.google.com/uc?id=1WaZTfw-Ni_oCqZ3L4DupTbTo6xbSzvCk"
gdown.download(checkpoint_url, lightgbm_checkpoint, quiet=False)
lightgbm_model.load(lightgbm_checkpoint)


checkpoint_url = "https://drive.google.com/uc?id=1NXAyzlIv_DLcj_qsXOZT8149MwNeHAB1"
gdown.download(checkpoint_url, xgboost_checkpoint, quiet=False)
xgboost_model.load(xgboost_checkpoint)

Downloading...
From: https://drive.google.com/uc?id=1WaZTfw-Ni_oCqZ3L4DupTbTo6xbSzvCk
To: /content/relbench_hm_lightgmb_checkpoint
100%|██████████| 3.39M/3.39M [00:00<00:00, 178MB/s]
Downloading...
From: https://drive.google.com/uc?id=1NXAyzlIv_DLcj_qsXOZT8149MwNeHAB1
To: /content/relbench_hm_xgboost_checkpoint
100%|██████████| 1.29M/1.29M [00:00<00:00, 110MB/s]


In [None]:
def evaluate(
    model_output: pd.DataFrame,
    src_entity_name: str,
    dst_entity_name: str,
    timestamp_col_name: str,
    eval_k: int,
    pred_score: str,
    train_table: Table,
    task: RecommendationTask,
) -> Dict[str, float]:
    """
    Evaluates the model predictions by computing recommendation metrics.

    Args:
        model_output (pd.DataFrame): DataFrame containing model predictions.
        src_entity_name (str): Name of the source entity column.
        dst_entity_name (str): Name of the destination entity column.
        timestamp_col_name (str): Name of the timestamp column.
        eval_k (int): Number of top-k predictions to evaluate.
        pred_score (str): Name of the prediction score column.
        train_table (Table): Training table to merge for past dst entities.
        task (RecommendationTask): Task containing evaluation logic.

    Returns:
        Dict[str, float]: Computed metrics for the model.
    """
    def adjust_past_dst_entities(values):
        if len(values) < eval_k:
            return values + [-1] * (eval_k - len(values))
        else:
            return values[:eval_k]

    grouped_df = (
        model_output.sort_values(pred_score, ascending=False)
        .groupby([src_entity_name, timestamp_col_name])[dst_entity_name]
        .apply(list)
        .reset_index()
    )
    grouped_df = train_table.df[[src_entity_name, timestamp_col_name]].merge(
        grouped_df, on=[src_entity_name, timestamp_col_name], how="left"
    )

    dst_entity_array = (
        grouped_df[dst_entity_name].apply(adjust_past_dst_entities).tolist()
    )
    dst_entity_array = np.array(dst_entity_array, dtype=int)
    metrics = task.evaluate(dst_entity_array, train_table)
    return metrics


def evaluate_model(model_name, model, splits, pred_col_name, task, eval_k):
    """
    Evaluates a model on multiple splits and prints the results.

    Args:
        model_name (str): Name of the model (e.g., 'LightGBM', 'XGBoost').
        model: Trained model to evaluate.
        splits (dict): Dictionary containing split data (train, val, test).
        pred_col_name (str): Name of the prediction score column.
        task (RecommendationTask): Task containing evaluation logic.
        eval_k (int): Number of top-k predictions to evaluate.
    """
    print(f"\n{'='*10} Evaluating {model_name} {'='*10}\n")

    metrics_results = {}

    for split_name, (tf_data, df_data, table) in splits.items():
        pred = model.predict(tf_test=tf_data).cpu().numpy()
        df_data[pred_col_name] = pred

        metrics = evaluate(
            df_data,
            src_entity,
            dst_entity,
            sampled_train_table.time_col,
            eval_k,
            pred_col_name,
            table,
            task,
        )

        metrics_results[split_name] = metrics
        print(f"{split_name.capitalize()} Metrics: {json.dumps(metrics, indent=2)}")

    print(f"\n{'='*10} {model_name} Evaluation Complete {'='*10}\n")
    return metrics_results


# Prepare split information
splits = {
    "train": (tf_train, dfs["train"], sampled_train_table),
    "val": (tf_val_pred, dfs["val_pred"], sampled_val_table),
    "test": (tf_test, dfs["test"], sampled_test_table),
}

# Evaluate both models
lightgbm_metrics = evaluate_model(
    model_name="LightGBM",
    model=lightgbm_model,
    splits=splits,
    pred_col_name=PRED_SCORE_COL_NAME,
    task=task,
    eval_k=task.eval_k,
)

xgboost_metrics = evaluate_model(
    model_name="XGBoost",
    model=xgboost_model,
    splits=splits,
    pred_col_name=PRED_SCORE_COL_NAME,
    task=task,
    eval_k=task.eval_k,
)



Train Metrics: {
  "link_prediction_precision": 0.26075,
  "link_prediction_recall": 0.9848450240702075,
  "link_prediction_map": 0.9085766007620526
}
Val Metrics: {
  "link_prediction_precision": 0.00125,
  "link_prediction_recall": 0.006019444444444444,
  "link_prediction_map": 0.0016770370370370368
}
Test Metrics: {
  "link_prediction_precision": 0.003083333333333333,
  "link_prediction_recall": 0.011755952380952379,
  "link_prediction_map": 0.003882435966810967
}




Train Metrics: {
  "link_prediction_precision": 0.262,
  "link_prediction_recall": 0.9861884777783351,
  "link_prediction_map": 0.9190281326405677
}
Val Metrics: {
  "link_prediction_precision": 0.002083333333333333,
  "link_prediction_recall": 0.009221464646464646,
  "link_prediction_map": 0.0033639911014911013
}
Test Metrics: {
  "link_prediction_precision": 0.003333333333333333,
  "link_prediction_recall": 0.011622619047619047,
  "link_prediction_map": 0.002937766955266955
}




### Transformer-Based Encoders

[Do Large Language Models make accurate personalized recommendations? (Kumo.ai)](https://kumo.ai/resources/blog/improving-predictions-with-large-language-models/)

Five researchers at Kumo.ai have compared the efficacy of (unsupervised) state-of-the-art transformer-based encoders and (supervised) GNNs on the `rel-hm` recommendation task.

In both cases, recommendations are produced by taking the embedding of a given customer and identifying 12 products whose embeddings have the highest cosine similarities with the customer embedding.

* **Transformer-based encoders**.
   * First, they textify products by concatenating all the product information into a single long sentence.
      * E.g., `“Product name: <name>. Product description: <description>. Color: <color>. Material: <material>, …”`
   * Then, they embed them with OpenAI `text-embedding-3-large`.
   * User embeddings are computed as the average of purchased product embeddings.

* **GNNs**.
   * Kumo's GNNs are used to embed users and products.
   * They try different text encoder for feature initialization.
      * GloVe (average word embeddings).
      * `intfloat/e5-base-v2`.
      * `text-embedding-3-large` (1024 output dimensions).

<br>
<br>

| Method                          | MAP@12     | PRECISION@12 | RECALL@12   | F1@12      |
|---------------------------------|------------|--------------|-------------|------------|
| **LLM-only:** <br> OpenAI text-embedding-3-large | 0.00190 (-93.33%) | 0.00329 (-67.84%) | 0.00119 (-97.73%) | 0.0071 (-54.60%) |
| **Kumo-GNN-only:** <br> uses GloVe for text embeddings | 0.02856           | 0.01023          | 0.05234          | 0.01564          |
| **Kumo-GNN+HuggingFace:** <br> uses intfloat/e5-base-v2 for text embeddings | 0.0297 (+4.00%)  | 0.01099 (+7.43%) | 0.05531 (+5.67%) | 0.01673 (+6.97%) |
| **Kumo-GNN+OpenAI:** <br> uses text-embedding-3-large for text embeddings | 0.02976 (+4.20%) | 0.01139 (+11.34%)| 0.0567 (+8.33%)  | 0.0173 (+10.61%) |

<br>
<br>

<img src="https://drive.google.com/uc?export=view&id=18sNn5NT2jqpmjcLp0iGtD5do3Ilv2-9Z" width="650">

## 🤖 ContextGNN

<img src="https://drive.google.com/uc?export=view&id=1mvITYIirr4enLM5NJlYBkjiIthN61n3L" width="650">

A recent GNN architecture proposed by Kumo.AI [[Yuan et al., 2024](https://arxiv.org/pdf/2411.19513)] for link prediction in recommendation systems.

### Problem

Traditionally, recommendation systems are modeled via different variants of a **two-tower paradigm**, where one tower embeds users and the other tower embeds items, which are then matched and ranked via an inner-product decoder.

This scheme proves to be highly efficient for scaling up recommendation systems during the inference phase, as it allows to pre-compute user and item representations and to perform the final ranking via fast (approximate) maximum inner product search.

However, one key limitation of two-tower based architectures for recommendation is that they learn a **pair-agnostic representation for users and items**.
* That is, the user representation is not aware of the item under consideration, and similarly, the item representation is not aware of the user and thus item representations are not capturing the uniqueness of user's view on the items. As such, **neither of the representations on both ends capture knowledge about the pair-wise dependency they are making a prediction for**.
* For example, consider a user who restocks their cosmetic products on a regular basis. In this scenario, the fine-grained context of user-cosmetic pairs is crucial, which cannot be adequately captured by two independent user and item representations alone. Such lack of knowledge has severe consequences on the quality of predictions, since, **e.g., the model is unable to distinguish between scenarios such as familiar purchases (i.e. users who repeatedly interact with the similar set of items) vs. exploratory purchases (i.e. users who like to explore new items).**

### Proposed Method

ContextGNN fuses both pair-wise representations and two-tower representations into a single architecture, enabling GNN-based recommendation systems to capture both repeated patterns and exploratory user preferences.

It proposes to use two separate GNN architectures sitting behind the same GNN backbone.

* First, pair-wise representations are learned within a user-centric subgraph using a bidirectional GNN, enabling the model to capture fine-grained local interaction signals such as repeat purchases and collaborative filtering. The root user node in the subgraph is augmented with a special ID embedding. This embedding explicitly identifies the node as the "seed" (the focal user), distinguishing it from other nodes in the subgraph.
* Second, a two-tower model with shallow item embeddings complements these local scores by handling exploratory recommendations and items outside the user subgraph.

Finally, a user-specific fusion score, produced by an MLP on the user GNN representation, is added to the scores of the pair-wise representation mode. It captures how exploratory a specific user is.

#### Theoretical Formulation

Context-based Graph Neural Network (ContextGNN) introduces a novel approach to information modeling, addressing the expressive and efficiency limitations of the previously discussed GraphSAGE and ID-GNN methodologies. This is achieved by combining two distinct representation models to generate more meaningful data embeddings.

The first model leverages pairwise representations derived from the item candidate set within the local user-centric subgraph. In contrast, the second model employs a two-tower architecture built on shallow item representations, enabling the prediction of rankings for all user-item pairs beyond the user's immediate subgraph.

To integrate these models, a user-specific fusion score is computed using an MLP applied to the user's GNN representation. This score is combined with the outputs of the pairwise representation model, aligning the distinct scoring mechanisms of the two models. It also accounts for the user's exploratory behavior, assigning greater or lesser importance to each model as needed.

Let's define the temporal recommendation problem on a **heterogeneous graph snapshot** $\mathcal{G}^{(-\infty, T]} = (\mathcal{V}, \mathcal{E}, \phi, \psi)$ up to timestamp $T$. Here:

- $\mathcal{V}$ represents the set of nodes.
- $\mathcal{E} \subseteq \mathcal{V} \times \mathcal{V}$ denotes the set of edges.
- Each node $v \in \mathcal{V}$ belongs to a **node type** $\phi(v)$.
- Each edge $e \in \mathcal{E}$ belongs to an **edge type** $\psi(e)$.

We define two subsets of nodes within this heterogeneous graph:
- $\mathcal{L} \subset \mathcal{V}$, representing the set of users.
- $\mathcal{R} \subset \mathcal{V}$, representing the set of items.

The goal is to predict the set of ground-truth items $\mathcal{Y}_v^{(T, T+i]} \subseteq \mathcal{R}$, where a link exists between a user $v \in \mathcal{L} $ and an item within the time interval $(T, T+i]$ for a given interval size $i$. The model is restricted to using only historical information available up to timestamp $T$ for making predictions.

##### **Pair-wise Representation**

Rather than learning a pair-wise representation via two **indipendent** user and item representations $h_v^{(k)}$ and $h_w^{(k)}$, ContextGNN utilizes the user specific subgraph and reads out GNN's item representations from it.

More in details, the pair-wise model processes information following these steps:
1. Sample a $k$-hop subgraph  $$ \tilde{\mathcal{G}} \gets \mathcal{G}_k^{(-\infty, T]}[v] $$ with node set $\tilde{\mathcal{V}}$ around user $v \in \mathcal{L}$. To further facilitate the extraction of meaningful item node representations, the sampled sub-graph is transformed into a **bidirectional** graph.

2. Add an indicator representation to the user seed node:  
   $$ h_v^{(0)} \gets h_v^{(0)} + \text{INDICATOR}_\theta $$

3. Read out <u>both GNN user and item representations</u> at layer $k$:  
   $$ h_v^{(k)}, \{ h_w^{(k)} : w \in \tilde{\mathcal{V}} \cap \mathcal{R} \} \gets \text{GNN}_{\theta}^{(k)}(\tilde{\mathcal{G}}, \mathbf{H}^{(0)}) $$

4. Compute the final ranking for all items $w \in \tilde{\mathcal{V}} \cap \mathcal{R}$:  $$ y_{(v, w)}^{(\text{pair})} \gets h_v^{(k)} \cdot h_w^{(k)}$$

where $\text{INDICATOR}_\theta$ is a $d$-dimensional representation added to the seed user node to differentitate it and bias its embedding with respect to the other sampled neighbouring nodes.

This approach efficiently handles temporal, heterogeneous graphs by integrating multi-behavior signals through GNNs. It captures user-item interactions like recency and clicks, requiring only a single GNN pass for predictions, with a complexity of $\mathcal{O}(\lvert \mathcal{L} \rvert)$, outperforming two-tower models. While adaptable to new users and items, its limited candidate set restricts its suitability for diverse recommendation needs.

##### **Two-Tower Representation**
As a fallback mechanism to supplement the pair-wise representations, ContextGNN's two-tower model ranks all user-item pairs **outside** the user's subgraph.

Instead of using another GNN network -- whose problems of over-squashed representations in dense edge configurations are well-known -- ContextGNN introduces a shallow embedding matrix $\mathbf{W} \in \mathbb{R}^{\lvert \mathcal{R} \rvert \times d}$ to learn item representations. While these embedding matrices retain high expressive power, they also support training against a much larger corpus of negative samples, a critical issue of previous GNN-only-based architectures directly affecting the model performances.


##### **Unified Representation**

Shallow item embeddings $\omega_w \in \mathbf{W}$ are injected within the user's GNN forward pass to better align user representtions to the corresponding item representations. Such conditioning mechanism can be summarized as follow:
1. Sample a $k$-hop subgraph  
   $$ \tilde{\mathcal{G}} \gets \mathcal{G}_k^{(-\infty, T]}[v] $$  
   with node set $\tilde{\mathcal{V}}$ around user $v \in \mathcal{L}$

2. Add the shallow embedding to all sampled items $w \in \tilde{\mathcal{V}} \cap \mathcal{R}$:  
   $$ h_w^{(0)} \gets h_w^{(0)} + \omega_w $$

3. Read out the GNN user representation at layer $k$:  
   $$ h_v^{(k)} \gets \text{GNN}_{\theta}^{(k)}(\tilde{\mathcal{G}}, \mathbf{H}^{(0)})$$

4. Compute the final ranking for all items $w \in \mathcal{R} \setminus \tilde{\mathcal{V}}$:  
   $$ y_{(v, w)}^{(\text{tower})} \gets h_v^{(k)} \cdot w_w $$


ContexGNN then fuses both pair-wise and two-tower representations into a single architecture. Namely, for all items $w \in \tilde{\mathcal{V}} \cup \mathcal{R}$ inside the locak user subgraph, we leverage the $y^{(\text{pair})}_{(v,w)}$ score for entities within the sampled subgraph and the two-tower $y^{(\text{tower})}_{(v,w)}$ scores for all items $w \in \tilde{\mathcal{V}} \setminus \mathcal{R} $ outside the sampled subgraph.


To align diverse user behaviors, CONTEXTGNN incorporates a user-specific fusion score, predicted from the GNN’s user embeddings $h_v^{(k)}$ using an MLP parameterized by $\theta$. This personalized fusion score adjusts the distinct scores by learning whether a user favors familiar items or exploratory purchases, refining the final ranking scores accordingly. As a result, the final score is computed as follows:

$$
y_{(v, w)} =
\begin{cases}
y_{(v, w)}^{(\text{pair})} + \text{MLP}_\theta \left( h_v^{(k)} \right) & \text{if } w \in \tilde{\mathcal{V}} \cup \mathcal{R}, \\
y_{(v, w)}^{(\text{tower})} & \text{otherwise.}
\end{cases}
$$

This approach ensures that CONTEXTGNN is highly computationally efficient. The model is trained end-to-end, jointly optimizing both item scores and the fusion score to maximize its predictive performance for future user-item interactions. In practice, cross-entropy loss is used for optimization.


### Implementation and Experiments
We can download the offial code from https://github.com/yiweny/ContextGNN.

In [None]:
%%capture

!git clone https://github.com/yiweny/ContextGNN.git
%cd ContextGNN

# install dependencies
!pip install -e '.[full]'

In [None]:
#@title Import libraries

import warnings
warnings.filterwarnings('ignore')

import json
import os
import warnings
from pathlib import Path
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union, List

import numpy as np
import torch
import torch.nn.functional as F
from relbench.base import Dataset, RecommendationTask, TaskType
from relbench.datasets import get_dataset
from relbench.modeling.graph import get_link_train_table_input
from relbench.modeling.loader import SparseTensor
from relbench.modeling.utils import get_stype_proposal
from relbench.tasks import get_task
from torch import Tensor
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from torch_geometric.typing import NodeType
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset as TorchFrameDataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index
from sentence_transformers import SentenceTransformer

from contextgnn.nn.models import IDGNN, ContextGNN, ShallowRHSGNN
from contextgnn.utils import RHSEmbeddingMode


from relbench.base import Database, EntityTask, RecommendationTask, Table, TaskType
from relbench.modeling.utils import remove_pkey_fkey, to_unix_time

random_seed = 42  # Random seed for reproducibility

# Setup device and random seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.set_num_threads(1)
seed_everything(random_seed)



class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))


def make_pkey_fkey_graph(
    db: Database,
    col_to_stype_dict: Dict[str, Dict[str, stype]],
    text_embedder_cfg: Optional[TextEmbedderConfig] = None,
    cache_dir: Optional[str] = None,
) -> Tuple[HeteroData, Dict[str, Dict[str, Dict[StatType, Any]]]]:
    r"""Given a :class:`Database` object, construct a heterogeneous graph with primary-
    foreign key relationships, together with the column stats of each table.

    Args:
        db: A database object containing a set of tables.
        col_to_stype_dict: Column to stype for
            each table.
        text_embedder_cfg: Text embedder config.
        cache_dir: A directory for storing materialized tensor
            frames. If specified, we will either cache the file or use the
            cached file. If not specified, we will not use cached file and
            re-process everything from scratch without saving the cache.

    Returns:
        HeteroData: The heterogeneous :class:`PyG` object with
            :class:`TensorFrame` feature.
    """
    data = HeteroData()
    col_stats_dict = dict()
    if cache_dir is not None:
        os.makedirs(cache_dir, exist_ok=True)

    for table_name, table in db.table_dict.items():
        # Materialize the tables into tensor frames:
        df = table.df
        # Ensure that pkey is consecutive.
        if table.pkey_col is not None:
            assert (df[table.pkey_col].values == np.arange(len(df))).all()

        col_to_stype = col_to_stype_dict[table_name]

        # Remove pkey, fkey columns since they will not be used as input
        # feature.
        remove_pkey_fkey(col_to_stype, table)

        if len(col_to_stype) == 0:  # Add constant feature in case df is empty:
            col_to_stype = {"__const__": stype.numerical}
            # We need to add edges later, so we need to also keep the fkeys
            fkey_dict = {key: df[key] for key in table.fkey_col_to_pkey_table}
            df = pd.DataFrame({"__const__": np.ones(len(table.df)), **fkey_dict})

        path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
        )

        dataset_list = []

        dataset = TorchFrameDataset(
            df=df,
            col_to_stype=col_to_stype,
            col_to_text_embedder_cfg=text_embedder_cfg,
        ).materialize(path=path, device="cuda")

        data[table_name].tf = dataset.tensor_frame
        col_stats_dict[table_name] = dataset.col_stats

        # Add time attribute:
        if table.time_col is not None:
            data[table_name].time = torch.from_numpy(
                to_unix_time(table.df[table.time_col])
            )

        # Add edges:
        for fkey_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
            pkey_index = df[fkey_name]
            # Filter out dangling foreign keys
            mask = ~pkey_index.isna()
            fkey_index = torch.arange(len(pkey_index))
            # Filter dangling foreign keys:
            pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
            fkey_index = fkey_index[torch.from_numpy(mask.values)]
            # Ensure no dangling fkeys
            assert (pkey_index < len(db.table_dict[pkey_table_name])).all()

            # fkey -> pkey edges
            edge_index = torch.stack([fkey_index, pkey_index], dim=0)
            edge_type = (table_name, f"f2p_{fkey_name}", pkey_table_name)
            data[edge_type].edge_index = sort_edge_index(edge_index)

            # pkey -> fkey edges.
            # "rev_" is added so that PyG loader recognizes the reverse edges
            edge_index = torch.stack([pkey_index, fkey_index], dim=0)
            edge_type = (pkey_table_name, f"rev_f2p_{fkey_name}", table_name)
            data[edge_type].edge_index = sort_edge_index(edge_index)

    data.validate()

    return data, col_stats_dict

Let's re-defined some function we used above for the sake of clear code and easy running.

In [None]:
dataset_name = "rel-hm"               # Dataset to use
task_name = "user-item-purchase"      # Task to evaluate
tune_metric = "link_prediction_map"

learning_rate = 0.001                                           # Learning rate for training
epochs = 1                                                      # Number of training epochs
eval_epochs_interval = 1                                        # Evaluation interval
batch_size = 1                                                  # Batch size for training
channels = 128                                                   # Number of channels in the model
aggregation = "sum"                                             # Aggregation method
num_layers = 4                                                  # Number of layers in the model
num_neighbors = 128                                              # Number of neighbors for sampling
temporal_strategy = "last"                                      # Temporal sampling strategy
max_steps_per_epoch = 2000                                      # Max steps per epoch
num_workers = 0                                                 # Number of data loader workers
cache_dir = os.path.expanduser("~/.cache/relbench_examples")    # Cache directory

In [None]:
# Load dataset and task
dataset: Dataset = get_dataset(dataset_name, download=True)
task: RecommendationTask = get_task(dataset_name, task_name, download=True)

In [None]:
# Load column types and statistics
stypes_cache_path = Path(f"{cache_dir}/{dataset_name}/stypes.json")
dataset_cache_dir = "/content/cache/materialized"

try:
    with open(stypes_cache_path, "r") as f:
        col_to_stype_dict = json.load(f)
    for table, col_to_stype in col_to_stype_dict.items():
        for col, stype_str in col_to_stype.items():
            col_to_stype[col] = stype(stype_str)
except FileNotFoundError:
    col_to_stype_dict = get_stype_proposal(dataset.get_db())
    Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True)
    with open(stypes_cache_path, "w") as f:
        json.dump(col_to_stype_dict, f, indent=2, default=str)

# Prepare graph data and column statistics
data, col_stats_dict = make_pkey_fkey_graph(
    dataset.get_db(),
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=TextEmbedderConfig(
        text_embedder=GloveTextEmbedding(device=device), batch_size=512),
    cache_dir=dataset_cache_dir,
)

# Define neighbors per layer
num_neighbors_per_layer = [
    int(num_neighbors // 2**i) for i in range(num_layers)
]

# Prepare loaders for train, val, and test sets
loader_dict: Dict[str, NeighborLoader] = {}
dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {}
num_dst_nodes_dict: Dict[str, int] = {}
for split in ["train", "val", "test"]:
    table = task.get_table(split)
    table_input = get_link_train_table_input(table, task)
    dst_nodes_dict[split] = table_input.dst_nodes
    num_dst_nodes_dict[split] = table_input.num_dst_nodes
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=num_neighbors_per_layer,
        time_attr="time",
        input_nodes=table_input.src_nodes,
        input_time=table_input.src_time,
        subgraph_type="bidirectional",
        batch_size=batch_size,
        temporal_strategy=temporal_strategy,
        shuffle=split == "train",
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
    )

Loading Database object from /root/.cache/relbench/rel-hm/db...
Done in 14.31 seconds.
Loading Database object from /root/.cache/relbench/rel-hm/db...
Done in 2.90 seconds.


In [None]:
#@title **Select** model type
model_type = "contextgnn"

import ipywidgets as widgets
from IPython.display import display

dropdown = widgets.Dropdown(
    options=['contextgnn', 'idgnn'],
    value='contextgnn',
    description='Select:',
    disabled=False,
)

def on_dropdown_change(change):
    global model_type
    if change['type'] == 'change' and change['name'] == 'value':
        model_type = change['new']

dropdown.observe(on_dropdown_change)
display(dropdown)

Dropdown(description='Select:', options=('contextgnn', 'idgnn'), value='contextgnn')

In [None]:
# Initialize the model
if model_type == "idgnn":
    model = IDGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        num_layers=num_layers,
        channels=channels,
        out_channels=1,
        aggr=aggregation,
        norm="layer_norm",
        torch_frame_model_kwargs={
            "channels": channels,
            "num_layers": num_layers,
        },
    ).to(device)
else:
    #model_type == "contextgnn"
    model = ContextGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        rhs_emb_mode=RHSEmbeddingMode.FUSION,
        dst_entity_table=task.dst_entity_table,
        num_nodes=num_dst_nodes_dict["train"],
        num_layers=num_layers,
        channels=channels,
        aggr="sum",
        norm="layer_norm",
        embedding_dim=64,
        torch_frame_model_kwargs={
            "channels": channels,
            "num_layers": num_layers,
        },
    ).to(device)

We now define the training and evaluation functions, which follow a slightly different processing workflow for ContextGNN compared to the original ID-GNN version.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
def train() -> float:
    """
    Train the model for one epoch.

    Returns:
        float: The average training loss for the epoch.
    """
    model.train()

    loss_accum = count_accum = 0
    steps = 0
    total_steps = min(len(loader_dict["train"]), max_steps_per_epoch)
    sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device)

    for batch in tqdm(loader_dict["train"], total=total_steps, desc="Train"):
        batch = batch.to(device)

        # Get ground-truth source and destination indices
        input_id = batch[task.src_entity_table].input_id
        src_batch, dst_index = sparse_tensor[input_id]

        # Reset gradients
        optimizer.zero_grad()

        if model_type == 'idgnn':
            # Forward pass for IDGNN
            out = model(batch, task.src_entity_table, task.dst_entity_table).flatten()
            batch_size = batch[task.src_entity_table].batch_size

            # Compute target labels for IDGNN
            target = torch.isin(
                batch[task.dst_entity_table].batch +
                batch_size * batch[task.dst_entity_table].n_id,
                src_batch + batch_size * dst_index,
            ).float()

            # Compute binary cross-entropy loss
            loss = F.binary_cross_entropy_with_logits(out, target)
            numel = out.numel()

        else:
            # model_type in ['contextgnn', 'shallowrhsgnn']:
            # Forward pass for ContextGNN and ShallowRHSGNN
            logits = model(batch, task.src_entity_table, task.dst_entity_table)

            # Construct edge label index
            edge_label_index = torch.stack([src_batch, dst_index], dim=0)

            # Compute sparse cross-entropy loss
            loss = sparse_cross_entropy(logits, edge_label_index)
            numel = len(batch[task.dst_entity_table].batch)

        # Backward pass and optimizer step
        loss.backward()
        optimizer.step()

        # Accumulate loss and count
        loss_accum += float(loss) * numel
        count_accum += numel

        # Log training loss to wandb
        print({"batch_train_loss": float(loss)})

        # Increment step counter and check early stopping condition
        steps += 1
        if steps > max_steps_per_epoch:
            break

    if count_accum == 0:
        warnings.warn(
            f"Did not sample a single '{task.dst_entity_table}' node in any mini-batch. "
            f"Try increasing the number of layers/hops or decreasing the batch size."
        )

    # Log average training loss to wandb
    avg_loss = loss_accum / count_accum if count_accum > 0 else float("nan")
    print({"train_loss": avg_loss})
    return avg_loss


@torch.no_grad()
def test(loader: NeighborLoader, desc: str) -> np.ndarray:
    """
    Evaluate the model using a data loader.

    Args:
        loader (NeighborLoader): Data loader for evaluation (val/test).
        desc (str): Description of the evaluation phase (e.g., "Validation").

    Returns:
        np.ndarray: Top-K predictions for the task.
    """
    model.eval()
    pred_list: List[Tensor] = []


    for batch in tqdm(loader, desc=desc):
        batch = batch.to(device)
        batch_size = batch[task.src_entity_table].batch_size

        if model_type == "idgnn":
            # Forward pass for IDGNN
            logits = model(batch, task.src_entity_table, task.dst_entity_table).detach().flatten()
            scores = torch.zeros(batch_size, task.num_dst_nodes, device=logits.device)
            scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(logits)

        else:
            #model_type in ["contextgnn", "shallowrhsgnn"]:
            # Forward pass for ContextGNN and ShallowRHSGNN
            logits = model(batch, task.src_entity_table, task.dst_entity_table).detach()
            scores = torch.sigmoid(logits)

        # Collect top-K predictions
        _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
        pred_list.append(pred_mini)

    # Concatenate predictions and return as a NumPy array
    pred = torch.cat(pred_list, dim=0).cpu().numpy()
    return pred

In [None]:
print("EVALUATION BEFORE TRAINING...")
val_pred = test(loader_dict["val"], desc="Best Validation")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best Validation Metrics: {json.dumps(val_metrics, indent=2)}")

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred, task.get_table("test"))
print(f"Test Metrics: {json.dumps(test_metrics, indent=2)}")

In [None]:
# Training and evaluation loop
best_state = None
best_val_metric = 0

for epoch in range(1, epochs + 1):
    # Train the model for one epoch
    train_loss = train()

    if epoch % eval_epochs_interval == 0:
        # Evaluate on the validation set
        val_pred = test(loader_dict["val"], desc="Validation")
        val_metrics = task.evaluate(val_pred, task.get_table("val"))
        print(f"Epoch: {epoch:02d}, Train loss: {train_loss:.4f}, Val metrics: {val_metrics}")

        # Save the best model state
        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

# Ensure the best model state is saved
assert best_state is not None
model.load_state_dict(best_state)

... or load our pretrained checkpoint

In [None]:
import gdown

# Checkpoint configuration
#"config": {
#    "model_type": "contextgnn",
#    "seed": 42,
#    "learning_rate": 0.001,
#    "epochs": 20,
#    "eval_epochs_interval": 5,
#    "batch_size": 200,
#    "channels": 128,
#    "aggregation": "sum",
#    "num_layers": 4,
#    "num_neighbors": 128,
#    "temporal_strategy": "last",
#    "max_steps_per_epoch": 2000,
#    "num_workers": 0,
#    "embedding_dim": 64
#}

checkpoint_url = "https://drive.google.com/uc?id=1IK0F5iII2fC9MdU07JD-C9PySMP4uRje"
gdown.download(checkpoint_url, "/content/relbench_hm_contextgnn_checkpoint", quiet=False)

state_dict = torch.load("/content/relbench_hm_contextgnn_checkpoint")
model.load_state_dict(state_dict)

Downloading...
From (original): https://drive.google.com/uc?id=1IK0F5iII2fC9MdU07JD-C9PySMP4uRje
From (redirected): https://drive.google.com/uc?id=1IK0F5iII2fC9MdU07JD-C9PySMP4uRje&confirm=t&uuid=b3f19611-96d0-4247-824c-bdc86f588622
To: /content/relbench_hm_contextgnn_checkpoint
100%|██████████| 37.7M/37.7M [00:01<00:00, 22.7MB/s]


<All keys matched successfully>

In [None]:
# Evaluate on the validation set using the best model
print("EVALUATION AFTER TRAINING...")
# Evaluate on the validation set using the best model
val_pred = test(loader_dict["val"], desc="Best Validation")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best Validation Metrics: {val_metrics}")

# Evaluate on the test set using the best model
test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred, task.get_table("test"))
print(f"Test Metrics: {json.dumps(test_metrics, indent=2)}")

EVALUATION AFTER TRAINING...


Best Validation: 100%|██████████| 74575/74575 [25:54<00:00, 47.96it/s]


Best Validation Metrics: {'link_prediction_precision': 0.006614146832048273, 'link_prediction_recall': 0.034180961287795825, 'link_prediction_map': 0.021203250508317745}


Test: 100%|█████████▉| 67094/67144 [22:21<00:00, 55.03it/s]