In [1]:
%cd ..
from tqdm import tqdm
from utils.data import preprocess_event_database
import numpy as np
import torch
import pickle
import os

from torch_geometric.data import HeteroData
from relbench.datasets import get_dataset

device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = get_dataset('rel-event')
db = dataset.get_db()
preprocess_event_database(db)

Loading Database object from /home/lingze/.cache/relbench/rel-event/db...
Done in 3.01 seconds.


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  event_df["event_id"].replace(event_id2index, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  event_df["event_id"].replace(event_id2index, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] 

In [3]:
cache_path = "./data/rel-event-tensor-frame/"
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [4]:
from utils.resource import get_text_embedder_cfg
text_embedder_cfg = get_text_embedder_cfg(
    model_name = "sentence-transformers/average_word_embeddings_glove.6B.300d", 
    device = "cpu")

  return self.fget.__get__(instance, owner)()


In [5]:
from utils.builder import build_pyg_hetero_graph
data, col_stats_dict = build_pyg_hetero_graph(
    db,
    col_type_dict,
    text_embedder_cfg,
    cache_path,
    True,
)

-----> Materialize event_attendees Tensor Frame
-----> Build edge between users and users
-----> Materialize events Tensor Frame
-----> Materialize event_interest Tensor Frame
-----> Materialize users Tensor Frame


In [6]:
# add new edges:
from utils.util import load_np_dict
from torch_geometric.utils import sort_edge_index
edge_dict = load_np_dict("./edges/rel-event-edges.npz")

for edge_name, edge_np in edge_dict.items():
    src_table, dst_table = edge_name.split('-')[0], edge_name.split('-')[1]
    edge_index = torch.from_numpy(edge_np.astype(int)).t()
    # [2, edge_num]
    edge_type = (src_table, f"appendix", dst_table)
    data[edge_type].edge_index = sort_edge_index(edge_index)
data.validate()

True

In [7]:
# get the relbench tasks
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy

In [8]:
task_a = get_task("rel-event", "user-attendance", download = True)
entity_table = task_a.entity_table

In [9]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128, 128
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=256,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [10]:
@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask, early_stop: int = 0)-> np.ndarray:
    # model.eval()
    pred_list = []
    early_stop = early_stop if early_stop > 0 else len(loader)
    for idx,batch in tqdm(enumerate(loader), leave=False, total=len(loader)):
        if idx > early_stop:
            break
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [11]:
# construct bottom model
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="sum",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [12]:
# read the pre-trained model
pre_trained_model_param_path = './static/rel-event-pre-trained-channel128-ep100-best-state.pth'
pre_trained_state_dict = torch.load(pre_trained_model_param_path)
net.load_state_dict(pre_trained_state_dict)

<All keys matched successfully>

In [13]:
# for regression task, we need to deactivate the normalization and dropout layer
task_a.task_type
# freeze_instances = (torch.nn.BatchNorm1d, torch.nn.LayerNorm, torch.nn.Dropout, torch.nn.BatchNorm2d)
deactive_nn_instances = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)
net.train()
for module in net.modules():
    if isinstance(module, deactive_nn_instances):
        module.eval()
        for param in module.parameters():
            param.requires_grad = False


In [14]:
# training for fine-tune
from torch.nn import L1Loss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error

task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.005
epoches = 80
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False
early_stop = 10
max_round_epoch = 20
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)

In [15]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
patience = 0
test_early_stop = 50
# train
for epoch in range(1, epoches + 1):
    cnt = 0
    loss_accum = count_accum = 0
    # net.train()
    for batch in tqdm(task_loader_dict["train"], leave = False):
        cnt += 1
        if cnt > max_round_epoch:
            break
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            task_a.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_a.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    train_loss = loss_accum / count_accum

    val_logits = test(task_loader_dict["val"], net, task_a, test_early_stop)
    val_logits = val_logits.numpy()
    val_n = len(val_logits)
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()[:val_n]
    val_metrics = {
        "mae": mean_absolute_error(val_pred_hat, val_logits),
        # "r2": r2_score(val_pred_hat, val_logits),
        # "rmse": root_mean_squared_error(val_pred_hat, val_logits),
    }
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")
    

    logits = test(task_loader_dict["test"], net, task_a, test_early_stop)
    logits = logits.numpy()
    test_n = len(logits)
    pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()[:test_n]
    test_metrics = {
            "mae": mean_absolute_error(pred_hat, logits),
            # "r2": r2_score(pred_hat, logits),
            # "rmse": root_mean_squared_error(pred_hat, logits),
    }
    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        patience = 0
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())
        # calculate test metrics

        print(f"Update the best scores\t Test metrics: {test_metrics}")
    else:
        patience += 1
        print(f"Test metrics: {test_metrics}")

    if patience >= early_stop:
        print(f"Early stop at epoch {epoch}")
        break   
    
best_epoch

  0%|          | 0/76 [00:00<?, ?it/s]

                                              

Epoch: 01, Train loss: 0.5699805051088334, Val metrics: {'mae': 0.3432623903867148}


                                             

Update the best scores	 Test metrics: {'mae': 0.35461682975254155}


                                               

Epoch: 02, Train loss: 0.4374914050102234, Val metrics: {'mae': 0.28964452126460827}


                                             

Update the best scores	 Test metrics: {'mae': 0.2921969253908504}


                                               

Epoch: 03, Train loss: 0.3904915928840637, Val metrics: {'mae': 0.2835505448070795}


                                             

Update the best scores	 Test metrics: {'mae': 0.28733951893118537}


                                               

Epoch: 04, Train loss: 0.3894182026386261, Val metrics: {'mae': 0.3032261137394868}


                                             

Test metrics: {'mae': 0.30694277010348164}


                                               

Epoch: 05, Train loss: 0.37682534456253053, Val metrics: {'mae': 0.2915356442243438}


                                             

Test metrics: {'mae': 0.2972936766475165}


                                               

Epoch: 06, Train loss: 0.3843789428472519, Val metrics: {'mae': 0.2730883817251825}


                                             

Update the best scores	 Test metrics: {'mae': 0.2754468128591228}


                                               

Epoch: 07, Train loss: 0.382332918047905, Val metrics: {'mae': 0.26746500797765915}


                                             

Update the best scores	 Test metrics: {'mae': 0.2690662868526423}


                                               

Epoch: 08, Train loss: 0.3542475521564484, Val metrics: {'mae': 0.27042995164169936}


                                             

Test metrics: {'mae': 0.2720089895085672}


                                               

Epoch: 09, Train loss: 0.384050115942955, Val metrics: {'mae': 0.26921774306317436}


                                             

Test metrics: {'mae': 0.2717647494537046}


                                               

Epoch: 10, Train loss: 0.3512449860572815, Val metrics: {'mae': 0.2702910683445829}


                                             

Test metrics: {'mae': 0.2725076460250034}


                                               

Epoch: 11, Train loss: 0.37159939408302306, Val metrics: {'mae': 0.2741619596490058}


                                             

Test metrics: {'mae': 0.27697954545667025}


                                               

Epoch: 12, Train loss: 0.3962375074625015, Val metrics: {'mae': 0.2630382315567993}


                                             

Update the best scores	 Test metrics: {'mae': 0.26771669186826746}


                                               

Epoch: 13, Train loss: 0.3518824428319931, Val metrics: {'mae': 0.26873472378664626}


                                             

Test metrics: {'mae': 0.2749006167716281}


                                               

Epoch: 14, Train loss: 0.3463559508323669, Val metrics: {'mae': 0.28035442331612437}


                                             

Test metrics: {'mae': 0.28565735903029177}


                                               

Epoch: 15, Train loss: 0.3697954475879669, Val metrics: {'mae': 0.26686147999775334}


                                             

Test metrics: {'mae': 0.27177416960936157}


                                               

Epoch: 16, Train loss: 0.3409941703081131, Val metrics: {'mae': 0.2705951159088042}


                                             

Test metrics: {'mae': 0.2774050770378819}


                                               

Epoch: 17, Train loss: 0.3652694195508957, Val metrics: {'mae': 0.270032754524158}


                                             

Test metrics: {'mae': 0.2759255292607553}


                                               

Epoch: 18, Train loss: 0.36918156445026395, Val metrics: {'mae': 0.26789408708819396}


                                             

Test metrics: {'mae': 0.27461760427776166}


                                               

Epoch: 19, Train loss: 0.3510738581418991, Val metrics: {'mae': 0.25724170375124944}


                                             

Update the best scores	 Test metrics: {'mae': 0.2626398668526477}


                                               

Epoch: 20, Train loss: 0.3475362777709961, Val metrics: {'mae': 0.2542767147637693}


                                             

Update the best scores	 Test metrics: {'mae': 0.26072556172917094}


                                               

Epoch: 21, Train loss: 0.36185933351516725, Val metrics: {'mae': 0.2589302062488615}


                                             

Test metrics: {'mae': 0.2635268403181047}


                                               

Epoch: 22, Train loss: 0.3418229728937149, Val metrics: {'mae': 0.2522156445861929}


                                             

Update the best scores	 Test metrics: {'mae': 0.25636788444504127}


                                               

Epoch: 23, Train loss: 0.34176734685897825, Val metrics: {'mae': 0.25726805755223614}


                                             

Test metrics: {'mae': 0.26400017465699743}


                                               

Epoch: 24, Train loss: 0.34679768085479734, Val metrics: {'mae': 0.24884594278849598}


                                             

Update the best scores	 Test metrics: {'mae': 0.2555970136757343}


                                               

Epoch: 25, Train loss: 0.34958581924438475, Val metrics: {'mae': 0.26132788451391253}


                                             

Test metrics: {'mae': 0.2699885992761378}


                                               

Epoch: 26, Train loss: 0.3397695288062096, Val metrics: {'mae': 0.261216866323457}


                                             

Test metrics: {'mae': 0.27033874634139626}


                                               

Epoch: 27, Train loss: 0.31011336892843244, Val metrics: {'mae': 0.25028484912429116}


                                             

Test metrics: {'mae': 0.2573699193119059}


                                               

Epoch: 28, Train loss: 0.3484345942735672, Val metrics: {'mae': 0.25365779451541026}


                                             

Test metrics: {'mae': 0.2612730852034289}


                                               

Epoch: 29, Train loss: 0.3232641726732254, Val metrics: {'mae': 0.24921075761932612}


                                             

Test metrics: {'mae': 0.2559494163167489}


                                               

Epoch: 30, Train loss: 0.327875480055809, Val metrics: {'mae': 0.24633208961146008}


                                             

Update the best scores	 Test metrics: {'mae': 0.2501799647617358}


                                               

Epoch: 31, Train loss: 0.33861030638217926, Val metrics: {'mae': 0.256317492257998}


                                             

Test metrics: {'mae': 0.2619849353499779}


                                               

Epoch: 32, Train loss: 0.342205348610878, Val metrics: {'mae': 0.24803230900204543}


                                             

Test metrics: {'mae': 0.2511066182728655}


                                               

Epoch: 33, Train loss: 0.32831390798091886, Val metrics: {'mae': 0.26275683476987965}


                                             

Test metrics: {'mae': 0.2680519497701261}


                                               

Epoch: 34, Train loss: 0.3327591001987457, Val metrics: {'mae': 0.2624991287824768}


                                             

Test metrics: {'mae': 0.2677500827201814}


                                               

Epoch: 35, Train loss: 0.3419374078512192, Val metrics: {'mae': 0.25172724508552247}


                                             

Test metrics: {'mae': 0.25680216959086877}


                                               

Epoch: 36, Train loss: 0.33245505690574645, Val metrics: {'mae': 0.24864729387816126}


                                             

Test metrics: {'mae': 0.2527486488786245}


                                               

Epoch: 37, Train loss: 0.34144198298454287, Val metrics: {'mae': 0.2517703043040623}


                                             

Test metrics: {'mae': 0.2587724453574608}


                                               

Epoch: 38, Train loss: 0.3284305721521378, Val metrics: {'mae': 0.2696133920884938}


                                             

Test metrics: {'mae': 0.27677312451657343}


                                               

Epoch: 39, Train loss: 0.33001417517662046, Val metrics: {'mae': 0.25070852361771223}


                                             

Test metrics: {'mae': 0.25415188552615775}


                                               

Epoch: 40, Train loss: 0.3540904402732849, Val metrics: {'mae': 0.2449256539459526}


                                             

Update the best scores	 Test metrics: {'mae': 0.2475858534076484}


                                               

Epoch: 41, Train loss: 0.3231784254312515, Val metrics: {'mae': 0.25360227416936165}


                                             

Test metrics: {'mae': 0.25496500509432435}


                                               

Epoch: 42, Train loss: 0.3258392482995987, Val metrics: {'mae': 0.2524790181058944}


                                             

Test metrics: {'mae': 0.253831721927906}


                                               

Epoch: 43, Train loss: 0.32127413153648376, Val metrics: {'mae': 0.2685628931893529}


                                             

Test metrics: {'mae': 0.26631758221452706}


                                               

Epoch: 44, Train loss: 0.3024706542491913, Val metrics: {'mae': 0.2547666651399883}


                                             

Test metrics: {'mae': 0.2565660569702672}


                                               

Epoch: 45, Train loss: 0.3269731789827347, Val metrics: {'mae': 0.25884612600082674}


                                             

Test metrics: {'mae': 0.2615894917341613}


                                               

Epoch: 46, Train loss: 0.3330889016389847, Val metrics: {'mae': 0.24468506213649552}


                                             

Update the best scores	 Test metrics: {'mae': 0.24984155952801074}


                                               

Epoch: 47, Train loss: 0.32237555384635924, Val metrics: {'mae': 0.2688981809214283}


                                             

Test metrics: {'mae': 0.2739221748417439}


                                               

Epoch: 48, Train loss: 0.29991266429424285, Val metrics: {'mae': 0.24537119446104874}


                                             

Test metrics: {'mae': 0.25027407778843674}


                                               

Epoch: 49, Train loss: 0.3230536639690399, Val metrics: {'mae': 0.24852238955985712}


                                             

Test metrics: {'mae': 0.2520461046537413}


                                               

Epoch: 50, Train loss: 0.3402813091874123, Val metrics: {'mae': 0.241547257861079}


                                             

Update the best scores	 Test metrics: {'mae': 0.24641801687613876}


                                               

Epoch: 51, Train loss: 0.34125052094459535, Val metrics: {'mae': 0.25665706421456347}


                                             

Test metrics: {'mae': 0.26214144707479503}


                                               

Epoch: 52, Train loss: 0.33093594014644623, Val metrics: {'mae': 0.2477139711317084}


                                             

Test metrics: {'mae': 0.25212689060851}


                                               

Epoch: 53, Train loss: 0.3447796881198883, Val metrics: {'mae': 0.2552322836762125}


                                             

Test metrics: {'mae': 0.2577516295613228}


                                               

Epoch: 54, Train loss: 0.32788940966129304, Val metrics: {'mae': 0.24283191215818756}


                                             

Test metrics: {'mae': 0.24566926162374092}


                                               

Epoch: 55, Train loss: 0.31224247217178347, Val metrics: {'mae': 0.2567213510162666}


                                             

Test metrics: {'mae': 0.25837512567936155}


                                               

Epoch: 56, Train loss: 0.3116843208670616, Val metrics: {'mae': 0.2440805018958587}


                                             

Test metrics: {'mae': 0.24444004780859271}


                                               

Epoch: 57, Train loss: 0.329606419801712, Val metrics: {'mae': 0.25379664521762463}


                                             

Test metrics: {'mae': 0.2572497612586487}


                                               

Epoch: 58, Train loss: 0.31226799041032793, Val metrics: {'mae': 0.2462977359633036}


                                             

Test metrics: {'mae': 0.24962049304289824}


                                               

Epoch: 59, Train loss: 0.3224899858236313, Val metrics: {'mae': 0.25089142362180483}


                                             

Test metrics: {'mae': 0.2543080276861954}


                                               

Epoch: 60, Train loss: 0.31362627148628236, Val metrics: {'mae': 0.24441277897711777}


                                             

Test metrics: {'mae': 0.2473541111956456}
Early stop at epoch 60




50

In [17]:
# test
net.load_state_dict(state_dict)
logits = test(task_loader_dict["test"], net, task_a)
logits = logits.numpy()
pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

  0%|          | 0/8 [00:00<?, ?it/s]

                                             

{'mae': 0.24648418025027075,
 'r2': 0.05870032800885938,
 'rmse': 0.6232577427551715}