In [None]:
pip install -U "autogluon.tabular[all]" --quiet # For training tabular model

In [None]:
# Install required packages.
!pip install torch==2.4.0
!pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
!pip install pytorch_frame --quiet
!pip install relbench --quiet

In [None]:
import os
import torch
import relbench

import os
import math
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch_geometric
import torch_frame

In [None]:
from relbench.datasets import get_dataset

dataset = get_dataset(name="rel-trial", download=True)

In [None]:
from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset

from relbench.tasks import get_task

db = dataset.get_db()

table = db.table_dict["studies"]
task = get_task("rel-trial", "study-outcome", download=True)

train_task = task.get_table("train").df
val_task = task.get_table("val").df
test_task = task.get_table("test").df

## learning table

train_df = train_task.merge(table.df, how="left")
val_df = val_task.merge(table.df, how="left")
test_df = test_task.merge(table.df, how="left")

## learning table

df_train = train_task.merge(table.df, how="left")
df_val = val_task.merge(table.df, how="left")
df_test = test_task.merge(table.df, how="left")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "roc_auc"
higher_is_better = True

In [None]:
# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

root_dir = "./data"

# Build the graph (dont consider text, will be added later in the tab method)

In [None]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)

In [None]:
## Identify text columns

stype = torch_frame.stype

text_columns = [
    (table_name, column_name)
    for table_name, column_map in col_to_stype_dict.items()
    for column_name, column_stype in column_map.items()
    if column_stype == stype.text_embedded
]

In [None]:
# drop columns

from collections import defaultdict

columns_to_drop_by_table = defaultdict(list)

for table_name, column_name in text_columns:
    columns_to_drop_by_table[table_name].append(column_name)

for table_name, columns in columns_to_drop_by_table.items():
    df = db.table_dict[table_name].df
    existing_columns = [c for c in columns if c in df.columns]

    if existing_columns:
        db.table_dict[table_name].df = df.drop(columns=existing_columns)

In [None]:
for table_name, column_name in text_columns:
    col_to_stype_dict[table_name].pop(column_name, None)

# create REG

In [None]:
from relbench.modeling.graph import make_pkey_fkey_graph

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    #text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),  # store materialized graph for convenience
)

In [None]:
train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

In [None]:
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            64 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

# Model


In [None]:
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder


class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])


    def encode_seed_nodes(self, batch: HeteroData, entity_table: NodeType) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, batch.batch_dict)
        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return x_dict[entity_table][: seed_time.size(0)]

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="mean",
    norm="batch_norm",
).to(device) #hyperparameters from Relbench Paper


optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 20

In [None]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

In [None]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

In [None]:
import torch
import numpy as np

@torch.no_grad()
def extract_seed_embeddings(loader: NeighborLoader):
    model.eval()
    entity_table = task.entity_table

    all_ids = []
    all_emb = []

    for batch in loader:
        batch = batch.to(device)

        seed_n = batch[entity_table].seed_time.size(0)

        # IDs and embeddings are on GPU here:
        seed_ids = batch[entity_table].n_id[:seed_n]
        seed_emb = model.encode_seed_nodes(batch, entity_table)

        # Move to CPU only for storage/return:
        all_ids.append(seed_ids.detach().cpu())
        all_emb.append(seed_emb.detach().cpu())

    ids = torch.cat(all_ids, dim=0).numpy()
    emb = torch.cat(all_emb, dim=0).numpy()
    return ids, emb

# Usage:
train_ids, train_emb = extract_seed_embeddings(loader_dict["train"])
val_ids, val_emb = extract_seed_embeddings(loader_dict["val"])
test_ids, test_emb = extract_seed_embeddings(loader_dict["test"])



# Get the original table

In [None]:
entity_table = task.entity_table
entity_df = db.table_dict[entity_table].df.reset_index(drop=True).copy()

# node_id is the row position used by PyG for that node type
entity_df["node_id"] = np.arange(len(entity_df), dtype=np.int64)

# Keep only what we need:
node_to_nct = entity_df[["node_id", "nct_id"]]


In [None]:
def make_emb_df(ids: np.ndarray, emb: np.ndarray, split: str) -> pd.DataFrame:
    """
    ids: shape (N,)
    emb: shape (N, D)
    """
    ids = ids.astype(np.int64)

    emb_df = pd.DataFrame(
        emb,
        columns=[f"gnn_{i}" for i in range(emb.shape[1])]
    )

    df = pd.concat(
        [
            pd.DataFrame({"node_id": ids}),
            emb_df,
            pd.Series(split, index=ids.index if hasattr(ids, "index") else range(len(ids)), name="split"),
        ],
        axis=1,
    )

    return df

train_emb_df = make_emb_df(train_ids, train_emb, "train").merge(node_to_nct, on="node_id", how="left")
val_emb_df   = make_emb_df(val_ids,   val_emb,   "val").merge(node_to_nct, on="node_id", how="left")
test_emb_df  = make_emb_df(test_ids,  test_emb,  "test").merge(node_to_nct, on="node_id", how="left")

# Sanity check: ensure mapping worked
assert train_emb_df["nct_id"].isna().sum() == 0, "Some train node_ids did not map to nct_id"
assert val_emb_df["nct_id"].isna().sum() == 0, "Some val node_ids did not map to nct_id"
assert test_emb_df["nct_id"].isna().sum() == 0, "Some test node_ids did not map to nct_id"


In [None]:
# Create augmented dataframes (merge train_df with train_emb df on nct_id, same with val and test)

df_train_aug = train_df.merge(train_emb_df, on="nct_id", how="left")
df_val_aug   = val_df.merge(val_emb_df, on="nct_id", how="left")
df_test_aug  = test_df.merge(test_emb_df, on="nct_id", how="left")

In [None]:
from autogluon.tabular import TabularDataset, TabularPredictor


path = "/models/"

predictor = TabularPredictor(
    label="outcome",
    path=path,
    eval_metric="roc_auc",  # or "roc_auc", "rmse", etc.
).fit(
    train_data=df_train_aug,
    tuning_data=df_val_aug,
    time_limit=3600,
    presets="medium_quality",   # strong; uses bagging/stacking
    included_model_types=[
        "GBM",      # LightGBM
        "CAT",      # CatBoost
        "XGB",      # XGBoost
        "RF",       # optional
        "XT",       # optional
        "REALMLP",  # MLP (if available)
        # optionally: "NN_TORCH" (depending on your AG version/install)
    ],
)

In [None]:
predictor.leaderboard(df_val_aug, silent=True).head(20)

In [None]:
proba  = predictor.predict_proba(df_test_aug)
preds_proba = proba[1]

results = task.evaluate(preds_proba)

print(task.evaluate(preds_proba))

In [None]:
OUTPUT_PATH = '/run/'

df_GNN_DB = pd.DataFrame([results])
df_GNN_DB["model"] = "Tab+GNN(DB)"
df_GNN_DB["task"] = task

df_GNN_DB.to_csv(OUTPUT_PATH+f"Tab+GNN_DB_Trial_{task}.csv")