In [1]:
%cd ..
from tqdm import tqdm
from utils.data import StackDataset
import numpy as np
import torch
import pickle
import os

from torch_geometric.data import HeteroData


device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = StackDataset(cache_dir="/home/lingze/.cache/relbench/stack")
db = dataset.get_db()

Loading Database object from /home/lingze/.cache/relbench/stack/db...
Done in 10.11 seconds.


In [3]:
cache_path = "./data/stack-tensor-frame/"

In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from utils.resource import get_text_embedder_cfg
text_embedder_cfg = get_text_embedder_cfg(
    model_name = "sentence-transformers/average_word_embeddings_glove.6B.300d", 
    device = torch.device("cpu")
)

  return self.fget.__get__(instance, owner)()


In [6]:
from utils.builder import build_pyg_hetero_graph
data, col_stats_dict = build_pyg_hetero_graph(
    db,
    col_type_dict,
    text_embedder_cfg,
    cache_path,
    True,
)

-----> Materialize tags Tensor Frame
-----> Materialize postHistory Tensor Frame
-----> Materialize comments Tensor Frame
-----> Materialize badges Tensor Frame
-----> Build edge between posts and tags
-----> Materialize users Tensor Frame
-----> Materialize postLinks Tensor Frame
-----> Materialize votes Tensor Frame
-----> Materialize posts Tensor Frame


In [7]:
# add new edges:
from utils.util import load_np_dict
from torch_geometric.utils import sort_edge_index
edge_dict = load_np_dict("./edges/rel-stack-edges.npz")

for edge_name, edge_np in edge_dict.items():
    src_table, dst_table = edge_name.split('-')[0], edge_name.split('-')[1]
    edge_index = torch.from_numpy(edge_np.astype(int)).t()
    # [2, edge_num]
    edge_type = (src_table, f"appendix", dst_table)
    data[edge_type].edge_index = sort_edge_index(edge_index)
data.validate()

True

In [8]:
# read the pre-extracted sample
from utils.util import load_np_dict
sample_dict = load_np_dict("./samples/rel-stack-samples.npz")
sample_dict.keys()

dict_keys(['tags', 'badges', 'users', 'posts'])

In [9]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy
from relbench.modeling.utils import to_unix_time
from torch_geometric.loader import NeighborLoader
import pandas as pd

from typing import List, Dict, Any, Tuple
from relbench.base import Database

In [10]:
def neighborsample_batch(
    db: Database,
    entity_table: str,
    node_idxs: np.ndarray,
    num_neighbors: List[int] = [64,64],
):
    # node_idxs: [n]
    nodes = (entity_table, torch.from_numpy(node_idxs))
    n = node_idxs.shape[0]
    input_time = torch.from_numpy(
        to_unix_time(pd.Series([db.max_timestamp] * n)))

    if db.table_dict[entity_table].time_col:
        time_col = db.table_dict[entity_table].time_col
        time_values = db.table_dict[entity_table].df[time_col].loc[node_idxs.tolist(
        )]
        input_time = torch.from_numpy(to_unix_time(time_values))

    loader = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        input_nodes=nodes,
        time_attr = "time",
        input_time=input_time,
        batch_size=n,
        temporal_strategy="uniform",
        shuffle=False,
        disjoint=True,
        num_workers=0,
        persistent_workers=False,
    )
    return next(iter(loader))
    

In [11]:
# construct bottom model
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.4
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.4,
    aggr="sum",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [12]:
from model.utils import InfoNCE
import time
import random
lr = 0.0005
negative_sample_pool_size = 512
temprature = 0.01
net.reset_parameters()
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
epoches = 300
batch_size = 128
max_steps_in_epoch = 10
negative_num = 20
loss_fn = InfoNCE(temperature=temprature, negative_mode='paired', reduction='mean')

In [None]:
net.to(device)
best_loss = math.inf
best_state = None
patience = 0
early_stop = 20
n_tables = len(sample_dict)
tables = list(sample_dict.keys())
for epoch in range(1, epoches + 1):
    net.train()
    ave_loss = 0
    print("*"*30 + f"<Epoch: {epoch:02d}>" + "*"*30)
    random.shuffle(tables)
    for sample_table in tables:
        
        sample_np = sample_dict[sample_table]
        loss_accum = count_accum = 0
        shuffle_sample_np = sample_np[np.random.permutation(len(sample_np))]
        anchor_nodes_np = shuffle_sample_np[:, 0]
        positive_pool_np = shuffle_sample_np[:, 1:]
        # choose the positive samples
        n = sample_np.shape[0]

        m = len(db.table_dict[sample_table].df)
        now = time.time()
        cnt = 0
        for batch_idx in tqdm(range(0, n, batch_size), leave=False):
            cnt += 1
            if cnt > max_steps_in_epoch:
                break
            anchor_nodes = anchor_nodes_np[batch_idx:batch_idx+batch_size]
            positive_pool_batch_np = positive_pool_np[batch_idx:batch_idx+batch_size]
            positive_nodes = []
            # random select the positive samples
            for row in positive_pool_batch_np:
                valid = row[row != -1]
                random_choice = np.random.choice(valid, 1)[0]
                positive_nodes.append(random_choice)

            positive_nodes = np.array(positive_nodes)
            B = positive_nodes.size
            # random select the negative sample, negative ratio is 1:20
            # for one batch, we still extract batch_size negative samples
            # for each positive-negative pair, we extract 20 from this 256 batch as negative samples
            excluded = set(positive_nodes.tolist()).union(
                set(anchor_nodes.tolist()))
            negative_candidates = list(set(range(m)) - excluded)
            # print(negative_candidates)

            sample_size = min(negative_sample_pool_size,
                              len(negative_candidates))

            negative_nodes = np.random.choice(
                negative_candidates, size=sample_size, replace=True)
            # [batch_size]
            # print(negative_nodes.shape)
            # print(B)
            # neighbor hood loader
            anchor_nodes_batch = neighborsample_batch(
                db, sample_table, anchor_nodes)
            positive_nodes_batch = neighborsample_batch(
                db, sample_table, positive_nodes)
            negative_nodes_batch = neighborsample_batch(
                db, sample_table, negative_nodes)

            optimizer.zero_grad()

            anchor_nodes_batch, positive_nodes_batch, negative_nodes_batch = \
                anchor_nodes_batch.to(device), positive_nodes_batch.to(
                    device), negative_nodes_batch.to(device)

            anchor_nodes_embedding = net.get_node_embedding(
                anchor_nodes_batch, sample_table)[sample_table][:B]
            positive_nodes_embedding = net.get_node_embedding(
                positive_nodes_batch, sample_table)[sample_table][:B]
            negative_nodes_embedding = net.get_node_embedding(
                negative_nodes_batch, sample_table)[sample_table][:sample_size]

            # negative_nodes_embedding = net.get_node_embedding(negative_nodes_batch, sample_table)[sample_table][:B]
            # [B, D]

            negative_indices = torch.stack([torch.randperm(sample_size)[
                                           :negative_num] for _ in range(B)]).to(device)
            negative_nodes_embedding = negative_nodes_embedding[negative_indices]
            # [B, negative_num, D]

            loss = loss_fn(anchor_nodes_embedding,
                           positive_nodes_embedding, negative_nodes_embedding)
            loss.backward()
            optimizer.step()
            loss_accum += loss.detach().item() * B
            count_accum += B
        
        end = time.time()
        train_loss = loss_accum / count_accum
        ave_loss += train_loss
        mins, secs = divmod(end - now, 60)
        print(
            f"====> In {sample_table}, Train loss: {train_loss} Count accum :{count_accum}, Cost Time {mins:.0f}m {secs:.0f}s")

    ave_loss /= n_tables
    if ave_loss < best_loss:
        best_loss = ave_loss
        best_state = copy.deepcopy(net.state_dict())
        print(f"Save best model at epoch {epoch} with loss {ave_loss}")
        patience = 0
    else:
        patience += 1
        if patience >= early_stop:
            print(f"Early stopping at epoch {epoch}")
            break

In [None]:
# pre-trained state
# record
pre_trained_state = copy.deepcopy(net.state_dict())

In [None]:
import torch
import json
torch.save(pre_trained_state, "./static/rel-stack-pre-trained-channel128-ep100.pth")

In [None]:
import torch
import json
torch.save(best_state, "./static/rel-stack-pre-trained-channel128-ep100-best-state.pth")