In [1]:
import pandas as pd
from utils.data import DatabaseFactory
from typing import Any, Dict, List, Optional, Type
from tqdm import tqdm
import torch
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 1
%aimport model.aida
%aimport model.layer.fusion_layer

In [3]:
db_name = "event"
db = DatabaseFactory.get_db(db_name,
                            with_text_compress=True)
dataset = DatabaseFactory.get_dataset(db_name)

Loading Database object from /home/lingze/.cache/relbench/rel-event/db...
Done in 2.99 seconds.


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  event_df.replace({"event_id": event_id2index}, inplace=True)


In [4]:
from utils.builder import build_pyg_hetero_graph
from utils.util import load_col_types
from utils.resource import get_text_embedder_cfg


In [5]:
cache_dir="/home/lingze/embedding_fusion/data/rel-event-tensor-frame"

col_type_dict = load_col_types(
    cache_path=cache_dir, file_name="col_type_dict.pkl")

data, col_stats_dict = build_pyg_hetero_graph(
    db,
    col_type_dict,
    get_text_embedder_cfg(device="cpu"),
    cache_dir=cache_dir,
    verbose=True,
)

-----> Materialize event_attendees Tensor Frame
-----> Build edge between users and users
-----> Materialize events Tensor Frame
-----> Materialize event_interest Tensor Frame
-----> Materialize users Tensor Frame
Build pyg hetero graph takes 0.094462 seconds


In [6]:
task_name = "user-repeat"
task = DatabaseFactory.get_task(db_name, task_name, dataset)

In [7]:
from torch_geometric.loader import NeighborLoader
from utils.sample import get_node_train_table_input_with_sample
num_neighbors = [64, 64]
batch_size = 128
data_loader_dict: Dict[str, NeighborLoader] = {}
for split, sample_ratio, table in [
    ("train", 1, task.get_table("train")),
    ("valid", 1, task.get_table("val")),
    ("test", 1, task.get_table("test", mask_input_cols=False)),
]:

    _, table_input = get_node_train_table_input_with_sample(
        table=table,
        task=task,
        sample_ratio=sample_ratio,
        shuffle=False,
    )

    data_loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        temporal_strategy="last",
        batch_size=batch_size,
        shuffle=split == "train"
    )

In [8]:
from model.aida import AIDABaseFeatureEncoder, AIDASharedTableEncoder, AIDAXFormer, AIDATableEncoder
from relbench.modeling.nn import HeteroTemporalEncoder
from model.graphsage import HeteroGraphSAGE
from model.aida import AIDABasicFormer
from model.encoder import build_encoder

In [9]:
from model.aida import construct_default_AIDAXFormer

In [15]:
channels = 128
feat_layer_num = 2
dropout_prob = 0.1
feat_nhead = 1
aggr = "max"
# specific_table_encoder = {
#         task.entity_table: build_encoder(
#             encoder_type="TabM",
#             channels=channels,
#             num_layers=2,
#             dropout_prob=dropout_prob
#         )
# }
specific_table_encoder = None
net = construct_default_AIDAXFormer(
    data,
    col_stats_dict,
    channels=channels,
    out_channels=1,
    feat_layer_num=feat_layer_num,
    dropout_prob=dropout_prob,
    feat_nhead=feat_nhead,
    relation_aggr=aggr,
    deactivate_relation_module=True
)
net.reset_parameters()



In [16]:
from relbench.base import TaskType
from sklearn.metrics import mean_absolute_error, roc_auc_score
from torch.nn import L1Loss, BCEWithLogitsLoss

is_regression = task.task_type == TaskType.REGRESSION


def deactivate_dropout(net: torch.nn.Module):
    """ Deactivate dropout layers in the model. for regression task
    """
    deactive_nn_instances = (
        torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)
    for module in net.modules():
        if isinstance(module, deactive_nn_instances):
            module.eval()
            for param in module.parameters():
                param.requires_grad = False
    return net


net = deactivate_dropout(net) if is_regression else net
loss_fn = L1Loss() if is_regression else BCEWithLogitsLoss()
evaluate_metric_func = mean_absolute_error if is_regression else roc_auc_score
higher_is_better = False if is_regression else True

In [17]:
@torch.no_grad()
def test(net: torch.nn.Module, loader: torch.utils.data.DataLoader, entity_table: str, early_stop: int = -1, is_regression: bool = False):
    pred_list = []
    y_list = []
    early_stop = early_stop if early_stop > 0 else len(loader.dataset)

    if not is_regression:
        net.eval()

    for idx, batch in tqdm(enumerate(loader), total=len(loader), leave=False, desc="Testing"):
        with torch.no_grad():
            batch = batch.to(device)
            y = batch[entity_table].y.float()
            pred = net(batch, entity_table)
            pred = pred.view(-1) if pred.size(1) == 1 else pred
            
            # apply a sigmoid
            if not is_regression:
                pred = torch.sigmoid(pred)

            pred_list.append(pred.detach().cpu())
            y_list.append(y.detach().cpu())
        if idx > early_stop:
            break

    pred_list = pred_logits = torch.cat(pred_list, dim=0)
    pred_list = torch.sigmoid(pred_list).numpy()
    y_list = torch.cat(y_list, dim=0).numpy()
    pred_list = pred_logits.numpy() if is_regression else pred_list
    return pred_list,  y_list

In [18]:
# init training variables
import math
net.to(device)
patience = 0
best_epoch = 0
best_val_metric = -math.inf if higher_is_better else math.inf
best_model_state = None

In [19]:
import copy
num_epochs = 500
early_stop_threshold = 5
max_round_epoch = 50
lr = 1e-3
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, net.parameters()), lr=lr
)


In [None]:
# loss-step
import time
start_time = time.time()
step_loss = []
val_metrics_log = []
for epoch in range(num_epochs):
    loss_accum = count_accum = 0
    net.train()
    for idx, batch in tqdm(enumerate(data_loader_dict["train"]),
                           leave=False,
                           total=len(data_loader_dict["train"]),
                           desc="Training"):
        if idx > max_round_epoch:
            break
        optimizer.zero_grad()
        batch = batch.to(device)
        y = batch[task.entity_table].y.float()
        pred = net(batch, task.entity_table)
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()

        # record
        step_loss.append(loss.item())
        loss_accum += loss.item()
        count_accum += 1
    train_loss = loss_accum / count_accum
    val_logits, val_pred_hat = test(
        net, data_loader_dict["test"], task.entity_table, early_stop=-1, is_regression=is_regression
    )
    val_metric = evaluate_metric_func(val_pred_hat, val_logits)
    val_metrics_log.append(val_metric)

    # best_val_metric = 0
    if (higher_is_better and val_metric > best_val_metric) or \
       (not higher_is_better and val_metric < best_val_metric):
        best_val_metric = val_metric
        best_epoch = epoch
        best_model_state = copy.deepcopy(net.state_dict())
        patience = 0

        # test_logits,  test_pred_hat = test(
        #     net, data_loader_dict["test"], task.entity_table, is_regression=is_regression)
        # test_metric = evaluate_metric_func(test_pred_hat, test_logits)

        # print(
        #     f"Update the best scores => Test {evaluate_metric_func.__name__} Metric: {test_metric:.6f}")
    else:
        patience += 1
        # test_logits, test_pred_hat = test(
        #     net, data_loader_dict["test"], task.entity_table, is_regression=is_regression)
        # test_metric = evaluate_metric_func(test_pred_hat, test_logits)
        # print(
        #     f"No improvement in epoch {epoch} => Test {evaluate_metric_func.__name__} Metric: {test_metric:.6f}")
        if patience > early_stop_threshold:
            print(f"Early stopping at epoch {epoch}")
            break
    
    print(
        f"==> Epcoh: {epoch} => Train Loss: {train_loss:.6f}, Val {evaluate_metric_func.__name__} Metric: {val_metric:.6f} \t{patience}/{early_stop_threshold}")
    
end_time = time.time()
total_time = end_time - start_time
print(f"Total training time: {total_time:.2f} seconds")

                                                         

==> Epcoh: 0 => Train Loss: 0.877890, Val roc_auc_score Metric: 0.670588 	0/5


                                                         

==> Epcoh: 1 => Train Loss: 0.691973, Val roc_auc_score Metric: 0.657754 	1/5


                                                         

==> Epcoh: 2 => Train Loss: 0.714578, Val roc_auc_score Metric: 0.677674 	0/5


                                                         

==> Epcoh: 3 => Train Loss: 0.667862, Val roc_auc_score Metric: 0.747794 	0/5


                                                         

==> Epcoh: 4 => Train Loss: 0.665311, Val roc_auc_score Metric: 0.787433 	0/5


                                                         

==> Epcoh: 5 => Train Loss: 0.654173, Val roc_auc_score Metric: 0.748864 	1/5


                                                         

==> Epcoh: 6 => Train Loss: 0.653727, Val roc_auc_score Metric: 0.774933 	2/5


                                                         

==> Epcoh: 7 => Train Loss: 0.653966, Val roc_auc_score Metric: 0.767380 	3/5


                                                         

==> Epcoh: 8 => Train Loss: 0.665608, Val roc_auc_score Metric: 0.814505 	0/5


                                                         

==> Epcoh: 9 => Train Loss: 0.685539, Val roc_auc_score Metric: 0.779011 	1/5


                                                         

==> Epcoh: 10 => Train Loss: 0.656779, Val roc_auc_score Metric: 0.793382 	2/5


                                                         

==> Epcoh: 11 => Train Loss: 0.648645, Val roc_auc_score Metric: 0.771858 	3/5


Training:  77%|███████▋  | 24/31 [00:04<00:01,  5.53it/s]

In [16]:
net.load_state_dict(best_model_state)
loader = data_loader_dict["test"]

In [17]:
importance_list = []
attn_weights_list = []
pad_mask_list = []
for _, batch in enumerate(loader):
    batch = batch.to(device)
    tokens, seq_types, attn_weights = net.get_attn_weights(batch, task.entity_table)
    pad_mask = (attn_weights == 0)
    valid_counts = (~pad_mask).sum(dim=1, keepdim=True).clamp_min(1)
    

    uniform = (1.0 / valid_counts).expand_as(attn_weights)
    importance = torch.clamp(attn_weights - uniform, min=0.0)
    
    # filter out valid counts == 1 rows
    valid_counts = valid_counts.squeeze(-1)
    attn_weights = attn_weights[valid_counts > 1]
    importance = importance[valid_counts > 1]
    pad_mask = pad_mask[valid_counts > 1]
    
    # We applied the filtering to remove rows with valid counts == 1
    attn_weights_list.append(attn_weights.cpu())
    pad_mask_list.append(pad_mask.cpu())
    importance_list.append(importance.cpu())

In [None]:
# concat all importance tensors
importances = torch.cat(importance_list, dim=0)
importances = importances.numpy()

pad_masks = torch.cat(pad_mask_list, dim=0)
pad_masks = pad_masks.numpy()

In [None]:
# replace the importances Nan corresponding to pad mask
importances[pad_masks] = float('nan')

In [None]:
seq_types

['customer_self',
 'transactions__f2p_customer_id__customer_max',
 'transactions__f2p_customer_id__customer_min',
 'transactions__f2p_customer_id__customer_sum',
 'transactions__f2p_customer_id__customer_mean']

In [None]:
# save as csv
df = pd.DataFrame(importances, columns = seq_types, index = None)

In [None]:
# order the columns according to the mean
df.mean(axis=0, skipna=True).sort_values(ascending=False)

customer_self                                   5.013261e-01
transactions__f2p_customer_id__customer_min     1.508319e-03
transactions__f2p_customer_id__customer_mean    1.487556e-05
transactions__f2p_customer_id__customer_max     8.915920e-07
transactions__f2p_customer_id__customer_sum     0.000000e+00
dtype: float32

In [None]:
df.to_csv("./aida/logs/importances_hm_user_churn.csv", index=False)