In [1]:
%cd ..

/home/lingze/embedding_fusion


In [2]:
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

import torch
import pickle
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = get_dataset(name = "rel-trial", download = True)
db = dataset.get_db()
cache_path = "data/rel-trial-tensor-frame"

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.92 seconds.


In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

  return self.fget.__get__(instance, owner)()


In [6]:
# build graph

# start build graph
cache_dir = "./data/rel-trial-tensor-frame"
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )
    
    # Add edges normal edges
    for fkey_col_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_col_name]
        # Filter out dangling foreign keys
        mask = ~pkey_index.isna()
        fkey_index = torch.arange(len(pkey_index))
        
        # filter dangling foreign keys:
        pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
        fkey_index = fkey_index[torch.from_numpy(mask.values)]
        
        # fkey -> pkey edges
        edge_index = torch.stack([fkey_index, pkey_index], dim=0)
        edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)

        # pkey -> fkey edges.
        # "rev_" is added so that PyG loader recognizes the reverse edges
        edge_index = torch.stack([pkey_index, fkey_index], dim=0)
        edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)
    
data.validate()

-----> Materialize interventions Tensor Frame
-----> Materialize interventions_studies Tensor Frame
-----> Materialize facilities_studies Tensor Frame
-----> Materialize sponsors Tensor Frame
-----> Materialize eligibilities Tensor Frame
-----> Materialize reported_event_totals Tensor Frame
-----> Materialize designs Tensor Frame
-----> Materialize conditions_studies Tensor Frame
-----> Materialize drop_withdrawals Tensor Frame
-----> Materialize studies Tensor Frame
-----> Materialize outcome_analyses Tensor Frame
-----> Materialize sponsors_studies Tensor Frame
-----> Materialize outcomes Tensor Frame
-----> Materialize conditions Tensor Frame
-----> Materialize facilities Tensor Frame


True

In [7]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy


In [8]:
task_a = get_task("rel-trial", "study-outcome", download = True)
entity_table = task_a.entity_table

In [9]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [10]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [17]:
# construct bottom model
channels = 128
args = {
    "channels": channels,
    "num_layers": 2,
    "dropout_prob": 0.2,
}

temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=args["channels"],
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=args["channels"],
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=args["channels"],
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)

net = CompositeModel(
    data=data,
    channels=args["channels"],
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

net.reset_parameters()

# if torch.cuda.device_count() > 1:
#   print("Let's use", torch.cuda.device_count(), "GPUs!")
#   # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
#   net = torch.nn.DataParallel(net)

In [18]:
# training
# training
task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.0005
epoches = 20
loss_fn = BCEWithLogitsLoss()
tune_metric = "auroc"
higher_is_better = True
early_stop = 20
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)
optimizer = torch.optim.Adam(net.parameters(), lr = lr)

In [19]:
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
for epoch in range(1, epoches + 1):
    net.train()
    cnt = 0
    loss_accum = count_accum = 0
    for batch in tqdm(task_loader_dict["train"], leave=False):
        cnt += 1
        if cnt > early_stop:
            break
        
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)
    
    train_loss = loss_accum / count_accum
    val_logits = test(task_loader_dict["val"], net, task_a)
    val_logits = torch.sigmoid(val_logits).numpy()
    
    val_pred = (val_logits > 0.5).astype(int)
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_metrics = {
            "auroc": roc_auc_score(val_pred_hat, val_logits),
        "accuracy": accuracy_score(val_pred_hat, val_pred),
        "precision": precision_score(val_pred_hat, val_pred),
        "recall": recall_score(val_pred_hat, val_pred),
        "f1": f1_score(val_pred_hat, val_pred),
    }
    
    test_logits = test(task_loader_dict["test"], net, task_a)
    test_logits =  torch.sigmoid(test_logits).numpy()

    test_pred = (test_logits > 0.5).astype(int)
    test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
    test_metrics = {
        "auroc": roc_auc_score(test_pred_hat, test_logits),
        "accuracy": accuracy_score(test_pred_hat, test_pred),
        "precision": precision_score(test_pred_hat, test_pred),
        "recall": recall_score(test_pred_hat, test_pred),
        "f1score": f1_score(test_pred_hat, test_pred),
    }
    
    print("*"*30 + f"<Epoch: {epoch:02d}>" + "*"*30)
    print(f", Train loss: {train_loss}, Val metrics: {val_metrics}")
    print(f"Test metrics: {test_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())

# print the best epoch
best_epoch

  0%|          | 0/24 [00:00<?, ?it/s]

                                               

******************************<Epoch: 01>******************************
, Train loss: 0.6538617521524429, Val metrics: {'auroc': 0.636488726272008, 'accuracy': 0.603125, 'precision': 0.5991189427312775, 'recall': 0.9696969696969697, 'f1': 0.7406398910823689}
Test metrics: {'auroc': 0.6826789195210248, 'accuracy': 0.6, 'precision': 0.5972045743329097, 'recall': 0.9730848861283644, 'f1score': 0.7401574803149606}


                                               

******************************<Epoch: 02>******************************
, Train loss: 0.6235794514417649, Val metrics: {'auroc': 0.6591076622036375, 'accuracy': 0.6125, 'precision': 0.6095017381228274, 'recall': 0.9376114081996435, 'f1': 0.7387640449438202}
Test metrics: {'auroc': 0.7050900197353287, 'accuracy': 0.6351515151515151, 'precision': 0.6213333333333333, 'recall': 0.9648033126293996, 'f1score': 0.7558799675587997}


                                               

******************************<Epoch: 03>******************************
, Train loss: 0.6134337723255158, Val metrics: {'auroc': 0.6751281054686626, 'accuracy': 0.634375, 'precision': 0.6329113924050633, 'recall': 0.8912655971479501, 'f1': 0.7401924500370096}
Test metrics: {'auroc': 0.6984974513578632, 'accuracy': 0.6484848484848484, 'precision': 0.6446776611694153, 'recall': 0.8902691511387164, 'f1score': 0.7478260869565218}


                                               

******************************<Epoch: 04>******************************
, Train loss: 0.6022158324718475, Val metrics: {'auroc': 0.6881151184556757, 'accuracy': 0.6395833333333333, 'precision': 0.6332094175960347, 'recall': 0.910873440285205, 'f1': 0.7470760233918129}
Test metrics: {'auroc': 0.7245892509050404, 'accuracy': 0.6678787878787878, 'precision': 0.6530014641288433, 'recall': 0.9233954451345756, 'f1score': 0.7650085763293311}


                                               

******************************<Epoch: 05>******************************
, Train loss: 0.5955698907375335, Val metrics: {'auroc': 0.6837414391593959, 'accuracy': 0.6427083333333333, 'precision': 0.6697819314641744, 'recall': 0.7664884135472371, 'f1': 0.714879467996675}
Test metrics: {'auroc': 0.7200610221205188, 'accuracy': 0.6775757575757576, 'precision': 0.702803738317757, 'recall': 0.7784679089026915, 'f1score': 0.7387033398821218}


                                               

******************************<Epoch: 06>******************************
, Train loss: 0.5842116326093674, Val metrics: {'auroc': 0.6892766676048411, 'accuracy': 0.65, 'precision': 0.6903553299492385, 'recall': 0.7272727272727273, 'f1': 0.7083333333333334}
Test metrics: {'auroc': 0.7099572603005098, 'accuracy': 0.6581818181818182, 'precision': 0.7098121085594989, 'recall': 0.7039337474120083, 'f1score': 0.7068607068607069}


                                               

******************************<Epoch: 07>******************************
, Train loss: 0.5788783550262451, Val metrics: {'auroc': 0.6969830994598797, 'accuracy': 0.6458333333333334, 'precision': 0.6808510638297872, 'recall': 0.7415329768270945, 'f1': 0.7098976109215017}
Test metrics: {'auroc': 0.7117855024033514, 'accuracy': 0.6557575757575758, 'precision': 0.6978131212723658, 'recall': 0.7267080745341615, 'f1score': 0.7119675456389453}


                                               

******************************<Epoch: 08>******************************
, Train loss: 0.5705364584922791, Val metrics: {'auroc': 0.6928685349737981, 'accuracy': 0.6375, 'precision': 0.651925820256776, 'recall': 0.8146167557932263, 'f1': 0.7242472266244057}
Test metrics: {'auroc': 0.7286937149637378, 'accuracy': 0.6593939393939394, 'precision': 0.6717687074829932, 'recall': 0.8178053830227743, 'f1score': 0.7376283846872083}


                                               

******************************<Epoch: 09>******************************
, Train loss: 0.5556879341602325, Val metrics: {'auroc': 0.676169032206184, 'accuracy': 0.6302083333333334, 'precision': 0.6650641025641025, 'recall': 0.7397504456327986, 'f1': 0.70042194092827}
Test metrics: {'auroc': 0.7088130955407843, 'accuracy': 0.6557575757575758, 'precision': 0.7018255578093306, 'recall': 0.7163561076604554, 'f1score': 0.7090163934426229}


                                               

******************************<Epoch: 10>******************************
, Train loss: 0.5533083319664002, Val metrics: {'auroc': 0.6859930575100854, 'accuracy': 0.6364583333333333, 'precision': 0.6749174917491749, 'recall': 0.7290552584670231, 'f1': 0.700942587832048}
Test metrics: {'auroc': 0.7204968944099379, 'accuracy': 0.6666666666666666, 'precision': 0.7015503875968992, 'recall': 0.7494824016563147, 'f1score': 0.7247247247247247}


                                               

******************************<Epoch: 11>******************************
, Train loss: 0.5514600217342377, Val metrics: {'auroc': 0.6894062250099402, 'accuracy': 0.6479166666666667, 'precision': 0.6784, 'recall': 0.7557932263814616, 'f1': 0.715008431703204}
Test metrics: {'auroc': 0.7136924436695604, 'accuracy': 0.6618181818181819, 'precision': 0.7007874015748031, 'recall': 0.7370600414078675, 'f1score': 0.7184661957618567}


                                               

******************************<Epoch: 12>******************************
, Train loss: 0.5339159369468689, Val metrics: {'auroc': 0.6775003462309964, 'accuracy': 0.6354166666666666, 'precision': 0.6671949286846276, 'recall': 0.750445632798574, 'f1': 0.7063758389261745}
Test metrics: {'auroc': 0.7061797004588767, 'accuracy': 0.6533333333333333, 'precision': 0.6927592954990215, 'recall': 0.7329192546583851, 'f1score': 0.7122736418511066}


                                               

******************************<Epoch: 13>******************************
, Train loss: 0.517328467965126, Val metrics: {'auroc': 0.6777773310281051, 'accuracy': 0.634375, 'precision': 0.6458333333333334, 'recall': 0.8288770053475936, 'f1': 0.7259953161592506}
Test metrics: {'auroc': 0.7001380262249827, 'accuracy': 0.64, 'precision': 0.6555183946488294, 'recall': 0.8115942028985508, 'f1score': 0.7252543940795559}


                                               

******************************<Epoch: 14>******************************
, Train loss: 0.5118883788585663, Val metrics: {'auroc': 0.686037732477361, 'accuracy': 0.6489583333333333, 'precision': 0.6521739130434783, 'recall': 0.8556149732620321, 'f1': 0.7401696222050886}
Test metrics: {'auroc': 0.7137287663603453, 'accuracy': 0.6654545454545454, 'precision': 0.6688417618270799, 'recall': 0.8488612836438924, 'f1score': 0.7481751824817519}


                                               

******************************<Epoch: 15>******************************
, Train loss: 0.5094190910458565, Val metrics: {'auroc': 0.6830311071797139, 'accuracy': 0.6458333333333334, 'precision': 0.6641901931649331, 'recall': 0.7967914438502673, 'f1': 0.7244732576985413}
Test metrics: {'auroc': 0.6974077706343152, 'accuracy': 0.6436363636363637, 'precision': 0.6660808435852372, 'recall': 0.7846790890269151, 'f1score': 0.720532319391635}


                                               

******************************<Epoch: 16>******************************
, Train loss: 0.48356842398643496, Val metrics: {'auroc': 0.6932795446727336, 'accuracy': 0.646875, 'precision': 0.7159533073929961, 'recall': 0.6559714795008913, 'f1': 0.6846511627906977}
Test metrics: {'auroc': 0.7028682818156503, 'accuracy': 0.6387878787878788, 'precision': 0.7228915662650602, 'recall': 0.6211180124223602, 'f1score': 0.6681514476614699}


                                               

******************************<Epoch: 17>******************************
, Train loss: 0.477664390206337, Val metrics: {'auroc': 0.6915774284195336, 'accuracy': 0.6395833333333333, 'precision': 0.6979742173112339, 'recall': 0.6755793226381461, 'f1': 0.6865942028985508}
Test metrics: {'auroc': 0.7051445037715061, 'accuracy': 0.6424242424242425, 'precision': 0.7136363636363636, 'recall': 0.650103519668737, 'f1score': 0.6803900325027086}


                                               

******************************<Epoch: 18>******************************
, Train loss: 0.45759028792381284, Val metrics: {'auroc': 0.6826424349644163, 'accuracy': 0.6395833333333333, 'precision': 0.6626323751891074, 'recall': 0.7807486631016043, 'f1': 0.7168576104746317}
Test metrics: {'auroc': 0.6927221435230588, 'accuracy': 0.6460606060606061, 'precision': 0.6752293577981652, 'recall': 0.7619047619047619, 'f1score': 0.7159533073929961}


                                               

******************************<Epoch: 19>******************************
, Train loss: 0.4508752182126045, Val metrics: {'auroc': 0.6783804430863254, 'accuracy': 0.65, 'precision': 0.68, 'recall': 0.7575757575757576, 'f1': 0.7166947723440135}
Test metrics: {'auroc': 0.6984853437942683, 'accuracy': 0.6496969696969697, 'precision': 0.694, 'recall': 0.7184265010351967, 'f1score': 0.7060020345879959}


                                               

******************************<Epoch: 20>******************************
, Train loss: 0.44676139652729036, Val metrics: {'auroc': 0.6795062522616703, 'accuracy': 0.634375, 'precision': 0.6909090909090909, 'recall': 0.6773618538324421, 'f1': 0.684068406840684}
Test metrics: {'auroc': 0.6885026576102091, 'accuracy': 0.6290909090909091, 'precision': 0.7034482758620689, 'recall': 0.6335403726708074, 'f1score': 0.6666666666666666}


7

In [20]:
net.load_state_dict(state_dict)
test_logits = test(task_loader_dict["test"], net, task_a)
test_logits =  torch.sigmoid(test_logits).numpy()

test_pred = (test_logits > 0.5).astype(int)
test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
test_metrics = {
    "auroc": roc_auc_score(test_pred_hat, test_logits),
    "accuracy": accuracy_score(test_pred_hat, test_pred),
    "precision": precision_score(test_pred_hat, test_pred),
    "recall": recall_score(test_pred_hat, test_pred),
    "f1score": f1_score(test_pred_hat, test_pred),
}
test_metrics

{'auroc': 0.7116583729856041,
 'accuracy': 0.6557575757575758,
 'precision': 0.6978131212723658,
 'recall': 0.7267080745341615,
 'f1score': 0.7119675456389453}