In [1]:
%cd ..
from tqdm import tqdm
from utils.data import preprocess_event_database
import numpy as np
import torch
import pickle
import os

from torch_geometric.data import HeteroData
from relbench.datasets import get_dataset

device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = get_dataset('rel-event')
db = dataset.get_db()
preprocess_event_database(db)

Loading Database object from /home/lingze/.cache/relbench/rel-event/db...
Done in 3.97 seconds.


In [3]:
cache_path = "./data/rel-event-tensor-frame/"

In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from utils.resource import get_text_embedder_cfg
text_embedder_cfg = get_text_embedder_cfg(
    model_name = "sentence-transformers/average_word_embeddings_glove.6B.300d", 
    device = "cpu")

  return self.fget.__get__(instance, owner)()


In [6]:
from utils.builder import build_pyg_hetero_graph
data, col_stats_dict = build_pyg_hetero_graph(
    db,
    col_type_dict,
    text_embedder_cfg,
    cache_path,
    True,
)

-----> Materialize event_attendees Tensor Frame
-----> Build edge between users and users
-----> Materialize events Tensor Frame
-----> Materialize event_interest Tensor Frame
-----> Materialize users Tensor Frame


In [7]:
# get the relbench tasks
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy

In [8]:
task_a = get_task("rel-event", "user-ignore", download = True)
entity_table = task_a.entity_table

In [9]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128, 64
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=256,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [10]:
@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [11]:
# construct bottom model
channels = 128

temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels,
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="batch_norm",
    dropout_prob=0.2
)

net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="sum",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [12]:
# training
task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.005
epoches = 50
loss_fn = BCEWithLogitsLoss()
tune_metric = "auroc"
higher_is_better = True
early_stop = 5
max_round_epoch = 20
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)
optimizer = torch.optim.Adam(net.parameters(), lr = lr)

In [13]:
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
patience = 0
for epoch in range(1, epoches + 1):
    net.train()
    cnt = 0
    loss_accum = count_accum = 0
    for batch in tqdm(task_loader_dict["train"], leave=False):
        cnt += 1
        if cnt > max_round_epoch:
            break
        
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)
    
    train_loss = loss_accum / count_accum
    val_logits = test(task_loader_dict["val"], net, task_a)
    val_logits = torch.sigmoid(val_logits).numpy()
    
    val_pred = (val_logits > 0.5).astype(int)
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_metrics = {
            "auroc": roc_auc_score(val_pred_hat, val_logits),
        # "accuracy": accuracy_score(val_pred_hat, val_pred),
        # "precision": precision_score(val_pred_hat, val_pred),
        # "recall": recall_score(val_pred_hat, val_pred),
        # "f1": f1_score(val_pred_hat, val_pred),
    }
    
    test_logits = test(task_loader_dict["test"], net, task_a)
    test_logits =  torch.sigmoid(test_logits).numpy()
    print("*"*30 + f"<Epoch: {epoch:02d}>" + "*"*30)
    print(f", Train loss: {train_loss}, Val metrics: {val_metrics}")
    
    test_pred = (test_logits > 0.5).astype(int)
    test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
    test_metrics = {
        "auroc": roc_auc_score(test_pred_hat, test_logits),
        # "accuracy": accuracy_score(test_pred_hat, test_pred),
        # "precision": precision_score(test_pred_hat, test_pred),
        # "recall": recall_score(test_pred_hat, test_pred),
        # "f1score": f1_score(test_pred_hat, test_pred),
    }

    print(f"Test metrics: {test_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        patience = 0
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())
    else:
        patience += 1
    
    if patience > early_stop:
        break

# print the best epoch
best_epoch

                                               

******************************<Epoch: 01>******************************
, Train loss: 0.3594540059566498, Val metrics: {'auroc': 0.855975456421885}
Test metrics: {'auroc': 0.7961468356038233}


                                               

******************************<Epoch: 02>******************************
, Train loss: 0.2994861871004105, Val metrics: {'auroc': 0.8746414520224045}
Test metrics: {'auroc': 0.8133136451684608}


                                               

******************************<Epoch: 03>******************************
, Train loss: 0.28578757494688034, Val metrics: {'auroc': 0.8840981549314882}
Test metrics: {'auroc': 0.8253767743164714}


                                               

******************************<Epoch: 04>******************************
, Train loss: 0.28483045995235445, Val metrics: {'auroc': 0.8834067993294183}
Test metrics: {'auroc': 0.8403721698262996}


                                               

******************************<Epoch: 05>******************************
, Train loss: 0.2815660312771797, Val metrics: {'auroc': 0.8820692094501619}
Test metrics: {'auroc': 0.8286570766250675}


                                               

******************************<Epoch: 06>******************************
, Train loss: 0.28003799095749854, Val metrics: {'auroc': 0.8727245527840766}
Test metrics: {'auroc': 0.8148969546854656}


                                               

******************************<Epoch: 07>******************************
, Train loss: 0.268093466758728, Val metrics: {'auroc': 0.8983207599279027}
Test metrics: {'auroc': 0.8294217395446317}


                                               

******************************<Epoch: 08>******************************
, Train loss: 0.2757302105426788, Val metrics: {'auroc': 0.8780500804310327}
Test metrics: {'auroc': 0.8022889079419517}


                                               

******************************<Epoch: 09>******************************
, Train loss: 0.26034917309880257, Val metrics: {'auroc': 0.8943597738240594}
Test metrics: {'auroc': 0.7945844844558764}


                                               

******************************<Epoch: 10>******************************
, Train loss: 0.2391987420618534, Val metrics: {'auroc': 0.887798103571913}
Test metrics: {'auroc': 0.8068197262709981}


                                               

******************************<Epoch: 11>******************************
, Train loss: 0.2618034064769745, Val metrics: {'auroc': 0.8751980493051922}
Test metrics: {'auroc': 0.7980711949445872}


                                               

******************************<Epoch: 12>******************************
, Train loss: 0.25437475070357324, Val metrics: {'auroc': 0.9028389489699011}
Test metrics: {'auroc': 0.8147947032485472}


                                               

******************************<Epoch: 13>******************************
, Train loss: 0.2466282732784748, Val metrics: {'auroc': 0.9146207628350487}
Test metrics: {'auroc': 0.8191048236003937}


                                               

******************************<Epoch: 14>******************************
, Train loss: 0.2568252310156822, Val metrics: {'auroc': 0.9198172858887146}
Test metrics: {'auroc': 0.8252815090025721}


                                               

******************************<Epoch: 15>******************************
, Train loss: 0.2456231825053692, Val metrics: {'auroc': 0.8963081671415003}
Test metrics: {'auroc': 0.794778190594138}


                                               

******************************<Epoch: 16>******************************
, Train loss: 0.25482316240668296, Val metrics: {'auroc': 0.9019310715739287}
Test metrics: {'auroc': 0.7731142231113652}


                                               

******************************<Epoch: 17>******************************
, Train loss: 0.24870089143514634, Val metrics: {'auroc': 0.905367560129465}
Test metrics: {'auroc': 0.8015636213521324}


                                               

******************************<Epoch: 18>******************************
, Train loss: 0.24473128616809844, Val metrics: {'auroc': 0.9172632372037134}
Test metrics: {'auroc': 0.8079616398336034}


                                               

******************************<Epoch: 19>******************************
, Train loss: 0.2380901075899601, Val metrics: {'auroc': 0.9163341618698763}
Test metrics: {'auroc': 0.8104245657489443}


                                               

******************************<Epoch: 20>******************************
, Train loss: 0.231086066365242, Val metrics: {'auroc': 0.9152894063608348}
Test metrics: {'auroc': 0.803284112921152}


                                               

******************************<Epoch: 21>******************************
, Train loss: 0.23276707753539086, Val metrics: {'auroc': 0.910373907397717}
Test metrics: {'auroc': 0.8068533866819091}


                                               

******************************<Epoch: 22>******************************
, Train loss: 0.24851371869444847, Val metrics: {'auroc': 0.9101249830416496}
Test metrics: {'auroc': 0.8006370073989394}


                                               

******************************<Epoch: 23>******************************
, Train loss: 0.22576750442385674, Val metrics: {'auroc': 0.894703180417466}
Test metrics: {'auroc': 0.7572455622241274}


                                               

******************************<Epoch: 24>******************************
, Train loss: 0.22801138758659362, Val metrics: {'auroc': 0.9008154543868829}
Test metrics: {'auroc': 0.7883534978247753}


                                               

******************************<Epoch: 25>******************************
, Train loss: 0.22643422409892083, Val metrics: {'auroc': 0.9088058656511038}
Test metrics: {'auroc': 0.790845003334286}


14

In [14]:
net.load_state_dict(state_dict)
test_logits = test(task_loader_dict["test"], net, task_a)
test_logits =  torch.sigmoid(test_logits).numpy()

test_pred = (test_logits > 0.5).astype(int)
test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
test_metrics = {
    "auroc": roc_auc_score(test_pred_hat, test_logits),
    "accuracy": accuracy_score(test_pred_hat, test_pred),
    "precision": precision_score(test_pred_hat, test_pred),
    "recall": recall_score(test_pred_hat, test_pred),
    "f1score": f1_score(test_pred_hat, test_pred),
}
test_metrics

{'auroc': 0.8253958273792512,
 'accuracy': 0.901747277791846,
 'precision': 0.6083916083916084,
 'recall': 0.38666666666666666,
 'f1score': 0.47282608695652173}