In [1]:
%cd ..
import torch
import math
import argparse
import copy
from tqdm import tqdm
import numpy as np
from torch_geometric.loader import NeighborLoader
from torch.nn import L1Loss, BCEWithLogitsLoss
from sklearn.metrics import mean_absolute_error, roc_auc_score
from relbench.base import TaskType
from typing import Dict


from utils.util import load_col_types
from utils.resource import get_text_embedder_cfg
from utils.builder import build_pyg_hetero_graph
from utils.data import DatabaseFactory
from utils.sample import get_node_train_table_input_with_sample
from model import HeteroGCN, HeteroGAT, HGT

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_cache_dir = "/home/lingze/.cache/relbench/ratebeer"
cache_dir = "./data/ratebeer-tensor-frame"
db_name = "ratebeer"
task_name = "place-positive"

In [3]:

db = DatabaseFactory.get_db(
    db_name, cache_dir=data_cache_dir
)

dataset = DatabaseFactory.get_dataset(
    db_name, cache_dir = data_cache_dir
)

task = DatabaseFactory.get_task(
    db_name, task_name, dataset
)

col_type_dict = load_col_types(
    cache_path=cache_dir,
    file_name = "col_type_dict.pkl"
)

Loading Database object from /home/lingze/.cache/relbench/ratebeer/db...
Done in 0.74 seconds.


In [4]:
data, col_stats_dict = build_pyg_hetero_graph(
    db,
    col_type_dict,
    get_text_embedder_cfg(device="cpu"),
    cache_dir=cache_dir,
    verbose=True,
)

  return self.fget.__get__(instance, owner)()


-----> Materialize favorites Tensor Frame
-----> Materialize beers Tensor Frame
-----> Materialize places Tensor Frame
-----> Materialize availability Tensor Frame
-----> Materialize place_ratings Tensor Frame
-----> Materialize beer_ratings Tensor Frame
-----> Materialize countries Tensor Frame
-----> Materialize brewers Tensor Frame
-----> Materialize users Tensor Frame


In [5]:
validation_ratio = 1
test_ratio = 1
batch_size = 256
num_neighbors = [128, 64]

channels = 128
out_channels = 1
norm = "layer_norm"
aggr = "sum"
edge_aggr = "sum"
dropout = 0.3
num_layers = 2
heads = 4

In [6]:


data_loader_dict: Dict[str, NeighborLoader] = {}
for split, sample_ratio, table in [
    ("train", 1, task.get_table("train")),
    ("valid", validation_ratio, task.get_table("val")),
    ("test", test_ratio, task.get_table("test", mask_input_cols=False)),
]:

    _, table_input = get_node_train_table_input_with_sample(
        table=table,
        task=task,
        sample_ratio=sample_ratio,
        shuffle=False,
    )

    data_loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=batch_size,
        shuffle=split == "train"
    )


In [7]:
net = HeteroGCN(
    data,
    col_stats_dict,
    channels=channels,
    out_channels=out_channels,
    num_layers = num_layers,
    aggr = aggr,
    edge_aggr = edge_aggr,
    dropout = dropout,
    norm = norm,
)

# net = HeteroGAT(
#     data,
#     col_stats_dict,
#     channels=channels,
#     out_channels=out_channels,
#     num_layers = num_layers,
#     aggr = aggr,
#     edge_aggr = edge_aggr,
#     dropout = dropout,
#     norm = norm,
# )

# net =  net = HGT(
#         data,
#         col_stats_dict,
#         channels=channels,
#         out_channels=out_channels,
#         num_layers=num_layers,
#         aggr=aggr,
#         edge_aggr=edge_aggr,
#         dropout=dropout,
#         norm=norm,
#         heads=heads,
#     )

In [8]:
is_regression = task.task_type == TaskType.REGRESSION


def deactivate_dropout(net: torch.nn.Module):
    """ Deactivate dropout layers in the model. for regression task
    """
    deactive_nn_instances = (
        torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)
    for module in net.modules():
        if isinstance(module, deactive_nn_instances):
            module.eval()
            for param in module.parameters():
                param.requires_grad = False
    return net
net = deactivate_dropout(net) if is_regression else net

In [9]:
loss_fn = L1Loss() if is_regression else BCEWithLogitsLoss()
evaluate_metric_func = mean_absolute_error if is_regression else roc_auc_score
higher_is_better = False if is_regression else True

In [10]:
def test(net: torch.nn.Module, loader: torch.utils.data.DataLoader, entity_table: str, early_stop: int = -1, is_regression: bool = False):
    pred_list = []
    y_list = []
    early_stop = early_stop if early_stop > 0 else len(loader.dataset)

    if not is_regression:
        net.eval()

    for idx, batch in tqdm(enumerate(loader), total=len(loader), leave=False, desc="Testing"):
        with torch.no_grad():
            batch = batch.to(device)
            y = batch[entity_table].y.float()
            pred = net(batch, entity_table)
            pred = pred.view(-1) if pred.size(1) == 1 else pred
            pred_list.append(pred.detach().cpu())
            y_list.append(y.detach().cpu())
        if idx > early_stop:
            break

    pred_list = pred_logits = torch.cat(pred_list, dim=0)
    pred_list = torch.sigmoid(pred_list).numpy()
    y_list = torch.cat(y_list, dim=0).numpy()
    return pred_logits.numpy(), pred_list,  y_list


In [11]:
lr = 0.001
num_epochs = 500
early_stop_threshold = 2
max_round_epoch = 10


In [12]:
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, net.parameters()), lr=lr
)


In [13]:
# init training variables
net.to(device)
patience = 0
best_epoch = 0
best_val_metric = -math.inf if higher_is_better else math.inf
best_model_state = None

In [14]:

for epoch in range(num_epochs):
    loss_accum = count_accum = 0
    net.train()
    for idx, batch in tqdm(enumerate(data_loader_dict["train"]),
                           leave=False,
                           total=len(data_loader_dict["train"]),
                           desc="Training"):
        if idx > max_round_epoch:
            break
        optimizer.zero_grad()
        batch = batch.to(device)
        y = batch[task.entity_table].y.float()
        pred = net(batch, task.entity_table)
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()

        loss_accum += loss.item()
        count_accum += 1
    train_loss = loss_accum / count_accum
    val_logits, _, val_pred_hat = test(
        net, data_loader_dict["valid"], task.entity_table, early_stop=-1, is_regression=is_regression
    )
    val_metric = evaluate_metric_func(val_pred_hat, val_logits)
    print(
        f"==> Epcoh: {epoch} => Train Loss: {train_loss:.6f}, Val {evaluate_metric_func.__name__} Metric: {val_metric:.6f} \t{patience}/{early_stop_threshold}")

    if (higher_is_better and val_metric > best_val_metric) or \
       (not higher_is_better and val_metric < best_val_metric):
        best_val_metric = val_metric
        best_epoch = epoch
        best_model_state = copy.deepcopy(net.state_dict())
        patience = 0

        if True:
            test_logits, _, test_pred_hat = test(
                net, data_loader_dict["test"], task.entity_table, is_regression=is_regression)
            test_metric = evaluate_metric_func(test_pred_hat, test_logits)

            print(
                f"Update the best scores => Test {evaluate_metric_func.__name__} Metric: {test_metric:.6f}")
        else:
            print(
                f"Update the best scores \t "
            )
    else:
        patience += 1
        if patience > early_stop_threshold:
            print(f"Early stopping at epoch {epoch}")
            break

                                                         

==> Epcoh: 0 => Train Loss: 0.715412, Val roc_auc_score Metric: 0.638780 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.651510


                                                         

==> Epcoh: 1 => Train Loss: 0.653358, Val roc_auc_score Metric: 0.642663 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.669429


                                                         

==> Epcoh: 2 => Train Loss: 0.642796, Val roc_auc_score Metric: 0.665023 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.684276


                                                         

==> Epcoh: 3 => Train Loss: 0.639465, Val roc_auc_score Metric: 0.711785 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.722610


                                                         

==> Epcoh: 4 => Train Loss: 0.618940, Val roc_auc_score Metric: 0.749072 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.753519


                                                         

==> Epcoh: 5 => Train Loss: 0.579253, Val roc_auc_score Metric: 0.763426 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.761799


                                                         

==> Epcoh: 6 => Train Loss: 0.567418, Val roc_auc_score Metric: 0.779911 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.769194


                                                         

==> Epcoh: 7 => Train Loss: 0.553163, Val roc_auc_score Metric: 0.807928 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.784731


                                                         

==> Epcoh: 8 => Train Loss: 0.519167, Val roc_auc_score Metric: 0.823271 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.789309


                                                         

==> Epcoh: 9 => Train Loss: 0.485041, Val roc_auc_score Metric: 0.835955 	0/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.793740


                                                         

==> Epcoh: 10 => Train Loss: 0.497251, Val roc_auc_score Metric: 0.835817 	0/2


                                                         

==> Epcoh: 11 => Train Loss: 0.453633, Val roc_auc_score Metric: 0.845749 	1/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.808568


                                                         

==> Epcoh: 12 => Train Loss: 0.481697, Val roc_auc_score Metric: 0.838476 	0/2


                                                         

==> Epcoh: 13 => Train Loss: 0.455922, Val roc_auc_score Metric: 0.828255 	1/2


                                                         

==> Epcoh: 14 => Train Loss: 0.452501, Val roc_auc_score Metric: 0.848882 	2/2


                                                        

Update the best scores => Test roc_auc_score Metric: 0.814017


                                                         

==> Epcoh: 15 => Train Loss: 0.451969, Val roc_auc_score Metric: 0.846257 	0/2


                                                         

==> Epcoh: 16 => Train Loss: 0.439286, Val roc_auc_score Metric: 0.847974 	1/2


                                                         

==> Epcoh: 17 => Train Loss: 0.449091, Val roc_auc_score Metric: 0.831111 	2/2
Early stopping at epoch 17




In [15]:
net.load_state_dict(best_model_state)
table = task.get_table("test", mask_input_cols=False)
_, table_input = get_node_train_table_input_with_sample(
    table=table,
    task=task,
    sample_ratio=1,
    shuffle=False,
)
loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    time_attr="time",
    input_nodes=table_input.nodes,
    input_time=table_input.time,
    transform=table_input.transform,
    batch_size=batch_size,
    shuffle=False,
)

In [16]:
test_logits, _, test_pred_hat = test(
    net, loader, task.entity_table, is_regression=is_regression)
test_metric = evaluate_metric_func(test_pred_hat, test_logits)
print(
    f"Test {evaluate_metric_func.__name__} Metric: {test_metric:.6f}")

                                                        

Test roc_auc_score Metric: 0.814014
