In [1]:
%cd ..
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

import torch
import pickle
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index
import time

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = get_dataset(name = "rel-avito", download = True)
db = dataset.get_db()
cache_path = "data/rel-avito-tensor-frame"

Loading Database object from /home/lingze/.cache/relbench/rel-avito/db...
Done in 4.60 seconds.


In [3]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [4]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

  return self.fget.__get__(instance, owner)()


In [5]:
# build graph

# start build graph
cache_dir = cache_path
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )
    
    # Add edges normal edges
    for fkey_col_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_col_name]
        # Filter out dangling foreign keys
        mask = ~pkey_index.isna()
        fkey_index = torch.arange(len(pkey_index))
        
        # filter dangling foreign keys:
        pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
        fkey_index = fkey_index[torch.from_numpy(mask.values)]
        
        # fkey -> pkey edges
        edge_index = torch.stack([fkey_index, pkey_index], dim=0)
        edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)

        # pkey -> fkey edges.
        # "rev_" is added so that PyG loader recognizes the reverse edges
        edge_index = torch.stack([pkey_index, fkey_index], dim=0)
        edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)
    
data.validate()

-----> Materialize PhoneRequestsStream Tensor Frame
-----> Materialize Location Tensor Frame
-----> Materialize SearchInfo Tensor Frame
-----> Materialize UserInfo Tensor Frame
-----> Materialize SearchStream Tensor Frame
-----> Materialize VisitStream Tensor Frame
-----> Materialize AdsInfo Tensor Frame
-----> Materialize Category Tensor Frame


True

In [6]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy


In [7]:
# add the additional edges
from utils.util import load_np_dict
edge_dict = load_np_dict("./edges/rel-avito-edges.npz")

In [8]:
for edge_name, edge_np in edge_dict.items():
    src_table, dst_table = edge_name.split('-')[0], edge_name.split('-')[1]
    edge_index = torch.from_numpy(edge_np.astype(int)).t()
    # [2, edge_num]
    edge_type = (src_table, f"appendix", dst_table)
    data[edge_type].edge_index = sort_edge_index(edge_index)
data.validate()

True

In [9]:
# read the pre-extracted sample
from utils.util import load_np_dict
sample_dict = load_np_dict("./samples/rel-avito-samples.npz")
sample_dict.keys()

dict_keys(['Location', 'UserInfo', 'Category'])

In [10]:
from relbench.base import Database
def neighborsample_batch(
    db: Database,
    entity_table: str,
    node_idxs: np.ndarray,
    num_neighbors: List[int] = [128,128],
):
    # node_idxs: [n]
    nodes = (entity_table, torch.from_numpy(node_idxs))
    n = node_idxs.shape[0]
    input_time = torch.from_numpy(
        to_unix_time(pd.Series([db.max_timestamp] * n)))

    if db.table_dict[entity_table].time_col:
        time_col = db.table_dict[entity_table].time_col
        time_values = db.table_dict[entity_table].df[time_col].loc[node_idxs.tolist(
        )]
        input_time = torch.from_numpy(to_unix_time(time_values))

    loader = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        input_nodes=nodes,
        time_attr = "time",
        input_time=input_time,
        batch_size=n,
        temporal_strategy="uniform",
        shuffle=False,
        disjoint=True,
        num_workers=0,
        persistent_workers=False,
    )
    return next(iter(loader))
    

In [12]:
# construct bottom model
channels = 128
args = {
    "channels": channels,
    "num_layers": 2,
    "dropout_prob": 0.2,
}

temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=args["channels"],
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=args["channels"],
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=args["channels"],
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.3
)

net = CompositeModel(
    data=data,
    channels=args["channels"],
    out_channels=1,
    dropout=0.3,
    aggr="mean",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)
# net.reset_parameters()


# if torch.cuda.device_count() > 1:
#     print("Let's use", torch.cuda.device_count(), "GPUs!")
#   # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
#   # net = torch.nn.DataParallel(net)
#     net = DataParallel(net)

In [13]:
for sample_table, sample_np in sample_dict.items():
    print(f"Start to train on {sample_table}")
    print(f"shape of sample_np: {sample_np.shape}")

Start to train on Location
shape of sample_np: (162, 21)
Start to train on UserInfo
shape of sample_np: (17615, 21)
Start to train on Category
shape of sample_np: (6, 21)


In [14]:
negative_num_dict = {
    "Location" : 10,
    "UserInfo": 20,
    "Category": 10,
}

In [15]:
from model.utils import InfoNCE
lr = 5e-4
negative_sample_pool_size = 512
temprature = 0.01
net.reset_parameters()
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
epoches = 20
early_restart_steps = 20
batch_size = 256
max_steps_in_epoch = 30
loss_fn = InfoNCE(temperature=temprature, negative_mode='paired')

In [16]:
net.to(device)
for epoch in range(1, epoches + 1):
    net.train()
    print("*"*30 + f"<Epoch: {epoch:02d}>" + "*"*30)
    for sample_table, sample_np in sample_dict.items():
        loss_accum = count_accum = 0
        shuffle_sample_np = sample_np[np.random.permutation(len(sample_np))]
        anchor_nodes_np = shuffle_sample_np[:, 0]
        positive_pool_np = shuffle_sample_np[:, 1:]
        # choose the positive samples
        n = sample_np.shape[0]
        negative_num = negative_num_dict[sample_table]
        m = len(db.table_dict[sample_table].df)
        now = time.time()
        cnt = 0
        for batch_idx in tqdm(range(0, n, batch_size), leave=False):
            cnt += 1
            if cnt > max_steps_in_epoch:
                break
            anchor_nodes = anchor_nodes_np[batch_idx:batch_idx+batch_size]
            positive_pool_batch_np = positive_pool_np[batch_idx:batch_idx+batch_size]
            positive_nodes = []
            # random select the positive samples
            for row in positive_pool_batch_np:
                valid = row[row != -1]
                random_choice = np.random.choice(valid, 1)[0]
                positive_nodes.append(random_choice)

            positive_nodes = np.array(positive_nodes)
            B = positive_nodes.size
            # random select the negative sample, negative ratio is 1:20
            # for one batch, we still extract batch_size negative samples
            # for each positive-negative pair, we extract 20 from this 256 batch as negative samples
            excluded = set(positive_nodes.tolist()).union(
                set(anchor_nodes.tolist()))
            negative_candidates = list(set(range(m)) - excluded)
            # print(negative_candidates)
            
            sample_size = min(negative_sample_pool_size, len(negative_candidates))
            
            # if sample_size < positive_nodes.size:
            #     # special case, for those number of positive pairs is too small, 
            #     # we employ the pure random to select the negative samples
            #     negative_candidates = list(range(m))
            #     sample_size = positive_nodes.size
                
            #     print("==> Candidate:" + str(negative_candidates))
            
            negative_nodes = np.random.choice(
                negative_candidates, size=sample_size, replace=True)
            # [batch_size]
            # print(negative_nodes.shape)
            # print(B)
            # neighbor hood loader
            anchor_nodes_batch = neighborsample_batch(
                db, sample_table, anchor_nodes)
            positive_nodes_batch = neighborsample_batch(
                db, sample_table, positive_nodes)
            negative_nodes_batch = neighborsample_batch(
                db, sample_table, negative_nodes)

            optimizer.zero_grad()

            anchor_nodes_batch, positive_nodes_batch, negative_nodes_batch = \
                anchor_nodes_batch.to(device), positive_nodes_batch.to(
                    device), negative_nodes_batch.to(device)

            anchor_nodes_embedding = net.get_node_embedding(
                anchor_nodes_batch, sample_table)[sample_table][:B]
            positive_nodes_embedding = net.get_node_embedding(
                positive_nodes_batch, sample_table)[sample_table][:B]
            negative_nodes_embedding = net.get_node_embedding(
                negative_nodes_batch, sample_table)[sample_table][:sample_size]

            # negative_nodes_embedding = net.get_node_embedding(negative_nodes_batch, sample_table)[sample_table][:B]
            # [B, D]
            
            negative_indices = torch.stack([torch.randperm(sample_size)[
                                           :negative_num] for _ in range(B)]).to(device)
            negative_nodes_embedding = negative_nodes_embedding[negative_indices]
            # [B, negative_num, D]

            loss = loss_fn(anchor_nodes_embedding,
                           positive_nodes_embedding, negative_nodes_embedding)
            loss.backward()
            optimizer.step()
            loss_accum += loss.detach().item()
            count_accum += B
        end = time.time()
        mins, secs = divmod(end - now, 60)
        train_loss = loss_accum / count_accum
        
        print(f"====> In {sample_table}, Train loss: {train_loss}, Cost Time {mins:.0f}m {secs:.0f}s")

******************************<Epoch: 01>******************************


  0%|          | 0/1 [00:00<?, ?it/s]

                                             

====> In Location, Train loss: 0.08918241806972174, Cost Time 0m 13s


                                               

====> In UserInfo, Train loss: 0.024141385747740666, Cost Time 5m 21s


                                             

====> In Category, Train loss: 2.634373505910238, Cost Time 0m 11s
******************************<Epoch: 02>******************************


                                             

====> In Location, Train loss: 0.07407037122749988, Cost Time 0m 12s


                                               

====> In UserInfo, Train loss: 0.009554343406731884, Cost Time 5m 18s


                                             

====> In Category, Train loss: 1.6830333073933919, Cost Time 0m 11s
******************************<Epoch: 03>******************************


                                             

====> In Location, Train loss: 0.06527979579972631, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.006813572781781355, Cost Time 5m 22s


                                             

====> In Category, Train loss: 1.0626503626505535, Cost Time 0m 10s
******************************<Epoch: 04>******************************


                                             

====> In Location, Train loss: 0.060166376608389395, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.005733799515292048, Cost Time 5m 17s


                                             

====> In Category, Train loss: 1.5939575831095378, Cost Time 0m 11s
******************************<Epoch: 05>******************************


                                             

====> In Location, Train loss: 0.05451224762716411, Cost Time 0m 10s


                                               

====> In UserInfo, Train loss: 0.0049683120256910724, Cost Time 5m 24s


                                             

====> In Category, Train loss: 1.6488850911458333, Cost Time 0m 11s
******************************<Epoch: 06>******************************


                                             

====> In Location, Train loss: 0.04804271827509374, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.004513030460414787, Cost Time 5m 36s


                                             

====> In Category, Train loss: 2.2902706464131675, Cost Time 0m 12s
******************************<Epoch: 07>******************************


                                             

====> In Location, Train loss: 0.04503394056249548, Cost Time 0m 13s


                                               

====> In UserInfo, Train loss: 0.004201838583685458, Cost Time 5m 47s


                                             

====> In Category, Train loss: 2.034142812093099, Cost Time 0m 12s
******************************<Epoch: 08>******************************


                                             

====> In Location, Train loss: 0.03845165393970631, Cost Time 0m 13s


                                               

====> In UserInfo, Train loss: 0.004059394417951504, Cost Time 5m 32s


                                             

====> In Category, Train loss: 1.7021479606628418, Cost Time 0m 11s
******************************<Epoch: 09>******************************


                                             

====> In Location, Train loss: 0.038684059072423865, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.0038486236628765863, Cost Time 4m 59s


                                             

====> In Category, Train loss: 1.6874440511067708, Cost Time 0m 10s
******************************<Epoch: 10>******************************


                                             

====> In Location, Train loss: 0.034639611656283154, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.003654822101816535, Cost Time 5m 14s


                                             

====> In Category, Train loss: 1.0534934997558594, Cost Time 0m 10s
******************************<Epoch: 11>******************************


                                             

====> In Location, Train loss: 0.038157015670964745, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.0034491363214328883, Cost Time 5m 12s


                                             

====> In Category, Train loss: 1.524783452351888, Cost Time 0m 11s
******************************<Epoch: 12>******************************


                                             

====> In Location, Train loss: 0.031006783614923924, Cost Time 0m 10s


                                               

====> In UserInfo, Train loss: 0.0034276931391408047, Cost Time 5m 12s


                                             

====> In Category, Train loss: 1.5290369987487793, Cost Time 0m 12s
******************************<Epoch: 13>******************************


                                             

====> In Location, Train loss: 0.030287792653213314, Cost Time 0m 14s


                                               

====> In UserInfo, Train loss: 0.003338953270576894, Cost Time 5m 57s


                                             

====> In Category, Train loss: 0.9579149087270101, Cost Time 0m 13s
******************************<Epoch: 14>******************************


                                             

====> In Location, Train loss: 0.03347723572342484, Cost Time 0m 13s


                                               

====> In UserInfo, Train loss: 0.003075275290757418, Cost Time 5m 35s


                                             

====> In Category, Train loss: 1.1142091751098633, Cost Time 0m 11s
******************************<Epoch: 15>******************************


                                             

====> In Location, Train loss: 0.032191041075153116, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.0030806957200790446, Cost Time 5m 39s


                                             

====> In Category, Train loss: 1.6216039657592773, Cost Time 0m 11s
******************************<Epoch: 16>******************************


                                             

====> In Location, Train loss: 0.03176107818697706, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.0030247592522452274, Cost Time 5m 17s


                                             

====> In Category, Train loss: 0.8297781149546305, Cost Time 0m 10s
******************************<Epoch: 17>******************************


                                             

====> In Location, Train loss: 0.026165635497481736, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.0028386797367905576, Cost Time 5m 5s


                                             

====> In Category, Train loss: 0.39744122823079425, Cost Time 0m 10s
******************************<Epoch: 18>******************************


                                             

====> In Location, Train loss: 0.027041576526783132, Cost Time 0m 10s


                                               

====> In UserInfo, Train loss: 0.0028791360634689528, Cost Time 5m 13s


                                             

====> In Category, Train loss: 1.1440845330556233, Cost Time 0m 11s
******************************<Epoch: 19>******************************


                                             

====> In Location, Train loss: 0.02837402143596131, Cost Time 0m 10s


                                               

====> In UserInfo, Train loss: 0.0027808332039664188, Cost Time 5m 18s


                                             

====> In Category, Train loss: 1.058437665303548, Cost Time 0m 11s
******************************<Epoch: 20>******************************


                                             

====> In Location, Train loss: 0.027290547335589374, Cost Time 0m 11s


                                               

====> In UserInfo, Train loss: 0.0026335918887828787, Cost Time 5m 18s


                                             

====> In Category, Train loss: 0.8737351099650065, Cost Time 0m 11s




In [17]:
pre_trained_state = copy.deepcopy(net.state_dict())
import torch
import json
torch.save(pre_trained_state, "./static/rel-avito-pre-trained-channel128.pth")