In [None]:
### I ran this on a google colab, if running on a local machine, you might need to install different versions.

!pip install relbench[full]
!pip install pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cu121.html # PyG for working with graphs
!pip install git+https://github.com/pyg-team/pytorch_geometric.git # more PyG
!pip install pytorch_frame[full] #PyTorch Frame for working with tabular data

In [None]:
### data available at https://www.kaggle.com/datasets/mmohaiminulislam/ecommerce-data-analysis/data

In [2]:
import os
import time

import torch
import numpy as np
import pandas as pd
import pooch
import pyarrow as pa
import pyarrow.json
from relbench.base import Database, Dataset, Table
from relbench.datasets import get_dataset, get_dataset_names, register_dataset
from relbench.utils import unzip_processor
from sklearn.impute import SimpleImputer

In [3]:
import duckdb
import pandas as pd

from relbench.base import Database, EntityTask, Table, TaskType
from relbench.datasets import get_dataset
from relbench.metrics import r2, mae
from relbench.tasks import get_task, get_task_names, register_task

In [4]:
BASE_DIR = '.'

In [5]:
import csv
def load_csv_to_db(file_path):
    encodings = ['utf-8', 'ISO-8859-1', 'utf-16', 'Windows-1252']
    for encoding in encodings:
        try:
            with open(file_path, encoding=encoding) as file:
                data = list(csv.DictReader(file))
                df = pd.DataFrame(data)
                return df
        except UnicodeDecodeError:
            print(f"UnicodeDecodeError with encoding: {encoding} for file: {file_path}")
            continue  # Try the next encoding

In [6]:
### dataset available at https://www.kaggle.com/datasets/olistbr/brazilian-ecommerce

class EcommerceDataBase(Dataset):
    # example of creating your own dataset: https://github.com/snap-stanford/relbench/blob/main/tutorials/custom_dataset.ipynb

    val_timestamp = pd.Timestamp(year=2018, month=1, day=1)
    test_timestamp = pd.Timestamp(year=2020, month=1, day=1)

    def make_db(self) -> Database:

        tables = {}

        customers = load_csv_to_db(BASE_DIR + '/customer_dim.csv').drop(columns=['contact_no', 'nid']).rename(columns={'coustomer_key': 'customer_key'})
        stores = load_csv_to_db(BASE_DIR + '/store_dim.csv').drop(columns=['upazila'])
        products = load_csv_to_db(BASE_DIR + '/item_dim.csv')
        transactions = load_csv_to_db(BASE_DIR + '/fact_table.csv').rename(columns={'coustomer_key': 'customer_key'})
        times = load_csv_to_db(BASE_DIR + '/time_dim.csv')

        t = transactions.merge(times[['time_key', 'date']], on='time_key').drop(columns=['payment_key', 'time_key', 'unit'])
        t['date'] = pd.to_datetime(t.date)
        t = t.reset_index().rename(columns={'index': 't_id'})
        t['quantity'] = t.quantity.astype(int)
        t['unit_price'] = t.unit_price.astype(float)
        products['unit_price'] = products.unit_price.astype(float)
        t['total_price'] = t.total_price.astype(float)

        print(t.isna().sum(axis=0))
        print(products.isna().sum(axis=0))
        print(stores.isna().sum(axis=0))
        print(customers.isna().sum(axis=0))

        tables['products'] = Table(
            df=pd.DataFrame(products),
            pkey_col='item_key',
            fkey_col_to_pkey_table={},
            time_col=None
        )

        tables['customers'] = Table(
            df=pd.DataFrame(customers),
            pkey_col='customer_key',
            fkey_col_to_pkey_table={},
            time_col=None
        )

        tables['transactions'] = Table(
            df=pd.DataFrame(t),
            pkey_col='t_id',
            fkey_col_to_pkey_table={
                'customer_key': 'customers',
                'item_key': 'products',
                'store_key': 'stores'
            },
            time_col='date'
        )

        tables['stores'] = Table(
            df=pd.DataFrame(stores),
            pkey_col='store_key',
            fkey_col_to_pkey_table={}
        )

        return Database(tables)

In [19]:
class CustomerRevenueTask(EntityTask):
    # example of custom task: https://github.com/snap-stanford/relbench/blob/main/tutorials/custom_task.ipynb


    task_type = TaskType.REGRESSION
    entity_col = "customer_key"
    entity_table = "customers"
    time_col = "timestamp"
    target_col = "revenue"
    timedelta = pd.Timedelta(days=30) # how far we want to predict revenue into the future.
    metrics = [r2, mae]
    num_eval_timestamps = 40

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:

        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        transactions = db.table_dict["transactions"].df

        df = duckdb.sql(f"""
            select
                timestamp,
                customer_key,
                sum(total_price) as revenue
            from
                timestamp_df t
            left join
                transactions ta
            on
                ta.date <= t.timestamp + INTERVAL '{self.timedelta}'
                and ta.date > t.timestamp
            group by timestamp, customer_key
        """).df().dropna()

        print(df)

        return Table(
            df=df,
            fkey_col_to_pkey_table={self.entity_col: self.entity_table},
            pkey_col=None,
            time_col=self.time_col,
        )

In [20]:
ecomm_ds = EcommerceDataBase()
db = ecomm_ds.get_db()

task = CustomerRevenueTask(ecomm_ds)

Making Database object from scratch...
(You can also use `get_dataset(..., download=True)` for datasets prepared by the RelBench team.)
UnicodeDecodeError with encoding: utf-8 for file: ./customer_dim.csv
UnicodeDecodeError with encoding: utf-8 for file: ./item_dim.csv


  t['date'] = pd.to_datetime(t.date)


t_id            0
customer_key    0
item_key        0
store_key       0
quantity        0
unit_price      0
total_price     0
date            0
dtype: int64
item_key       0
item_name      0
desc           0
unit_price     0
man_country    0
supplier       0
unit           0
dtype: int64
store_key    0
division     0
district     0
dtype: int64
customer_key    0
name            0
dtype: int64
Done in 15.60 seconds.


In [21]:
import numpy as np

from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

Making task table for train split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Making Database object from scratch...
(You can also use `get_dataset(..., download=True)` for datasets prepared by the RelBench team.)
UnicodeDecodeError with encoding: utf-8 for file: ./customer_dim.csv
UnicodeDecodeError with encoding: utf-8 for file: ./item_dim.csv


  t['date'] = pd.to_datetime(t.date)


t_id            0
customer_key    0
item_key        0
store_key       0
quantity        0
unit_price      0
total_price     0
date            0
dtype: int64
item_key       0
item_name      0
desc           0
unit_price     0
man_country    0
supplier       0
unit           0
dtype: int64
store_key    0
division     0
district     0
dtype: int64
customer_key    0
name            0
dtype: int64
Done in 14.98 seconds.
        timestamp  customer_key  revenue
0      2014-01-22           110    456.0
1      2014-01-22          6366    527.0
2      2014-01-22          2427    400.0
3      2014-01-22           679    233.5
4      2014-01-22          3385    472.0
...           ...           ...      ...
318152 2017-12-02          5959    336.0
318153 2017-12-02          3171    140.0
318154 2017-12-02          6225    132.0
318155 2017-12-02          2312     16.0
318156 2017-12-02          7494     95.0

[318157 rows x 3 columns]
Done in 15.25 seconds.
Making task table for val split from sc

  t['date'] = pd.to_datetime(t.date)


t_id            0
customer_key    0
item_key        0
store_key       0
quantity        0
unit_price      0
total_price     0
date            0
dtype: int64
item_key       0
item_name      0
desc           0
unit_price     0
man_country    0
supplier       0
unit           0
dtype: int64
store_key    0
division     0
district     0
dtype: int64
customer_key    0
name            0
dtype: int64
Done in 15.33 seconds.
       timestamp  customer_key  revenue
0     2020-01-01          4827    192.0
1     2020-01-01          4731    245.0
2     2020-01-01          9090    360.0
3     2020-01-01          6492    184.0
4     2020-01-01          1595     45.0
...          ...           ...      ...
79027 2020-07-29           154    176.0
79028 2020-07-29          1181    140.0
79029 2020-07-29          5099     28.0
79030 2020-07-29          1344     26.0
79031 2020-07-29          1707     35.0

[79032 rows x 3 columns]
Done in 15.70 seconds.


In [10]:
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "./data"

cuda


In [11]:
from relbench.modeling.utils import get_stype_proposal

col_to_stype_dict = get_stype_proposal(db)

In [12]:
!pip install -U sentence-transformers # we need another package for text encoding
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor


class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

Collecting sentence-transformers
  Downloading sentence_transformers-3.2.1-py3-none-any.whl.metadata (10 kB)
Downloading sentence_transformers-3.2.1-py3-none-any.whl (255 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m255.8/255.8 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentence-transformers
Successfully installed sentence-transformers-3.2.1


In [13]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-ecomm_materialized_cache"
    ),  # store materialized graph for convenience
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/248 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

(…)WordEmbeddings/wordembedding_config.json:   0%|          | 0.00/164 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/480M [00:00<?, ?B/s]

(…)beddings/whitespacetokenizer_config.json:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Embedding raw data in mini-batch: 100%|██████████| 2/2 [00:02<00:00,  1.01s/it]
Embedding raw data in mini-batch: 100%|██████████| 2/2 [00:00<00:00, 54.91it/s]
Embedding raw data in mini-batch: 100%|██████████| 2/2 [00:00<00:00, 97.97it/s]
Embedding raw data in mini-batch: 100%|██████████| 36/36 [00:01<00:00, 30.88it/s]
Embedding raw data in mini-batch: 100%|██████████| 3/3 [00:00<00:00, 137.15it/s]


In [14]:
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

In [15]:
from torch.nn import MSELoss
import copy
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder


class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)


# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 100

In [17]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

In [18]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

100%|██████████| 622/622 [01:42<00:00,  6.10it/s]


Epoch: 01, Train loss: 115.22177720827277, Val metrics: {'r2': -0.050840411244314376, 'mae': 106.85230259719245}


100%|██████████| 622/622 [01:36<00:00,  6.42it/s]


Epoch: 02, Train loss: 106.49870062402098, Val metrics: {'r2': -0.06088264228589724, 'mae': 106.83409881596596}


100%|██████████| 622/622 [01:36<00:00,  6.47it/s]


Epoch: 03, Train loss: 106.49842213751197, Val metrics: {'r2': -0.07429948311107792, 'mae': 106.90080879044865}


100%|██████████| 622/622 [01:37<00:00,  6.38it/s]


KeyboardInterrupt: 