In [9]:
%cd ..

/home/lingze/embedding_fusion


In [10]:
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder

In [1]:
import relbench
from relbench.base import Table, Database, Dataset, EntityTask
from relbench.datasets import get_dataset
from relbench.tasks import get_task
from relbench.base import BaseTask
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader

import os
import math
import numpy as np
from tqdm import tqdm
import copy

import torch
from torch import Tensor
import torch_geometric
import torch_frame

from torch_frame.config.text_embedder import TextEmbedderConfig
from typing import List, Optional
from sentence_transformers import SentenceTransformer

from torch.nn import L1Loss, BCEWithLogitsLoss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [2]:
dataset = get_dataset(name="rel-trial", download=True)
db = dataset.get_db()
task_a = get_task("rel-trial", "study-outcome", download = True)
task_b = get_task("rel-trial", "site-success", download = True)

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 9.62 seconds.


In [3]:
col_to_stype_dict = get_stype_proposal(db)

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)


root_dir = "/home/lingze/embedding_fusion/data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-trial_materialized_cache"
    ),  # store materialized graph for convenience
)

  return self.fget.__get__(instance, owner)()


In [4]:
def generate_loader_dict(task: BaseTask) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [5]:
taska_loader_dict = generate_loader_dict(task_a)
taskb_loader_dict = generate_loader_dict(task_b)

In [6]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [31]:
channels = 128

task_a_temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
)

task_a_feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)


task_a_node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)



task_a_model =  CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=task_a_feat_encoder,
    node_encoder=task_a_node_encoder,
    temporal_encoder=task_a_temporal_encoder
)

In [32]:
# Train Task A
optimizer = torch.optim.Adam(task_a_model.parameters(), lr=0.005)
epochs = 15
loss_fn = BCEWithLogitsLoss()
tune_metric = "auroc"
higher_is_better = True

In [33]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
task_a_model.to(device)
best_epoch = 0
task_a_model.reset_parameters()
for epoch in range(1, epochs + 1):
    task_a_model.train()
    loss_accum = count_accum = 0
    for batch in tqdm(taska_loader_dict["train"]):
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = task_a_model(
            batch,
            task_a.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_a.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)
    
    train_loss = loss_accum / count_accum
    val_logits = test(taska_loader_dict["val"], task_a_model, task_a)
    val_logits = torch.sigmoid(val_logits).numpy()
    
    val_pred = (val_logits > 0.5).astype(int)
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_metrics = {
            "auroc": roc_auc_score(val_pred_hat, val_logits),
        "accuracy": accuracy_score(val_pred_hat, val_pred),
        "precision": precision_score(val_pred_hat, val_pred),
        "recall": recall_score(val_pred_hat, val_pred),
        "f1": f1_score(val_pred_hat, val_pred),
    }
    
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(task_a_model.state_dict())

best_epoch

  0%|          | 0/24 [00:00<?, ?it/s]

100%|██████████| 24/24 [00:05<00:00,  4.11it/s]


Epoch: 01, Train loss: 0.6894868534068893, Val metrics: {'auroc': 0.4877367214828515, 'accuracy': 0.584375, 'precision': 0.584375, 'recall': 1.0, 'f1': 0.73767258382643}


100%|██████████| 24/24 [00:04<00:00,  4.81it/s]


Epoch: 02, Train loss: 0.6564565871046925, Val metrics: {'auroc': 0.5614973262032086, 'accuracy': 0.584375, 'precision': 0.584375, 'recall': 1.0, 'f1': 0.73767258382643}


100%|██████████| 24/24 [00:04<00:00,  5.04it/s]


Epoch: 03, Train loss: 0.6493301573586062, Val metrics: {'auroc': 0.5979386970099044, 'accuracy': 0.584375, 'precision': 0.584375, 'recall': 1.0, 'f1': 0.73767258382643}


100%|██████████| 24/24 [00:05<00:00,  4.35it/s]


Epoch: 04, Train loss: 0.634345735145128, Val metrics: {'auroc': 0.6277905101434513, 'accuracy': 0.584375, 'precision': 0.584375, 'recall': 1.0, 'f1': 0.73767258382643}


100%|██████████| 24/24 [00:06<00:00,  4.00it/s]


Epoch: 05, Train loss: 0.6210746712757783, Val metrics: {'auroc': 0.6499805663892351, 'accuracy': 0.5854166666666667, 'precision': 0.5849843587069864, 'recall': 1.0, 'f1': 0.7381578947368421}


100%|██████████| 24/24 [00:05<00:00,  4.61it/s]


Epoch: 06, Train loss: 0.6072499076982568, Val metrics: {'auroc': 0.6390039269296235, 'accuracy': 0.5927083333333333, 'precision': 0.6693227091633466, 'recall': 0.5989304812834224, 'f1': 0.632173095014111}


100%|██████████| 24/24 [00:05<00:00,  4.37it/s]


Epoch: 07, Train loss: 0.593892188286491, Val metrics: {'auroc': 0.6393702616612833, 'accuracy': 0.6114583333333333, 'precision': 0.6127098321342925, 'recall': 0.910873440285205, 'f1': 0.7326164874551971}


100%|██████████| 24/24 [00:05<00:00,  4.51it/s]


Epoch: 08, Train loss: 0.5772828233029339, Val metrics: {'auroc': 0.6417156974432516, 'accuracy': 0.5989583333333334, 'precision': 0.6554770318021201, 'recall': 0.661319073083779, 'f1': 0.6583850931677019}


100%|██████████| 24/24 [00:05<00:00,  4.60it/s]


Epoch: 09, Train loss: 0.5672043125709335, Val metrics: {'auroc': 0.6388341620539764, 'accuracy': 0.6270833333333333, 'precision': 0.6415620641562064, 'recall': 0.8199643493761141, 'f1': 0.7198748043818466}


100%|██████████| 24/24 [00:04<00:00,  4.85it/s]


Epoch: 10, Train loss: 0.5455926545902314, Val metrics: {'auroc': 0.6450663199889206, 'accuracy': 0.6260416666666667, 'precision': 0.660828025477707, 'recall': 0.7397504456327986, 'f1': 0.6980656013456686}


100%|██████████| 24/24 [00:05<00:00,  4.70it/s]


Epoch: 11, Train loss: 0.5298433202851985, Val metrics: {'auroc': 0.6338886431765689, 'accuracy': 0.6208333333333333, 'precision': 0.6347469220246238, 'recall': 0.8270944741532977, 'f1': 0.718266253869969}


100%|██████████| 24/24 [00:05<00:00,  4.18it/s]


Epoch: 12, Train loss: 0.5176240227034395, Val metrics: {'auroc': 0.6308462779051014, 'accuracy': 0.5927083333333333, 'precision': 0.6585820895522388, 'recall': 0.6292335115864528, 'f1': 0.6435733819507748}


100%|██████████| 24/24 [00:05<00:00,  4.67it/s]


Epoch: 13, Train loss: 0.4970846722150656, Val metrics: {'auroc': 0.6132577432887031, 'accuracy': 0.596875, 'precision': 0.6570397111913358, 'recall': 0.6488413547237076, 'f1': 0.6529147982062781}


100%|██████████| 24/24 [00:05<00:00,  4.66it/s]


Epoch: 14, Train loss: 0.48224579991251904, Val metrics: {'auroc': 0.6213126398884913, 'accuracy': 0.621875, 'precision': 0.651840490797546, 'recall': 0.7575757575757576, 'f1': 0.7007419620774938}


100%|██████████| 24/24 [00:05<00:00,  4.76it/s]


Epoch: 15, Train loss: 0.46138886097313586, Val metrics: {'auroc': 0.6300197910105031, 'accuracy': 0.5770833333333333, 'precision': 0.6565656565656566, 'recall': 0.5793226381461676, 'f1': 0.615530303030303}


5

In [34]:
# test task A
task_a_model.load_state_dict(state_dict)
test_logits = test(taska_loader_dict["test"], task_a_model, task_a)
test_logits =  torch.sigmoid(test_logits).numpy()

test_pred = (test_logits > 0.5).astype(int)
test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
test_metrics = {
    "auroc": roc_auc_score(test_pred_hat, test_logits),
    "accuracy": accuracy_score(test_pred_hat, test_pred),
    "precision": precision_score(test_pred_hat, test_pred),
    "recall": recall_score(test_pred_hat, test_pred),
    "f1score": f1_score(test_pred_hat, test_pred),
}
test_metrics

{'auroc': 0.6789013596793917,
 'accuracy': 0.5878787878787879,
 'precision': 0.5868772782503038,
 'recall': 1.0,
 'f1score': 0.7396630934150077}

In [35]:
# save the individual model state
torch.save(task_a_model.state_dict(), "task_a_model.pth")

In [53]:
# train Task B
task_b_temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
)

task_b_feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)


task_b_node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.4
)

task_b_model =  CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.3,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=task_b_feat_encoder,
    node_encoder=task_b_node_encoder,
    temporal_encoder=task_b_temporal_encoder
)

In [54]:
optimizer = torch.optim.Adam(task_b_model.parameters(), lr=0.005)
epochs = 20
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

In [58]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
task_b_model.to(device)
best_epoch = 0
early_stop = 40
# train
task_b_model.reset_parameters()
for epoch in range(1, epochs + 1):
    task_b_model.train()
    
    cnt = 0
    loss_accum = count_accum = 0
    for batch in tqdm(taskb_loader_dict["train"]):
        cnt += 1
        if cnt > early_stop:
            break
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = task_b_model(
            batch,
            task_b.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_b.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    train_loss = loss_accum / count_accum
    val_pred_hat = task_b.get_table("val").df[task_b.target_col].to_numpy()
    val_logits = test(taskb_loader_dict["val"], task_b_model, task_b)
    val_logits = val_logits.numpy()
    
    val_metrics = {
        "mae": mean_absolute_error(val_pred_hat, val_logits),
        "r2": r2_score(val_pred_hat, val_logits),
        "rmse": root_mean_squared_error(val_pred_hat, val_logits),
    }
    
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(task_b_model.state_dict())

best_epoch

  0%|          | 0/296 [00:00<?, ?it/s]

 14%|█▎        | 40/296 [00:06<00:41,  6.17it/s]


Epoch: 01, Train loss: 0.4110280603170395, Val metrics: {'mae': 0.3936864982174104, 'r2': -0.30282323912695497, 'rmse': 0.5452460380940995}


 14%|█▎        | 40/296 [00:06<00:42,  5.96it/s]


Epoch: 02, Train loss: 0.24289779141545295, Val metrics: {'mae': 0.384304029025163, 'r2': -0.2736703055791698, 'rmse': 0.5391111102901006}


 14%|█▎        | 40/296 [00:06<00:42,  5.99it/s]


Epoch: 03, Train loss: 0.2302646704018116, Val metrics: {'mae': 0.3877429890546543, 'r2': -0.2695288294198446, 'rmse': 0.5382339077140139}


 14%|█▎        | 40/296 [00:06<00:41,  6.23it/s]


Epoch: 04, Train loss: 0.22985964342951776, Val metrics: {'mae': 0.3908682909753259, 'r2': -0.33945272075606603, 'rmse': 0.5528578297369027}


 14%|█▎        | 40/296 [00:06<00:38,  6.58it/s]


Epoch: 05, Train loss: 0.21786684431135656, Val metrics: {'mae': 0.39870456898865636, 'r2': -0.38730088352122327, 'rmse': 0.5626458263218033}


 14%|█▎        | 40/296 [00:06<00:41,  6.20it/s]


Epoch: 06, Train loss: 0.2138837032020092, Val metrics: {'mae': 0.3794568005223825, 'r2': -0.3743733489618122, 'rmse': 0.5600181888565388}


 14%|█▎        | 40/296 [00:06<00:40,  6.27it/s]


Epoch: 07, Train loss: 0.2111013986170292, Val metrics: {'mae': 0.4065624115894496, 'r2': -0.40869757010953034, 'rmse': 0.5669681373353582}


 14%|█▎        | 40/296 [00:06<00:39,  6.43it/s]


Epoch: 08, Train loss: 0.20857503078877926, Val metrics: {'mae': 0.3900789748812933, 'r2': -0.38278624839372455, 'rmse': 0.5617295829243674}


 14%|█▎        | 40/296 [00:05<00:35,  7.17it/s]


Epoch: 09, Train loss: 0.2060912225395441, Val metrics: {'mae': 0.4063927779075654, 'r2': -0.5193313758323912, 'rmse': 0.5888111474175233}


 14%|█▎        | 40/296 [00:05<00:35,  7.16it/s]


Epoch: 10, Train loss: 0.20251142829656602, Val metrics: {'mae': 0.4020019833621722, 'r2': -0.4002314430996885, 'rmse': 0.5652618598386858}


 14%|█▎        | 40/296 [00:05<00:34,  7.32it/s]


Epoch: 11, Train loss: 0.19478647522628306, Val metrics: {'mae': 0.3873272033693128, 'r2': -0.38843926023356556, 'rmse': 0.562876623971106}


 14%|█▎        | 40/296 [00:06<00:39,  6.52it/s]


Epoch: 12, Train loss: 0.1929924976080656, Val metrics: {'mae': 0.38693799886738184, 'r2': -0.3918509976059299, 'rmse': 0.5635677629355763}


 14%|█▎        | 40/296 [00:06<00:39,  6.41it/s]


Epoch: 13, Train loss: 0.18783764094114302, Val metrics: {'mae': 0.39364775330928603, 'r2': -0.4563227805582635, 'rmse': 0.5764724962189511}


 14%|█▎        | 40/296 [00:06<00:38,  6.62it/s]


Epoch: 14, Train loss: 0.18445986583828927, Val metrics: {'mae': 0.3983604497899226, 'r2': -0.4403371393483424, 'rmse': 0.5732998785991873}


 14%|█▎        | 40/296 [00:06<00:42,  5.96it/s]


Epoch: 15, Train loss: 0.1852744035422802, Val metrics: {'mae': 0.40969337951042534, 'r2': -0.5093403171949973, 'rmse': 0.5868719555921155}


 14%|█▎        | 40/296 [00:05<00:38,  6.70it/s]


Epoch: 16, Train loss: 0.18679688014090062, Val metrics: {'mae': 0.4025752679474265, 'r2': -0.4999514545499366, 'rmse': 0.5850437874459633}


 14%|█▎        | 40/296 [00:06<00:40,  6.35it/s]


Epoch: 17, Train loss: 0.1825265321880579, Val metrics: {'mae': 0.4052837893479933, 'r2': -0.48206966643968485, 'rmse': 0.5815460088863766}


 14%|█▎        | 40/296 [00:06<00:38,  6.64it/s]


Epoch: 18, Train loss: 0.1769079878926277, Val metrics: {'mae': 0.3983733381981597, 'r2': -0.4803401211491005, 'rmse': 0.5812065836234902}


 14%|█▎        | 40/296 [00:06<00:39,  6.45it/s]


Epoch: 19, Train loss: 0.17650280594825746, Val metrics: {'mae': 0.4072830869679515, 'r2': -0.48362593122006525, 'rmse': 0.5818512584067168}


 14%|█▎        | 40/296 [00:06<00:42,  5.99it/s]


Epoch: 20, Train loss: 0.18459568172693253, Val metrics: {'mae': 0.3996612915887673, 'r2': -0.5034183352771859, 'rmse': 0.585719511441545}


6

In [61]:
# test
task_b_model.load_state_dict(state_dict)
logits = test(taskb_loader_dict["test"], task_b_model, task_b)
logits = logits.numpy()
pred_hat = task_b.get_table("test", mask_input_cols=False).df[task_b.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

{'mae': 0.40999596248485687,
 'r2': -0.5214476645866393,
 'rmse': 0.5935707603008625}

In [62]:
# save the individual model state
torch.save(task_b_model.state_dict(), "task_b_model.pth")

: 