In [1]:
%cd ..

/home/lingze/embedding_fusion


In [2]:
import relbench
from relbench.base import Table, Database, Dataset, EntityTask
from relbench.datasets import get_dataset
from relbench.tasks import get_task
from relbench.base import BaseTask
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.modeling.graph import get_node_train_table_input

import os
import math
import numpy as np
from tqdm import tqdm
import copy

import torch
from torch import Tensor
import torch_geometric
import torch_frame

from torch_frame.config.text_embedder import TextEmbedderConfig
from typing import List, Optional
from sentence_transformers import SentenceTransformer

from torch.nn import L1Loss, BCEWithLogitsLoss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [3]:
dataset = get_dataset(name="rel-avito", download=True)
db = dataset.get_db()

Loading Database object from /home/lingze/.cache/relbench/rel-avito/db...
Done in 5.57 seconds.


In [4]:
col_to_stype_dict = get_stype_proposal(db)

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)


  return self.fget.__get__(instance, owner)()


In [5]:
root_dir = "/home/lingze/embedding_fusion/data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-avito_materialized_cache"
    ),  # store materialized graph for convenience
)

In [6]:
task = get_task("rel-avito", "user-ad-visit", download=True)

In [7]:
from relbench.modeling.graph import get_link_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader
from relbench.modeling.loader import LinkNeighborLoader

In [8]:
neighbors = 128
batch_size = 512
n = 30_000
train_table_input = get_link_train_table_input(task.get_table("train"), task)
train_loader = LinkNeighborLoader(
    data = data,
    num_neighbors = neighbors,
    time_attr="time",
    src_nodes = train_table_input.src_nodes,
    dst_nodes=train_table_input.dst_nodes,
    num_dst_nodes=train_table_input.num_dst_nodes,
    src_time=train_table_input.src_time,
    share_same_time=False,
    batch_size=batch_size,
    temporal_strategy="uniform",
    shuffle= True,
    num_workers=0,
)
# train_loader = LinkNeighborLoader(
#     data = data,
#     num_neighbors = neighbors,
#     time_attr="time",
#     src_nodes = (train_table_input.src_nodes[0], train_table_input.src_nodes[1][:n]),
#     dst_nodes=(train_table_input.dst_nodes[0], train_table_input.dst_nodes[1][:n]),
#     num_dst_nodes=train_table_input.num_dst_nodes,
#     src_time=train_table_input.src_time,
#     share_same_time=False,
#     batch_size=batch_size,
#     temporal_strategy="uniform",
#     shuffle= True,
#     num_workers=0,
# )

Loading Database object from /home/lingze/.cache/relbench/rel-avito/db...
Done in 6.31 seconds.


  dst_node_indices = sparse_coo.to_sparse_csr()


In [9]:
# val and test loader
eval_loaders_dict = {}
for split in ["val", "test"]:
    ts = dataset.val_timestamp if split == "val" else dataset.test_timestamp
    seed_time = int(ts.timestamp())
    target_table =  task.get_table(split)
    src_node_indices = torch.from_numpy(target_table.df[task.src_entity_col].values)
    
    src_loader = NeighborLoader(
        data,
        num_neighbors = neighbors,
        time_attr = "time",
        input_nodes=(task.src_entity_table, src_node_indices),
        input_time = torch.full(
            size=(len(src_node_indices),), fill_value = seed_time, dtype=torch.long
        ),
        batch_size = batch_size,
        shuffle=False,
        num_workers = 0,
    )
    
    dst_loader = NeighborLoader(
        data,
        num_neighbors= neighbors,
        time_attr = "time",
        input_nodes = task.dst_entity_table,
        input_time = torch.full(
            size = (task.num_dst_nodes,), fill_value = seed_time, dtype = torch.long
        ),
        batch_size = batch_size,
        shuffle = False,
        num_workers = 0,
    )
    eval_loaders_dict[split] = (src_loader, dst_loader)

In [10]:
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder

In [11]:
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
            node_type for node_type in data.node_types if "time" in data[node_type]
        ],
        channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data = data,
    node_to_col_stats=col_stats_dict,
    channels = channels,
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)

model =  CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 15

In [None]:
from relbench.metrics import link_prediction_map, link_prediction_precision, link_prediction_recall

In [None]:
model.to(device)

eval_epochs_interval = 0
max_step_per_epoch = 100
total_steps = min(len(train_loader), max_step_per_epoch)
eval_k = 10

for epoch in range(1, epochs+1):
    model.train()
    loss_accum = count_accum = 0
    steps = 0
    for batch in tqdm(train_loader, total = total_steps):
        optimizer.zero_grad()

        src_batch, batch_pos_dst, batch_neg_dst = batch
        src_batch, batch_pos_dst, batch_neg_dst = (
            src_batch.to(device),
            batch_pos_dst.to(device),
            batch_neg_dst.to(device),
        )
        
        x_src = model(src_batch, task.src_entity_table)
        x_pos_dst = model(batch_pos_dst, task.dst_entity_table)
        x_neg_dst = model(batch_neg_dst, task.dst_entity_table)
        
        # [batch_size, ]
        pos_score = torch.sum(x_src*x_pos_dst, dim = 1)
        neg_score = torch.sum(x_src*x_neg_dst, dim = 1)
        
        diff_score = pos_score - neg_score
        loss = torch.nn.functional.softplus(-diff_score).mean()
        loss.backward()
        
        optimizer.step()
        
        loss_accum += float(loss) * x_src.size(0)
        count_accum += x_src.size(0)
        
        steps += 1 
        if steps >= total_steps:
            break

    train_loss = loss_accum / count_accum if count_accum > 0 else float("nan")
    
    if epoch % eval_epochs_interval == 0:
        model.eval()
        dst_embs:list[Tensor] = []
        src_loader, dst_loader = eval_loaders_dict["val"]
        for batch in tqdm(dst_loader):
            batch = batch.to(device)
            emb = model(batch, task.dst_entity_table).detach()
            dst_embs.append(emb)
        dst_emb = torch.cat(dst_embs, dim = 0)
        del dst_embs
        
        pred_index_mat_list: list[Tensor] = []
        for batch in tqdm(src_loader):
            batch = batch.to(device)
            emb = model(batch, task.src_entity_table).detach()
            _, pred_index_mat = torch.topk(emb @ dst_emb.t(), k = 10, dim = 1)
            pred_index_mat_list.append(pred_index_mat)
        
        pred = torch.cat(pred_index_mat_list, dim = 0).numpy()
        val_table = task.get_table("val", mask_input_cols=False)
        
        expect_pred_shape = (len(val_table), eval_k)
        assert pred.shape == expect_pred_shape, f"Expected shape {expect_pred_shape}, got {pred.shape}"
        
        pred_isin_list = []
        dst_count_list = []
        for true_dst_nodes, pred_dst_nodes in zip(
            val_table.df[task.dst_entity_col], pred
        ):
            pred_isin_list.append(
                np.isin(np.array(pred_dst_nodes), np.array(true_dst_nodes))
            )
            dst_count_list.append(len(true_dst_nodes))
        pred_isin = np.stack(pred_isin_list)
        dst_count = np.array(dst_count_list)

        val_metrics = {
            "map": link_prediction_map(pred_isin, dst_count),
            "precision": link_prediction_precision(pred_isin, dst_count),
            "recall": link_prediction_recall(pred_isin, dst_count),
        }
        
    