<a href="https://colab.research.google.com/github/Bryan-Az/TimeGPT-Tabula9-RDL/blob/main/RelBench/RelBench_GNN_Inference_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RelBench

Relbench or 'Relational Bench' is a library that also provides benchmark datasets for evaluating relation deep learning algorithms. It can convert tabular data into a graph data structure and leverage the graph data structure for more efficient model inference.

This notebook uses the google colab L4 GPU environment, which is required to load the Rel-Stack database (as it requires more than 15gb ram as in the CPU environment). It also allows for faster inference using the GPU.

## Imports and Installs

In [1]:
%%capture
# install torch 2.1.0
!pip install torch==2.1.0+cu121 -f https://download.pytorch.org/whl/cu121/torch_stable.html

!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
!pip install pyg-lib -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
# install torch cluster and torch sparse with pip
!pip install torch_geometric -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
!pip install pytorch_frame[full] #PyTorch Frame for working with tabular data
!pip install relbench[full]


In [2]:
import torch
torch.__version__

'2.1.0+cu121'

In [3]:
import torch_geometric
torch_geometric.__version__


'2.6.1'

In [4]:
import relbench

relbench.__version__

'1.1.0'

In [5]:
# for data ETL
import numpy as np
from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task

# for converting the relation tables into a graph structure, GNN
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# for text encoding and embeddings
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor

## Data Loading and Transformation

### What kind of data?
Relbench hosts a variety of databases with datasets in a variety of domains that have been pre-structured to work with their graph neural network architecture. As this is an example showcase of Relbench's capabilities - I will chose a database at random (rel-stack) and infer using the datasets within that database.

Rel-stack is a database consisting of data provided by Stack Exchange's question and answering platform. The base training data is from 2010 to 2019 while the validation to testing data ranges from 2019-2021.

### What kind of prediction or forecasting task?
The task must be defined when crafting the datasets. Rel-bench's databases offer pre-defined tasks for node classification, regression or link prediction. The meaning of each task is mentioned in more detail on the [documentation](https://relbench.stanford.edu/datasets/rel-stack/). For this example, I'll be using the user-badge classification task: For each user predict if a user will receive a new badge in the next 3 months, evaluated using AUROC metrics.


In [6]:
dataset = get_dataset("rel-stack", download=True)
# the task is the type of prediction or forecasting
# task to be used when creating the GNN
# Available tasks in rel-stack: user-engagement, user-badge, post-votes
task = get_task("rel-stack", "user-badge", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "roc_auc"
higher_is_better = False

Downloading file 'rel-stack/db.zip' from 'https://relbench.stanford.edu/download/rel-stack/db.zip' to '/root/.cache/relbench'.
100%|████████████████████████████████████████| 880M/880M [00:00<00:00, 879GB/s]
Unzipping contents of '/root/.cache/relbench/rel-stack/db.zip' to '/root/.cache/relbench/rel-stack/.'
Downloading file 'rel-stack/tasks/user-badge.zip' from 'https://relbench.stanford.edu/download/rel-stack/tasks/user-badge.zip' to '/root/.cache/relbench'.
100%|█████████████████████████████████████| 6.15M/6.15M [00:00<00:00, 2.37GB/s]
Unzipping contents of '/root/.cache/relbench/rel-stack/tasks/user-badge.zip' to '/root/.cache/relbench/rel-stack/tasks/.'


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "./data"

cuda


In [8]:
# this is the method relbench created to infer the data types to be used in the GNN
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()

Loading Database object from /root/.cache/relbench/rel-stack/db...
Done in 10.43 seconds.


In [9]:
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

{'comments': {'Id': <stype.numerical: 'numerical'>,
  'PostId': <stype.numerical: 'numerical'>,
  'UserId': <stype.numerical: 'numerical'>,
  'ContentLicense': <stype.categorical: 'categorical'>,
  'UserDisplayName': <stype.text_embedded: 'text_embedded'>,
  'Text': <stype.text_embedded: 'text_embedded'>,
  'CreationDate': <stype.timestamp: 'timestamp'>},
 'postLinks': {'Id': <stype.numerical: 'numerical'>,
  'RelatedPostId': <stype.numerical: 'numerical'>,
  'PostId': <stype.numerical: 'numerical'>,
  'LinkTypeId': <stype.categorical: 'categorical'>,
  'CreationDate': <stype.timestamp: 'timestamp'>},
 'posts': {'Id': <stype.numerical: 'numerical'>,
  'OwnerUserId': <stype.numerical: 'numerical'>,
  'PostTypeId': <stype.numerical: 'numerical'>,
  'AcceptedAnswerId': <stype.numerical: 'numerical'>,
  'ParentId': <stype.numerical: 'numerical'>,
  'OwnerDisplayName': <stype.text_embedded: 'text_embedded'>,
  'Title': <stype.text_embedded: 'text_embedded'>,
  'Tags': <stype.text_embedded: 

In [10]:
# this is the embedding model Relbench recommends for speed and convenience
class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

In [11]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph


text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # specified column types using previously inferred types
    text_embedder_cfg=text_embedder_cfg,  # the text encoder recommended by relbench
    cache_dir=os.path.join(
        root_dir, f"rel-stack_materialized_cache"
    ),  # the graph is stored in cache
)

modules.json:   0%|          | 0.00/248 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

(…)WordEmbeddings/wordembedding_config.json:   0%|          | 0.00/164 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/480M [00:00<?, ?B/s]

(…)beddings/whitespacetokenizer_config.json:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Embedding raw data in mini-batch: 100%|██████████| 2438/2438 [00:38<00:00, 63.20it/s]
Embedding raw data in mini-batch: 100%|██████████| 2438/2438 [00:14<00:00, 173.81it/s]
Embedding raw data in mini-batch: 100%|██████████| 1305/1305 [01:11<00:00, 18.22it/s]
Embedding raw data in mini-batch: 100%|██████████| 1305/1305 [00:07<00:00, 175.59it/s]
Embedding raw data in mini-batch: 100%|██████████| 1305/1305 [00:08<00:00, 160.34it/s]
Embedding raw data in mini-batch: 100%|██████████| 1305/1305 [00:09<00:00, 136.00it/s]
Embedding raw data in mini-batch: 100%|██████████| 998/998 [00:07<00:00, 127.32it/s]
Embedding raw data in mini-batch: 100%|██████████| 998/998 [00:06<00:00, 146.51it/s]
Embedding raw data in mini-batch: 100%|██████████| 998/998 [00:06<00:00, 159.30it/s]
Embedding raw data in mini-batch: 100%|██████████| 998/998 [00:05<00:00, 176.46it/s]
Embedding raw data in mini-batch: 100%|██████████| 4592/4592 [00:30<00:00, 148.91it/s]
Embedding raw data in mini-batch: 100%|██████████| 45

In [12]:
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

## Adding the Torch Model

In [13]:
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder


class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)


# the rel-stack database is large (spanning 10 years d data) - less epochs for faster training
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 3

## Training, Testing, and Evaluation

In [14]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

Relbench's example notebook set the number of epochs to use for training to 10. However, they mentioned different tasks or databases might require a different number of epochs. With each epoch for the rel-stack user-badge classification task taking 12 minutes each, 3 epochs will take about 37 minutes to train.

In [15]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

100%|██████████| 6614/6614 [12:31<00:00,  8.80it/s]


Epoch: 01, Train loss: 0.04815797895185285, Val metrics: {'average_precision': 0.24755473428026917, 'accuracy': 0.9723805366251951, 'f1': 0.18838341845824919, 'roc_auc': 0.6924670991860463}


100%|██████████| 6614/6614 [12:37<00:00,  8.73it/s]


Epoch: 02, Train loss: 0.04552589356805092, Val metrics: {'average_precision': 0.15286483948946003, 'accuracy': 0.9722794848786166, 'f1': 0.2339142091152815, 'roc_auc': 0.5801168775211673}


100%|██████████| 6614/6614 [12:45<00:00,  8.64it/s]


Epoch: 03, Train loss: 0.04534391459263776, Val metrics: {'average_precision': 0.1521455560061642, 'accuracy': 0.9722673586690272, 'f1': 0.23842823842823843, 'roc_auc': 0.582264747324989}
Best Val metrics: {'average_precision': 0.15291737856147042, 'accuracy': 0.9722794848786166, 'f1': 0.2339142091152815, 'roc_auc': 0.5801169844838261}
Best test metrics: {'average_precision': 0.14227638002722787, 'accuracy': 0.9746906328320802, 'f1': 0.2334242675839165, 'roc_auc': 0.5779863225468362}
