In [1]:
%cd ..

/home/lingze/embedding_fusion


In [2]:
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

import torch
import pickle
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = get_dataset(name = "rel-trial", download = True)
db = dataset.get_db()
cache_path = "data/rel-trial-tensor-frame"

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.82 seconds.


In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

  return self.fget.__get__(instance, owner)()


In [6]:
# build graph

# start build graph
cache_dir = "./data/rel-trial-tensor-frame"
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )
    
    # Add edges normal edges
    for fkey_col_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_col_name]
        # Filter out dangling foreign keys
        mask = ~pkey_index.isna()
        fkey_index = torch.arange(len(pkey_index))
        
        # filter dangling foreign keys:
        pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
        fkey_index = fkey_index[torch.from_numpy(mask.values)]
        
        # fkey -> pkey edges
        edge_index = torch.stack([fkey_index, pkey_index], dim=0)
        edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)

        # pkey -> fkey edges.
        # "rev_" is added so that PyG loader recognizes the reverse edges
        edge_index = torch.stack([pkey_index, fkey_index], dim=0)
        edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)
    
data.validate()

-----> Materialize interventions Tensor Frame
-----> Materialize interventions_studies Tensor Frame
-----> Materialize facilities_studies Tensor Frame
-----> Materialize sponsors Tensor Frame
-----> Materialize eligibilities Tensor Frame
-----> Materialize reported_event_totals Tensor Frame
-----> Materialize designs Tensor Frame
-----> Materialize conditions_studies Tensor Frame
-----> Materialize drop_withdrawals Tensor Frame
-----> Materialize studies Tensor Frame
-----> Materialize outcome_analyses Tensor Frame
-----> Materialize sponsors_studies Tensor Frame
-----> Materialize outcomes Tensor Frame
-----> Materialize conditions Tensor Frame
-----> Materialize facilities Tensor Frame


True

In [7]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy


In [8]:
task_a = get_task("rel-trial", "study-adverse", download = True)
entity_table = task_a.entity_table
entity_table

'studies'

In [9]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128,64
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [10]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    # model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    # model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [11]:
# add the additional edges
from utils.util import load_np_dict
edge_dict = load_np_dict("./edges/rel-trial-edges.npz")

for edge_name, edge_np in edge_dict.items():
    src_table, dst_table = edge_name.split('-')[0], edge_name.split('-')[1]
    edge_index = torch.from_numpy(edge_np.astype(int)).t()
    # [2, edge_num]
    edge_type = (src_table, f"appendix", dst_table)
    data[edge_type].edge_index = sort_edge_index(edge_index)
data.validate()

True

In [24]:
# construct bottom model
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="sum",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [25]:
# for regression task, we need to deactivate the normalization and dropout layer
task_a.task_type
# freeze_instances = (torch.nn.BatchNorm1d, torch.nn.LayerNorm, torch.nn.Dropout, torch.nn.BatchNorm2d)
deactive_nn_instances = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)
net.train()
for module in net.modules():
    if isinstance(module, deactive_nn_instances):
        module.eval()
        for param in module.parameters():
            param.requires_grad = False

In [26]:
# # read pre-trained parameters
pre_trained_model_param_path = './static/rel-trial-pre-trained-channel128-ep40.pth'
# pre_trained_model_param_path = './static/rel-trial-pre-trained.pth'
pre_trained_state_dict = torch.load(pre_trained_model_param_path)
net.load_state_dict(pre_trained_state_dict)

<All keys matched successfully>

In [27]:
# training for fine-tune
from torch.nn import L1Loss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error

task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.01
epoches = 100
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False
early_stop = 10
max_round_epoch = 30
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)

In [28]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
patience = 0
# train
for epoch in range(1, epoches + 1):
    cnt = 0
    loss_accum = count_accum = 0
    # net.train()
    for batch in tqdm(task_loader_dict["train"], leave = False):
        cnt += 1
        if cnt > max_round_epoch:
            break
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            task_a.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_a.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    train_loss = loss_accum / count_accum
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_logits = test(task_loader_dict["val"], net, task_a)
    val_logits = val_logits.numpy()
    
    val_metrics = {
        "mae": mean_absolute_error(val_pred_hat, val_logits),
        "r2": r2_score(val_pred_hat, val_logits),
        "rmse": root_mean_squared_error(val_pred_hat, val_logits),
    }
    logits = test(task_loader_dict["test"], net, task_a)
    logits = logits.numpy()
    pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
    test_metrics = {
            "mae": mean_absolute_error(pred_hat, logits),
            "r2": r2_score(pred_hat, logits),
            "rmse": root_mean_squared_error(pred_hat, logits),
    }

    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")
    print(f"Test metrics: {test_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        patience = 0
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())
    else:
        patience += 1
    
    if patience >= early_stop:
        print(f"Early stop at epoch {epoch}")
        break   
    
best_epoch

  0%|          | 0/85 [00:00<?, ?it/s]

                                               

Epoch: 01, Train loss: 36.49122250874837, Val metrics: {'mae': 53.82199902685402, 'r2': -0.010854659452472104, 'rmse': 393.2326854546118}
Test metrics: {'mae': 54.136749339086144, 'r2': -0.02444283331601138, 'rmse': 252.71932937341663}


                                               

Epoch: 02, Train loss: 35.46702893575033, Val metrics: {'mae': 52.10411227292778, 'r2': -0.0010813661465787217, 'rmse': 391.3271133672281}
Test metrics: {'mae': 52.07485496339955, 'r2': 0.0029049332258721527, 'rmse': 249.3233076943158}


                                               

Epoch: 03, Train loss: 35.29884433746338, Val metrics: {'mae': 51.0425308381401, 'r2': 0.007507140289079395, 'rmse': 389.64485498645564}
Test metrics: {'mae': 50.53966613933132, 'r2': 0.028991390235299397, 'rmse': 246.04023689474633}


                                               

Epoch: 04, Train loss: 36.839137395222984, Val metrics: {'mae': 50.18318641536017, 'r2': 0.015132603939983946, 'rmse': 388.14512037333685}
Test metrics: {'mae': 49.389963091007786, 'r2': 0.05217812975043257, 'rmse': 243.08488701056598}


                                               

Epoch: 05, Train loss: 34.93049176534017, Val metrics: {'mae': 49.599875295255025, 'r2': 0.021457736717374054, 'rmse': 386.8967168056179}
Test metrics: {'mae': 48.69335172649562, 'r2': 0.0730441887417912, 'rmse': 240.39427010361513}


                                               

Epoch: 06, Train loss: 33.27241764068604, Val metrics: {'mae': 48.642347381302095, 'r2': 0.03268335161929792, 'rmse': 384.6711198164192}
Test metrics: {'mae': 47.112587559292436, 'r2': 0.10184254264909076, 'rmse': 236.63056256507863}


                                               

Epoch: 07, Train loss: 30.80533758799235, Val metrics: {'mae': 48.30641171597441, 'r2': 0.034953188340443364, 'rmse': 384.21953374627435}
Test metrics: {'mae': 46.30014245196205, 'r2': 0.1198911443450319, 'rmse': 234.24093427594943}


                                               

Epoch: 08, Train loss: 33.105711936950684, Val metrics: {'mae': 48.21658721601216, 'r2': 0.04903869693292584, 'rmse': 381.40525541054427}
Test metrics: {'mae': 45.62827035925412, 'r2': 0.15094634350965885, 'rmse': 230.0711506108693}


                                               

Epoch: 09, Train loss: 33.821931552886966, Val metrics: {'mae': 46.933922427187404, 'r2': 0.0577558444813, 'rmse': 379.6531229859638}
Test metrics: {'mae': 44.11048051864486, 'r2': 0.17290675012409717, 'rmse': 227.07631316446972}


                                               

Epoch: 10, Train loss: 32.20309534072876, Val metrics: {'mae': 46.780582982530866, 'r2': 0.05572137136440891, 'rmse': 380.06277136813827}
Test metrics: {'mae': 44.18129333244544, 'r2': 0.1776920585520314, 'rmse': 226.41846322672058}


                                               

Epoch: 11, Train loss: 30.464422957102457, Val metrics: {'mae': 46.44085332651201, 'r2': 0.06202666215891828, 'rmse': 378.7917375591487}
Test metrics: {'mae': 44.114365272784305, 'r2': 0.17982856577671824, 'rmse': 226.12413352287595}


                                               

Epoch: 12, Train loss: 29.925061734517417, Val metrics: {'mae': 46.30238806238155, 'r2': 0.06653704568983076, 'rmse': 377.8799020155483}
Test metrics: {'mae': 42.95666355344847, 'r2': 0.21179998294539026, 'rmse': 221.67300881869278}


                                               

Epoch: 13, Train loss: 32.12350104649862, Val metrics: {'mae': 45.6185352349399, 'r2': 0.07036217473206685, 'rmse': 377.1048722207749}
Test metrics: {'mae': 42.37913579167287, 'r2': 0.22642695144505764, 'rmse': 219.60653588200566}


                                               

Epoch: 14, Train loss: 30.880432256062825, Val metrics: {'mae': 46.108634887745936, 'r2': 0.07380655492105614, 'rmse': 376.40562253766615}
Test metrics: {'mae': 43.19321434339298, 'r2': 0.21992501321245195, 'rmse': 220.52750924289475}


                                               

Epoch: 15, Train loss: 31.883029556274415, Val metrics: {'mae': 45.86742529656958, 'r2': 0.06578033631166202, 'rmse': 378.03303468012376}
Test metrics: {'mae': 42.92779806885072, 'r2': 0.22521827550871354, 'rmse': 219.77803199004805}


                                               

Epoch: 16, Train loss: 31.60558058420817, Val metrics: {'mae': 46.17108017738672, 'r2': 0.0716541474742961, 'rmse': 376.84273862625673}
Test metrics: {'mae': 43.011736096372964, 'r2': 0.23064372668838284, 'rmse': 219.007176316283}


                                               

Epoch: 17, Train loss: 29.19881649017334, Val metrics: {'mae': 46.09924530639656, 'r2': 0.09658368414786256, 'rmse': 371.7484919254867}
Test metrics: {'mae': 42.79380213115857, 'r2': 0.2647447433843091, 'rmse': 214.0985200202878}


                                               

Epoch: 18, Train loss: 33.30828437805176, Val metrics: {'mae': 45.9667980410024, 'r2': 0.06800085953530122, 'rmse': 377.58349879213864}
Test metrics: {'mae': 42.83020966671638, 'r2': 0.2318136582753234, 'rmse': 218.84059491327994}


                                               

Epoch: 19, Train loss: 28.38090295791626, Val metrics: {'mae': 45.927766928047326, 'r2': 0.07948483053530553, 'rmse': 375.2500211703022}
Test metrics: {'mae': 42.9869097917298, 'r2': 0.23258642484644332, 'rmse': 218.73049451274971}


                                               

Epoch: 20, Train loss: 32.58191436131795, Val metrics: {'mae': 45.83143519208631, 'r2': 0.1057477422481139, 'rmse': 369.8582182585846}
Test metrics: {'mae': 42.91608916817677, 'r2': 0.2872660117377308, 'rmse': 210.79404168064724}


                                               

Epoch: 21, Train loss: 34.5609629313151, Val metrics: {'mae': 47.57163246968208, 'r2': 0.058999308306403764, 'rmse': 379.40252934988143}
Test metrics: {'mae': 46.0179071193567, 'r2': 0.17824122725181313, 'rmse': 226.34284514388426}


                                               

Epoch: 22, Train loss: 29.486788113911945, Val metrics: {'mae': 46.242669525183324, 'r2': 0.08383976159689377, 'rmse': 374.36132023567535}
Test metrics: {'mae': 44.28231151164266, 'r2': 0.22360297160020826, 'rmse': 220.00701485108985}


                                               

Epoch: 23, Train loss: 27.59923884073893, Val metrics: {'mae': 45.12582899907498, 'r2': 0.10288670282876122, 'rmse': 370.4494015119078}
Test metrics: {'mae': 42.67391195048183, 'r2': 0.27684565043469855, 'rmse': 212.32938286858487}


                                               

Epoch: 24, Train loss: 27.678485329945882, Val metrics: {'mae': 45.9212043246045, 'r2': 0.09508424757609124, 'rmse': 372.0568670211763}
Test metrics: {'mae': 43.469443925588436, 'r2': 0.2383709481116767, 'rmse': 217.90457419316618}


                                               

Epoch: 25, Train loss: 27.443993854522706, Val metrics: {'mae': 45.88833601648367, 'r2': 0.10406251648255782, 'rmse': 370.2065546586467}
Test metrics: {'mae': 43.65039728934714, 'r2': 0.27329051136414706, 'rmse': 212.85066508584745}


                                               

Epoch: 26, Train loss: 27.154906368255617, Val metrics: {'mae': 45.65287510329263, 'r2': 0.1117351600435802, 'rmse': 368.61795586111083}
Test metrics: {'mae': 42.90127087679792, 'r2': 0.27522579930410795, 'rmse': 212.56705661053323}


                                               

Epoch: 27, Train loss: 25.198499965667725, Val metrics: {'mae': 46.10751695029471, 'r2': 0.08324884685974632, 'rmse': 374.4820305398744}
Test metrics: {'mae': 42.98694807498158, 'r2': 0.2622060738794013, 'rmse': 214.46781831479979}


                                               

Epoch: 28, Train loss: 26.0025284131368, Val metrics: {'mae': 45.76443372072768, 'r2': 0.08662839921061782, 'rmse': 373.7911395869005}
Test metrics: {'mae': 43.54033808978641, 'r2': 0.2292189303984271, 'rmse': 219.20987582439727}


                                               

Epoch: 29, Train loss: 32.61659259796143, Val metrics: {'mae': 49.048423850540075, 'r2': 0.12569272711126234, 'rmse': 365.7103875085732}
Test metrics: {'mae': 45.67343896594959, 'r2': 0.3022152838495923, 'rmse': 208.57167192155345}


                                               

Epoch: 30, Train loss: 26.270057264963786, Val metrics: {'mae': 45.17660150475007, 'r2': 0.11926598714298509, 'rmse': 367.0520338179178}
Test metrics: {'mae': 41.42223892296894, 'r2': 0.31582711805133734, 'rmse': 206.5273270706167}


                                               

Epoch: 31, Train loss: 27.214138380686443, Val metrics: {'mae': 45.58656980087297, 'r2': 0.1164149417312188, 'rmse': 367.64565031935246}
Test metrics: {'mae': 42.234321507461146, 'r2': 0.31393008655278454, 'rmse': 206.81345185654044}


                                               

Epoch: 32, Train loss: 29.22033363978068, Val metrics: {'mae': 45.51680879472383, 'r2': 0.11433308447978141, 'rmse': 368.0785092986376}
Test metrics: {'mae': 42.71657207498767, 'r2': 0.28196286108661994, 'rmse': 211.5768027628446}


                                               

Epoch: 33, Train loss: 27.076058642069498, Val metrics: {'mae': 46.13558634858008, 'r2': 0.09900260982758646, 'rmse': 371.25047418138814}
Test metrics: {'mae': 42.74726786041901, 'r2': 0.2903158188093887, 'rmse': 210.34256162554934}
Early stop at epoch 33


23

In [29]:
# test
net.load_state_dict(state_dict)
net.eval()
logits = test(task_loader_dict["test"], net, task_a)
logits = logits.numpy()
pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

{'mae': 42.673038310185554,
 'r2': 0.27674693131019423,
 'rmse': 212.34387511044832}