In [1]:
%cd ..
from tqdm import tqdm
from utils.data import StackDataset
import numpy as np
import torch
import pickle
import os

from torch_geometric.data import HeteroData


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = StackDataset(cache_dir="/home/lingze/.cache/relbench/stack")
db = dataset.get_db()

Loading Database object from /home/lingze/.cache/relbench/stack/db...
Done in 9.69 seconds.


In [3]:
cache_path = "./data/stack-tensor-frame/"

In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from utils.resource import get_text_embedder_cfg
text_embedder_cfg = get_text_embedder_cfg(
    model_name = "sentence-transformers/average_word_embeddings_glove.6B.300d", 
    device = device)

  return self.fget.__get__(instance, owner)()


In [6]:
from utils.builder import build_pyg_hetero_graph
data, col_stats_dict = build_pyg_hetero_graph(
    db,
    col_type_dict,
    text_embedder_cfg,
    cache_path,
    True,
)

-----> Materialize tags Tensor Frame
-----> Materialize postHistory Tensor Frame
-----> Materialize comments Tensor Frame
-----> Materialize badges Tensor Frame
-----> Build edge between posts and tags
-----> Materialize users Tensor Frame
-----> Materialize postLinks Tensor Frame
-----> Materialize votes Tensor Frame
-----> Materialize posts Tensor Frame


In [7]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy


In [8]:
task_a = get_task("rel-stack", "post-votes", download = True)
entity_table = task_a.entity_table

In [9]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [None]:
# from typing import List
# def generate_loader_dict_specific_node(
#     task: BaseTask, 
#     data:HeteroData, 
#     node_idxs: np.ndarray,
#     num_neighbors: List[int] = [128, 64]
# ):
#     nodes = (task.entity_table, node_idxs)
    

In [10]:
@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask, early_stop: int = 0)-> np.ndarray:
    # model.eval()
    pred_list = []
    early_stop = early_stop if early_stop > 0 else len(loader)
    for idx,batch in tqdm(enumerate(loader), leave=False, total=len(loader)):
        if idx > early_stop:
            break
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [11]:
# construct bottom model
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="batch_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="sum",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [12]:
# for regression task, we need to deactivate the normalization and dropout layer
task_a.task_type
# freeze_instances = (torch.nn.BatchNorm1d, torch.nn.LayerNorm, torch.nn.Dropout, torch.nn.BatchNorm2d)
deactive_nn_instances = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)
net.train()
for module in net.modules():
    if isinstance(module, deactive_nn_instances):
        module.eval()
        for param in module.parameters():
            param.requires_grad = False


In [13]:
# training for fine-tune
from torch.nn import L1Loss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error

task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.005
epoches = 80
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False
early_stop = 10
max_round_epoch = 30
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)

In [14]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
patience = 0
test_early_stop = 50
# train
for epoch in range(1, epoches + 1):
    cnt = 0
    loss_accum = count_accum = 0
    # net.train()
    for batch in tqdm(task_loader_dict["train"], leave = False):
        cnt += 1
        if cnt > max_round_epoch:
            break
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            task_a.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_a.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    train_loss = loss_accum / count_accum

    val_logits = test(task_loader_dict["val"], net, task_a, test_early_stop)
    val_logits = val_logits.numpy()
    val_n = len(val_logits)
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()[:val_n]
    val_metrics = {
        "mae": mean_absolute_error(val_pred_hat, val_logits),
        # "r2": r2_score(val_pred_hat, val_logits),
        # "rmse": root_mean_squared_error(val_pred_hat, val_logits),
    }
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")
    


    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        patience = 0
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())
        # calculate test metrics
        logits = test(task_loader_dict["test"], net, task_a, test_early_stop)
        logits = logits.numpy()
        test_n = len(logits)
        pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()[:test_n]
        test_metrics = {
                "mae": mean_absolute_error(pred_hat, logits),
                # "r2": r2_score(pred_hat, logits),
                # "rmse": root_mean_squared_error(pred_hat, logits),
        }
        print(f"Update the best scores\t Test metrics: {test_metrics}")
    else:
        patience += 1
    
    if patience >= early_stop:
        print(f"Early stop at epoch {epoch}")
        break   
    
best_epoch

                                                  

Epoch: 01, Train loss: 0.2988541970650355, Val metrics: {'mae': 0.10188150754216693}


                                                

Update the best scores	 Test metrics: {'mae': 0.10851538996354851}


                                                 

Epoch: 02, Train loss: 0.11430266151825587, Val metrics: {'mae': 0.08766572650190438}


                                                

Update the best scores	 Test metrics: {'mae': 0.09452000738041105}


                                                 

Epoch: 03, Train loss: 0.1177882768213749, Val metrics: {'mae': 0.11989024395427574}


                                                 

Epoch: 04, Train loss: 0.1299741397301356, Val metrics: {'mae': 0.10632366070429128}


                                                 

Epoch: 05, Train loss: 0.11530769628783068, Val metrics: {'mae': 0.10124846532619582}


                                                 

Epoch: 06, Train loss: 0.11626069297393163, Val metrics: {'mae': 0.08674288658178686}


                                                

Update the best scores	 Test metrics: {'mae': 0.09373517212738722}


                                                 

Epoch: 07, Train loss: 0.10260674630602201, Val metrics: {'mae': 0.07941188197968158}


                                                

Update the best scores	 Test metrics: {'mae': 0.08654264810705653}


                                                 

Epoch: 08, Train loss: 0.09734583447376886, Val metrics: {'mae': 0.07511450996646411}


                                                

Update the best scores	 Test metrics: {'mae': 0.08223723034661459}


                                                 

Epoch: 09, Train loss: 0.10251803075273831, Val metrics: {'mae': 0.08261271754949406}


                                                 

Epoch: 10, Train loss: 0.09902743523319563, Val metrics: {'mae': 0.07114545625931808}


                                                

Update the best scores	 Test metrics: {'mae': 0.07833163935449157}


                                                 

Epoch: 11, Train loss: 0.09139947108924389, Val metrics: {'mae': 0.07084893097635359}


                                                

Update the best scores	 Test metrics: {'mae': 0.07816547922287581}


                                                 

Epoch: 12, Train loss: 0.0909802682697773, Val metrics: {'mae': 0.06742855518884835}


                                                

Update the best scores	 Test metrics: {'mae': 0.07496149396605438}


                                                 

Epoch: 13, Train loss: 0.09124470253785451, Val metrics: {'mae': 0.07076582333948132}


                                                 

Epoch: 14, Train loss: 0.0931877575814724, Val metrics: {'mae': 0.0664060892058307}


                                                

Update the best scores	 Test metrics: {'mae': 0.0739216845628535}


                                                 

Epoch: 15, Train loss: 0.08498411116500695, Val metrics: {'mae': 0.06957770751386323}


                                                 

Epoch: 16, Train loss: 0.09033782668411731, Val metrics: {'mae': 0.07011042113627614}


                                                 

Epoch: 17, Train loss: 0.08856107965111733, Val metrics: {'mae': 0.06577601294395886}


                                                

Update the best scores	 Test metrics: {'mae': 0.07277351735647958}


                                                 

Epoch: 18, Train loss: 0.09163592060407003, Val metrics: {'mae': 0.06819643576758001}


                                                 

Epoch: 19, Train loss: 0.08368172347545624, Val metrics: {'mae': 0.06804966834854842}


                                                 

Epoch: 20, Train loss: 0.09337529217203458, Val metrics: {'mae': 0.0661360410688526}


                                                 

Epoch: 21, Train loss: 0.09264710880815982, Val metrics: {'mae': 0.06882739195319956}


                                                 

Epoch: 22, Train loss: 0.0876142393797636, Val metrics: {'mae': 0.0689975908023246}


                                                 

Epoch: 23, Train loss: 0.08783800279100736, Val metrics: {'mae': 0.06732169781895994}


                                                 

Epoch: 24, Train loss: 0.09072480723261833, Val metrics: {'mae': 0.06848202060953947}


                                                 

Epoch: 25, Train loss: 0.09023997709155082, Val metrics: {'mae': 0.06631657025784579}


                                                 

Epoch: 26, Train loss: 0.08708882927894593, Val metrics: {'mae': 0.06537021931402634}


                                                

Update the best scores	 Test metrics: {'mae': 0.07196546741219137}


                                                 

Epoch: 27, Train loss: 0.0836745098233223, Val metrics: {'mae': 0.06537763979159696}


                                                 

Epoch: 28, Train loss: 0.08575024406115214, Val metrics: {'mae': 0.06593768214196305}


                                                 

Epoch: 29, Train loss: 0.09083325043320656, Val metrics: {'mae': 0.06658710503893862}


                                                 

Epoch: 30, Train loss: 0.08647281800707181, Val metrics: {'mae': 0.06592833402001085}


                                                 

Epoch: 31, Train loss: 0.08539213252564272, Val metrics: {'mae': 0.06591201357497423}


                                                 

Epoch: 32, Train loss: 0.09393340547879538, Val metrics: {'mae': 0.06496944878047846}


                                                

Update the best scores	 Test metrics: {'mae': 0.07136015656894468}


                                                 

Epoch: 33, Train loss: 0.08622072227299213, Val metrics: {'mae': 0.06553802702662734}


                                                 

Epoch: 34, Train loss: 0.08976320375998816, Val metrics: {'mae': 0.06515914596830757}


                                                 

Epoch: 35, Train loss: 0.08273353974024454, Val metrics: {'mae': 0.06574916728823155}


                                                 

Epoch: 36, Train loss: 0.091374629860123, Val metrics: {'mae': 0.0650337233879755}


                                                 

Epoch: 37, Train loss: 0.08930601999163627, Val metrics: {'mae': 0.0664266195483958}


                                                 

Epoch: 38, Train loss: 0.0837939412643512, Val metrics: {'mae': 0.06663893344340763}


                                                 

Epoch: 39, Train loss: 0.08212159586449465, Val metrics: {'mae': 0.06482596127393751}


                                                

Update the best scores	 Test metrics: {'mae': 0.07077645427100424}


                                                 

Epoch: 40, Train loss: 0.0814968328922987, Val metrics: {'mae': 0.06490774026775291}


                                                 

Epoch: 41, Train loss: 0.0856102659056584, Val metrics: {'mae': 0.06574970262720756}


                                                 

Epoch: 42, Train loss: 0.08136231005191803, Val metrics: {'mae': 0.06474992035659473}


                                                

Update the best scores	 Test metrics: {'mae': 0.07105183878931319}


                                                 

Epoch: 43, Train loss: 0.08768302674094836, Val metrics: {'mae': 0.06444837342053791}


                                                

Update the best scores	 Test metrics: {'mae': 0.07050292178730314}


                                                 

Epoch: 44, Train loss: 0.09271791030963263, Val metrics: {'mae': 0.06545482316075286}


                                                 

Epoch: 45, Train loss: 0.08501300079127153, Val metrics: {'mae': 0.0668853823842749}


                                                 

Epoch: 46, Train loss: 0.087934560328722, Val metrics: {'mae': 0.0655809343845815}


                                                 

Epoch: 47, Train loss: 0.08525236360728741, Val metrics: {'mae': 0.06588489615945027}


                                                 

Epoch: 48, Train loss: 0.08644945832590262, Val metrics: {'mae': 0.06512880872575838}


                                                 

Epoch: 49, Train loss: 0.08544362969696521, Val metrics: {'mae': 0.06525857923603486}


                                                 

Epoch: 50, Train loss: 0.07839766989151636, Val metrics: {'mae': 0.06607129172938113}


                                                 

Epoch: 51, Train loss: 0.09211349574228128, Val metrics: {'mae': 0.06731295384355351}


                                                 

Epoch: 52, Train loss: 0.08585178442299365, Val metrics: {'mae': 0.06518343883887909}


                                                 

Epoch: 53, Train loss: 0.08251386967798074, Val metrics: {'mae': 0.06658790314126388}
Early stop at epoch 53




43

In [15]:
# test
net.load_state_dict(state_dict)
logits = test(task_loader_dict["test"], net, task_a)
logits = logits.numpy()
pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

  0%|          | 0/315 [00:00<?, ?it/s]

                                                 

{'mae': 0.0642314010633084,
 'r2': 0.232499361038208,
 'rmse': 0.3233393236664935}