In [1]:
%cd ..
from tqdm import tqdm
from utils.data import StackDataset
import numpy as np

import torch
import pickle
import os

from torch_geometric.data import HeteroData


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = StackDataset(cache_dir="/home/lingze/.cache/relbench/stack")
db = dataset.get_db()

Loading Database object from /home/lingze/.cache/relbench/stack/db...
Done in 9.89 seconds.


In [3]:
for table_name, table in db.table_dict.items():
    n = len(table.df)
    print(f"Table {table_name} has {n} rows")

Table tags has 1597 rows
Table postHistory has 1175368 rows
Table comments has 623967 rows
Table badges has 463463 rows
Table postTag has 648577 rows
Table users has 255360 rows
Table postLinks has 77337 rows
Table votes has 1317876 rows
Table posts has 333893 rows


In [4]:
cache_path = "./data/stack-tensor-frame/"

In [5]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [6]:
from utils.resource import get_text_embedder_cfg
text_embedder_cfg = get_text_embedder_cfg(
    model_name = "sentence-transformers/average_word_embeddings_glove.6B.300d", 
    device = device)

  return self.fget.__get__(instance, owner)()


In [7]:
from utils.builder import build_pyg_hetero_graph
data, col_stats_dict = build_pyg_hetero_graph(
    db,
    col_type_dict,
    text_embedder_cfg,
    cache_path,
    True,
)

-----> Materialize tags Tensor Frame
-----> Materialize postHistory Tensor Frame
-----> Materialize comments Tensor Frame
-----> Materialize badges Tensor Frame
-----> Build edge between posts and tags
-----> Materialize users Tensor Frame
-----> Materialize postLinks Tensor Frame
-----> Materialize votes Tensor Frame
-----> Materialize posts Tensor Frame


In [9]:
# get the relbench tasks
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy

In [10]:
task_a = get_task("rel-stack", "user-engagement", download = True)
entity_table = task_a.entity_table

In [11]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128, 64
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [12]:
@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [21]:
# construct bottom model
channels = 128

temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels,
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.4
)

net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.4,
    aggr="sum",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [22]:
# training
task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.005
epoches = 40
loss_fn = BCEWithLogitsLoss()
tune_metric = "auroc"
higher_is_better = True
early_stop = 5
max_round_epoch = 50
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)
optimizer = torch.optim.Adam(net.parameters(), lr = lr)

In [23]:
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
patience = 0
for epoch in range(1, epoches + 1):
    net.train()
    cnt = 0
    loss_accum = count_accum = 0
    for batch in tqdm(task_loader_dict["train"], leave=False):
        cnt += 1
        if cnt > max_round_epoch:
            break
        
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)
    
    train_loss = loss_accum / count_accum
    val_logits = test(task_loader_dict["val"], net, task_a)
    val_logits = torch.sigmoid(val_logits).numpy()
    
    val_pred = (val_logits > 0.5).astype(int)
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_metrics = {
            "auroc": roc_auc_score(val_pred_hat, val_logits),
        # "accuracy": accuracy_score(val_pred_hat, val_pred),
        # "precision": precision_score(val_pred_hat, val_pred),
        # "recall": recall_score(val_pred_hat, val_pred),
        # "f1": f1_score(val_pred_hat, val_pred),
    }
    
    test_logits = test(task_loader_dict["test"], net, task_a)
    test_logits =  torch.sigmoid(test_logits).numpy()

    test_pred = (test_logits > 0.5).astype(int)
    test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
    test_metrics = {
        "auroc": roc_auc_score(test_pred_hat, test_logits),
        # "accuracy": accuracy_score(test_pred_hat, test_pred),
        # "precision": precision_score(test_pred_hat, test_pred),
        # "recall": recall_score(test_pred_hat, test_pred),
        # "f1score": f1_score(test_pred_hat, test_pred),
    }
    
    print("*"*30 + f"<Epoch: {epoch:02d}>" + "*"*30)
    print(f", Train loss: {train_loss}, Val metrics: {val_metrics}")
    print(f"Test metrics: {test_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        patience = 0
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())
    else:
        patience += 1
    
    if patience > early_stop:
        break

# print the best epoch
best_epoch

                                                 

******************************<Epoch: 01>******************************
, Train loss: 0.17751104071736334, Val metrics: {'auroc': 0.8619929158978273}
Test metrics: {'auroc': 0.858006987973499}


                                                 

******************************<Epoch: 02>******************************
, Train loss: 0.1480710458755493, Val metrics: {'auroc': 0.8757654654152972}
Test metrics: {'auroc': 0.8727768179023552}


                                                 

******************************<Epoch: 03>******************************
, Train loss: 0.14739177137613296, Val metrics: {'auroc': 0.8831352357130179}
Test metrics: {'auroc': 0.8822768896684354}


                                                 

******************************<Epoch: 04>******************************
, Train loss: 0.14512174502015113, Val metrics: {'auroc': 0.8861831023207394}
Test metrics: {'auroc': 0.8862909664063041}


                                                 

******************************<Epoch: 05>******************************
, Train loss: 0.14434975162148475, Val metrics: {'auroc': 0.8905428075698991}
Test metrics: {'auroc': 0.8928316610638354}


                                                 

******************************<Epoch: 06>******************************
, Train loss: 0.13796303808689117, Val metrics: {'auroc': 0.8904670105591859}
Test metrics: {'auroc': 0.8922880207892395}


                                                 

******************************<Epoch: 07>******************************
, Train loss: 0.13990494415163993, Val metrics: {'auroc': 0.8944207946270052}
Test metrics: {'auroc': 0.8946814870597576}


                                                 

******************************<Epoch: 08>******************************
, Train loss: 0.1419109396636486, Val metrics: {'auroc': 0.8955138977915741}
Test metrics: {'auroc': 0.8964668890523301}


                                                 

******************************<Epoch: 09>******************************
, Train loss: 0.13703391045331956, Val metrics: {'auroc': 0.8938731903084609}
Test metrics: {'auroc': 0.8970190253315733}


                                                 

******************************<Epoch: 10>******************************
, Train loss: 0.1385793524980545, Val metrics: {'auroc': 0.8939954494052046}
Test metrics: {'auroc': 0.8967827652797862}


                                                 

******************************<Epoch: 11>******************************
, Train loss: 0.1315681503713131, Val metrics: {'auroc': 0.8961990886490785}
Test metrics: {'auroc': 0.897590555821881}


                                                 

******************************<Epoch: 12>******************************
, Train loss: 0.13453764393925666, Val metrics: {'auroc': 0.8958722159047275}
Test metrics: {'auroc': 0.8980793760619341}


                                                 

******************************<Epoch: 13>******************************
, Train loss: 0.1305397157371044, Val metrics: {'auroc': 0.8928251248665765}
Test metrics: {'auroc': 0.8962720760528274}


                                                 

******************************<Epoch: 14>******************************
, Train loss: 0.13592741042375564, Val metrics: {'auroc': 0.8959228342481996}
Test metrics: {'auroc': 0.8983444262479205}


                                                 

******************************<Epoch: 15>******************************
, Train loss: 0.1345680770277977, Val metrics: {'auroc': 0.8963333094149666}
Test metrics: {'auroc': 0.8994625023948234}


                                                 

******************************<Epoch: 16>******************************
, Train loss: 0.13461943164467813, Val metrics: {'auroc': 0.8968818856812741}
Test metrics: {'auroc': 0.8980832732895784}


                                                 

******************************<Epoch: 17>******************************
, Train loss: 0.13729683712124824, Val metrics: {'auroc': 0.8958469750924889}
Test metrics: {'auroc': 0.8978122381618214}


                                                 

******************************<Epoch: 18>******************************
, Train loss: 0.13193594723939894, Val metrics: {'auroc': 0.8955650406388262}
Test metrics: {'auroc': 0.8978887336524122}


                                                 

******************************<Epoch: 19>******************************
, Train loss: 0.12987299248576165, Val metrics: {'auroc': 0.8968905735519432}
Test metrics: {'auroc': 0.8978447682798435}


                                                 

******************************<Epoch: 20>******************************
, Train loss: 0.13078380689024927, Val metrics: {'auroc': 0.8963608520779178}
Test metrics: {'auroc': 0.897477028201694}


                                                 

******************************<Epoch: 21>******************************
, Train loss: 0.1308639046549797, Val metrics: {'auroc': 0.8972662301194362}
Test metrics: {'auroc': 0.8975754483193118}


                                                 

******************************<Epoch: 22>******************************
, Train loss: 0.1299095007777214, Val metrics: {'auroc': 0.896272069248499}
Test metrics: {'auroc': 0.8971157835029516}


                                                 

******************************<Epoch: 23>******************************
, Train loss: 0.12588203951716423, Val metrics: {'auroc': 0.8952392691038334}
Test metrics: {'auroc': 0.8957296913096701}


                                                 

******************************<Epoch: 24>******************************
, Train loss: 0.12427199766039848, Val metrics: {'auroc': 0.8925782650495782}
Test metrics: {'auroc': 0.8939860919823329}


                                                 

******************************<Epoch: 25>******************************
, Train loss: 0.12541361883282662, Val metrics: {'auroc': 0.8919093536956538}
Test metrics: {'auroc': 0.8956035454775695}


                                                 

******************************<Epoch: 26>******************************
, Train loss: 0.13090210080146789, Val metrics: {'auroc': 0.8925468644251742}
Test metrics: {'auroc': 0.8930661962718545}


                                                 

******************************<Epoch: 27>******************************
, Train loss: 0.12665733680129052, Val metrics: {'auroc': 0.8910961565720247}
Test metrics: {'auroc': 0.8913882014860983}


21

In [24]:
net.load_state_dict(state_dict)
test_logits = test(task_loader_dict["test"], net, task_a)
test_logits =  torch.sigmoid(test_logits).numpy()

test_pred = (test_logits > 0.5).astype(int)
test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
test_metrics = {
    "auroc": roc_auc_score(test_pred_hat, test_logits),
    "accuracy": accuracy_score(test_pred_hat, test_pred),
    "precision": precision_score(test_pred_hat, test_pred),
    "recall": recall_score(test_pred_hat, test_pred),
    "f1score": f1_score(test_pred_hat, test_pred),
}
test_metrics

{'auroc': 0.8975646468783236,
 'accuracy': 0.9744602153465627,
 'precision': 0.5742115027829313,
 'recall': 0.25673994193280797,
 'f1score': 0.354829464029808}