In [1]:
%cd ..

/home/lingze/embedding_fusion


In [2]:
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

import torch
import pickle
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = get_dataset(name = "rel-trial", download = True)
db = dataset.get_db()
cache_path = "data/rel-trial-tensor-frame"

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.93 seconds.


In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

  return self.fget.__get__(instance, owner)()


In [6]:
# build graph

# start build graph
cache_dir = "./data/rel-trial-tensor-frame"
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )
    
    # Add edges normal edges
    for fkey_col_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_col_name]
        # Filter out dangling foreign keys
        mask = ~pkey_index.isna()
        fkey_index = torch.arange(len(pkey_index))
        
        # filter dangling foreign keys:
        pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
        fkey_index = fkey_index[torch.from_numpy(mask.values)]
        
        # fkey -> pkey edges
        edge_index = torch.stack([fkey_index, pkey_index], dim=0)
        edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)

        # pkey -> fkey edges.
        # "rev_" is added so that PyG loader recognizes the reverse edges
        edge_index = torch.stack([pkey_index, fkey_index], dim=0)
        edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)
    
data.validate()

-----> Materialize interventions Tensor Frame
-----> Materialize interventions_studies Tensor Frame
-----> Materialize facilities_studies Tensor Frame
-----> Materialize sponsors Tensor Frame
-----> Materialize eligibilities Tensor Frame
-----> Materialize reported_event_totals Tensor Frame
-----> Materialize designs Tensor Frame
-----> Materialize conditions_studies Tensor Frame
-----> Materialize drop_withdrawals Tensor Frame
-----> Materialize studies Tensor Frame
-----> Materialize outcome_analyses Tensor Frame
-----> Materialize sponsors_studies Tensor Frame
-----> Materialize outcomes Tensor Frame
-----> Materialize conditions Tensor Frame
-----> Materialize facilities Tensor Frame


True

In [7]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy

In [8]:
task_a = get_task("rel-trial", "site-success", download = True)
entity_table = task_a.entity_table
entity_table

'facilities'

In [9]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [10]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    # model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    # model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [11]:
# construct bottom model
channels = 64
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [12]:
# for regression task, we need to deactivate the normalization and dropout layer
task_a.task_type
# freeze_instances = (torch.nn.BatchNorm1d, torch.nn.LayerNorm, torch.nn.Dropout, torch.nn.BatchNorm2d)
deactive_nn_instances = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)
net.train()
for module in net.modules():
    if isinstance(module, deactive_nn_instances):
        module.eval()
        for param in module.parameters():
            param.requires_grad = False


In [13]:
# training for fine-tune
from torch.nn import L1Loss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error

task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.005
epoches = 25
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False
early_stop = 20
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)

In [14]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
early_stop = 20
# train
for epoch in range(1, epoches + 1):
    cnt = 0
    loss_accum = count_accum = 0
    # net.train()
    for batch in tqdm(task_loader_dict["train"]):
        cnt += 1
        if cnt > early_stop:
            break
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            task_a.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_a.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    train_loss = loss_accum / count_accum
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_logits = test(task_loader_dict["val"], net, task_a)
    val_logits = val_logits.numpy()
    
    val_metrics = {
        "mae": mean_absolute_error(val_pred_hat, val_logits),
        "r2": r2_score(val_pred_hat, val_logits),
        "rmse": root_mean_squared_error(val_pred_hat, val_logits),
    }
    
    logits = test(task_loader_dict["test"], net, task_a)
    logits = logits.numpy()
    pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
    test_metrics = {
            "mae": mean_absolute_error(pred_hat, logits),
            "r2": r2_score(pred_hat, logits),
            "rmse": root_mean_squared_error(pred_hat, logits),
    }

    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")
    print(f"Test metrics: {test_metrics}")
    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())

best_epoch

  0%|          | 0/296 [00:00<?, ?it/s]

  7%|▋         | 20/296 [00:02<00:39,  6.98it/s]


Epoch: 01, Train loss: 0.4122033104300499, Val metrics: {'mae': 0.42873566536526286, 'r2': -0.11780229176082635, 'rmse': 0.5050475492820236}
Test metrics: {'mae': 0.43314663560187006, 'r2': -0.1667494227404105, 'rmse': 0.5197958257453715}


  7%|▋         | 20/296 [00:01<00:19, 14.21it/s]


Epoch: 02, Train loss: 0.3531487837433815, Val metrics: {'mae': 0.42800956032713616, 'r2': -0.17552524252314217, 'rmse': 0.5179236585615973}
Test metrics: {'mae': 0.42664993200364576, 'r2': -0.1820978721610611, 'rmse': 0.5232035812365085}


  7%|▋         | 20/296 [00:01<00:23, 11.88it/s]


Epoch: 03, Train loss: 0.3328595906496048, Val metrics: {'mae': 0.4342383184396135, 'r2': -0.2880092873697544, 'rmse': 0.5421372737449065}
Test metrics: {'mae': 0.42970260533132315, 'r2': -0.2715932234498817, 'rmse': 0.5426478574490157}


  7%|▋         | 20/296 [00:01<00:21, 12.62it/s]


Epoch: 04, Train loss: 0.3061138838529587, Val metrics: {'mae': 0.42309131233019537, 'r2': -0.3303786429346032, 'rmse': 0.5509819889105757}
Test metrics: {'mae': 0.4347987276616598, 'r2': -0.4029152941000069, 'rmse': 0.5699801287382802}


  7%|▋         | 20/296 [00:01<00:22, 12.47it/s]


Epoch: 05, Train loss: 0.2931565657258034, Val metrics: {'mae': 0.42327213813914755, 'r2': -0.31198164720255295, 'rmse': 0.5471591297835268}
Test metrics: {'mae': 0.42303546081017207, 'r2': -0.29376912286091184, 'rmse': 0.5473591484939002}


  7%|▋         | 20/296 [00:01<00:21, 13.10it/s]


Epoch: 06, Train loss: 0.2772732928395271, Val metrics: {'mae': 0.4192808257629666, 'r2': -0.32650032251379235, 'rmse': 0.5501782912766678}
Test metrics: {'mae': 0.4146196518292489, 'r2': -0.2962320738325035, 'rmse': 0.5478799051453973}


  7%|▋         | 20/296 [00:01<00:19, 14.22it/s]


Epoch: 07, Train loss: 0.2731128364801407, Val metrics: {'mae': 0.4038496852875306, 'r2': -0.17675599311744117, 'rmse': 0.5181947153069292}
Test metrics: {'mae': 0.40987153462898235, 'r2': -0.21769372043503488, 'rmse': 0.5310226227772831}


  7%|▋         | 20/296 [00:01<00:21, 13.09it/s]


Epoch: 08, Train loss: 0.2702962763607502, Val metrics: {'mae': 0.40954955674437316, 'r2': -0.29542591614172675, 'rmse': 0.5436959035426382}
Test metrics: {'mae': 0.4214133589529854, 'r2': -0.3353287139648802, 'rmse': 0.5560810353958312}


  7%|▋         | 20/296 [00:01<00:20, 13.25it/s]


Epoch: 09, Train loss: 0.2583423532545567, Val metrics: {'mae': 0.3973759943373849, 'r2': -0.11194367481834044, 'rmse': 0.5037222848057502}
Test metrics: {'mae': 0.40312441427599427, 'r2': -0.1279667645241367, 'rmse': 0.5110838314225888}


  7%|▋         | 20/296 [00:01<00:21, 13.09it/s]


Epoch: 10, Train loss: 0.24823640063405036, Val metrics: {'mae': 0.37652905911939727, 'r2': -0.11125099150483786, 'rmse': 0.5035653639176089}
Test metrics: {'mae': 0.38940320133263284, 'r2': -0.14625675304742636, 'rmse': 0.5152107829076834}


  7%|▋         | 20/296 [00:01<00:20, 13.28it/s]


Epoch: 11, Train loss: 0.23855635449290274, Val metrics: {'mae': 0.37900324510297223, 'r2': -0.1165116865234257, 'rmse': 0.5047559032228921}
Test metrics: {'mae': 0.3952483179834642, 'r2': -0.1843488490326195, 'rmse': 0.5237014922224219}


  7%|▋         | 20/296 [00:01<00:21, 12.96it/s]


Epoch: 12, Train loss: 0.23172188475728034, Val metrics: {'mae': 0.3786475500506874, 'r2': -0.14768939003498605, 'rmse': 0.5117548339144693}
Test metrics: {'mae': 0.38480122388122895, 'r2': -0.15440828642571813, 'rmse': 0.5170394821651433}


  7%|▋         | 20/296 [00:01<00:20, 13.43it/s]


Epoch: 13, Train loss: 0.22471152171492575, Val metrics: {'mae': 0.37909371167759565, 'r2': -0.11730412242660604, 'rmse': 0.5049349948339151}
Test metrics: {'mae': 0.39079801834372035, 'r2': -0.161700210302957, 'rmse': 0.5186698748077172}


  7%|▋         | 20/296 [00:01<00:20, 13.47it/s]


Epoch: 14, Train loss: 0.2239884167909622, Val metrics: {'mae': 0.37636511110089194, 'r2': -0.10903338437218313, 'rmse': 0.5030626566834815}
Test metrics: {'mae': 0.391706801402307, 'r2': -0.16281323767477685, 'rmse': 0.5189182846608337}


  7%|▋         | 20/296 [00:01<00:20, 13.65it/s]


Epoch: 15, Train loss: 0.21556501984596252, Val metrics: {'mae': 0.37967925243635287, 'r2': -0.11195280011211528, 'rmse': 0.5037243517289691}
Test metrics: {'mae': 0.3785712118688413, 'r2': -0.10236160328475075, 'rmse': 0.5052496590194782}


  7%|▋         | 20/296 [00:01<00:20, 13.54it/s]


Epoch: 16, Train loss: 0.21520972549915313, Val metrics: {'mae': 0.3764004556216979, 'r2': -0.09996137002063188, 'rmse': 0.5010008778617818}
Test metrics: {'mae': 0.38567898599651884, 'r2': -0.13762929378820177, 'rmse': 0.5132682183149738}


  7%|▋         | 20/296 [00:01<00:21, 12.79it/s]


Epoch: 17, Train loss: 0.20965303033590316, Val metrics: {'mae': 0.3782940900617646, 'r2': -0.18064030731930836, 'rmse': 0.519049256442876}
Test metrics: {'mae': 0.39590506621167254, 'r2': -0.22988511909725795, 'rmse': 0.5336742686732351}


  7%|▋         | 20/296 [00:01<00:23, 11.68it/s]


Epoch: 18, Train loss: 0.21199334636330605, Val metrics: {'mae': 0.4076121951466724, 'r2': -0.35752719439551495, 'rmse': 0.556575441325432}
Test metrics: {'mae': 0.41414829624225613, 'r2': -0.3446634464034226, 'rmse': 0.5580213169880104}


  7%|▋         | 20/296 [00:01<00:20, 13.59it/s]


Epoch: 19, Train loss: 0.2112356096506119, Val metrics: {'mae': 0.3843613453440781, 'r2': -0.12819767971962115, 'rmse': 0.5073905464522983}
Test metrics: {'mae': 0.38651353010416617, 'r2': -0.1267839319784121, 'rmse': 0.5108157893248705}


  7%|▋         | 20/296 [00:01<00:20, 13.51it/s]


Epoch: 20, Train loss: 0.2046791784465313, Val metrics: {'mae': 0.3816656588579701, 'r2': -0.1067564962259584, 'rmse': 0.5025459878813047}
Test metrics: {'mae': 0.39655993273056717, 'r2': -0.16655573548286462, 'rmse': 0.519752679374508}


  7%|▋         | 20/296 [00:01<00:24, 11.40it/s]


Epoch: 21, Train loss: 0.2025813214480877, Val metrics: {'mae': 0.38140117953196456, 'r2': -0.1097591966398721, 'rmse': 0.5032272456544736}
Test metrics: {'mae': 0.3931524250621013, 'r2': -0.15166874012429155, 'rmse': 0.5164256200814352}


  7%|▋         | 20/296 [00:01<00:23, 11.69it/s]


Epoch: 22, Train loss: 0.2050161764025688, Val metrics: {'mae': 0.3789324745938837, 'r2': -0.15440708435645023, 'rmse': 0.5132503590366759}
Test metrics: {'mae': 0.38468717027369487, 'r2': -0.1467693879210099, 'rmse': 0.5153259776451379}


  7%|▋         | 20/296 [00:01<00:23, 11.67it/s]


Epoch: 23, Train loss: 0.196232870221138, Val metrics: {'mae': 0.3779022701359271, 'r2': -0.15149339648987814, 'rmse': 0.5126022357815025}
Test metrics: {'mae': 0.398528938328328, 'r2': -0.20582935950453551, 'rmse': 0.5284293329078786}


  7%|▋         | 20/296 [00:01<00:23, 11.51it/s]


Epoch: 24, Train loss: 0.18904469385743142, Val metrics: {'mae': 0.3907229980766078, 'r2': -0.2705569747886458, 'rmse': 0.5384518116681922}
Test metrics: {'mae': 0.4038138659927396, 'r2': -0.2689909172067664, 'rmse': 0.542092310634643}


  7%|▋         | 20/296 [00:01<00:20, 13.41it/s]


Epoch: 25, Train loss: 0.1982203058898449, Val metrics: {'mae': 0.3845771032781646, 'r2': -0.136425749908748, 'rmse': 0.5092374128456487}
Test metrics: {'mae': 0.38391513368555685, 'r2': -0.13537804503279416, 'rmse': 0.512760114901673}


14

In [15]:
# test
net.load_state_dict(state_dict)
logits = test(task_loader_dict["test"], net, task_a)
logits = logits.numpy()
pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

{'mae': 0.3916951493793514,
 'r2': -0.16279380941943344,
 'rmse': 0.5189139496056754}