In [1]:
%cd ..

/home/lingze/embedding_fusion


In [2]:
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

import torch
import pickle
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset_ = get_dataset(name = "rel-trial", download = True)
db = dataset_.get_db()
cache_path = "data/rel-trial-tensor-frame"

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.94 seconds.


In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

  return self.fget.__get__(instance, owner)()


In [6]:
# build graph

# start build graph
cache_dir = "./data/rel-trial-tensor-frame"
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )
    
    # Add edges normal edges
    for fkey_col_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_col_name]
        # Filter out dangling foreign keys
        mask = ~pkey_index.isna()
        fkey_index = torch.arange(len(pkey_index))
        
        # filter dangling foreign keys:
        pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
        fkey_index = fkey_index[torch.from_numpy(mask.values)]
        
        # fkey -> pkey edges
        edge_index = torch.stack([fkey_index, pkey_index], dim=0)
        edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)

        # pkey -> fkey edges.
        # "rev_" is added so that PyG loader recognizes the reverse edges
        edge_index = torch.stack([pkey_index, fkey_index], dim=0)
        edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)
    
data.validate()

-----> Materialize interventions Tensor Frame
-----> Materialize interventions_studies Tensor Frame
-----> Materialize facilities_studies Tensor Frame
-----> Materialize sponsors Tensor Frame
-----> Materialize eligibilities Tensor Frame
-----> Materialize reported_event_totals Tensor Frame
-----> Materialize designs Tensor Frame
-----> Materialize conditions_studies Tensor Frame
-----> Materialize drop_withdrawals Tensor Frame
-----> Materialize studies Tensor Frame
-----> Materialize outcome_analyses Tensor Frame
-----> Materialize sponsors_studies Tensor Frame
-----> Materialize outcomes Tensor Frame
-----> Materialize conditions Tensor Frame
-----> Materialize facilities Tensor Frame


True

In [7]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy


In [8]:
task_a = get_task("rel-trial", "condition-sponsor-run", download = True)
src_table = task_a.src_entity_table
dst_table = task_a.dst_entity_table
print(f"src table: {src_table}, dst table: {dst_table}")

src table: conditions, dst table: sponsors


In [9]:
from relbench.modeling.graph import get_link_train_table_input
from relbench.modeling.loader import LinkNeighborLoader

In [10]:
num_neighbors = [128,64]
share_same_time = True

In [11]:
train_table_input = get_link_train_table_input(task_a.get_table("train"), task_a)
train_loader = LinkNeighborLoader(
    data=data,
    num_neighbors=num_neighbors,
    time_attr='time',
    src_nodes = train_table_input.src_nodes,
    dst_nodes = train_table_input.dst_nodes,
    num_dst_nodes=train_table_input.num_dst_nodes,
    src_time=train_table_input.src_time,
    share_same_time=share_same_time,
    batch_size=512,
    temporal_strategy="uniform",
    shuffle= not share_same_time,
    num_workers=0,
    persistent_workers=False,
)

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...


Done in 7.48 seconds.


  dst_node_indices = sparse_coo.to_sparse_csr()


In [12]:
from typing import Tuple
eval_loaders_dict: Dict[str, Tuple[NeighborLoader, NeighborLoader]] = {}
for split in ["val", "test"]:
    timestamp = dataset_.val_timestamp if split == "val" else dataset_.test_timestamp
    seed_time = int(timestamp.timestamp())
    target_table = task_a.get_table(split)
    src_node_indices = torch.from_numpy(target_table.df[task_a.src_entity_col].values)
    src_loader = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        time_attr="time",
        input_nodes=(task_a.src_entity_table, src_node_indices),
        input_time=torch.full(
            size=(len(src_node_indices),), fill_value=seed_time, dtype=torch.long
        ),
        batch_size=512,
        shuffle=False,
        num_workers=0,
    )
    dst_loader = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        time_attr="time",
        input_nodes=task_a.dst_entity_table,
        input_time=torch.full(
            size=(task_a.num_dst_nodes,), fill_value=seed_time, dtype=torch.long
        ),
        batch_size=512,
        shuffle=False,
        num_workers=0,
    )
    eval_loaders_dict[split] = (src_loader, dst_loader)

In [13]:
@torch.no_grad()
def test(model: torch.nn, task: BaseTask, src_loader: NeighborLoader, dst_loader: NeighborLoader) -> np.ndarray:
    model.eval()

    dst_embs: list[Tensor] = []
    for batch in tqdm(dst_loader, leave=False):
        batch = batch.to(device)
        emb = model(batch, task.dst_entity_table).detach()
        dst_embs.append(emb)
    dst_emb = torch.cat(dst_embs, dim=0)
    del dst_embs

    pred_index_mat_list: list[Tensor] = []
    for batch in tqdm(src_loader, leave=False):
        batch = batch.to(device)
        emb = model(batch, task.src_entity_table)
        _, pred_index_mat = torch.topk(emb @ dst_emb.t(), k=task.eval_k, dim=1)
        pred_index_mat_list.append(pred_index_mat.cpu())
    pred = torch.cat(pred_index_mat_list, dim=0).numpy()
    return pred

In [14]:
# construct bottom model
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="sum",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [15]:
lr = 0.001
epoches = 80
tune_metric = "mae"
higher_is_better = True
early_stop = 10
max_round_epoch = 200
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
tune_metric = "link_prediction_map"

In [16]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
patience = 0
for epoch in range(1, epoches):
    net.train()
    loss_accum = count_accum = 0
    cnt = 0
    for batch in tqdm(train_loader, leave=False):
        src_batch, batch_pos_dst, batch_neg_dst = batch
        src_batch = src_batch.to(device)
        batch_pos_dst = batch_pos_dst.to(device)
        batch_neg_dst = batch_neg_dst.to(device)
        
        x_src = net(src_batch, src_table)
        x_pos_dst = net(batch_pos_dst, dst_table)
        x_neg_dst = net(batch_neg_dst, dst_table)
        
        
        pos_score = torch.sum(x_src * x_pos_dst, dim=1)
        if share_same_time:
            neg_score = x_src @ x_neg_dst.t()
            pos_score = pos_score.view(-1, 1)
        else:
            neg_score = torch.sum(x_src * x_neg_dst, dim=1)
        
        optimizer.zero_grad()
        diff_score = pos_score - neg_score
        
        loss = torch.nn.functional.softplus(-diff_score).mean()
        loss.backward()
        
        optimizer.step()
        
        loss_accum += float(loss) * x_src.size(0)
        count_accum += x_src.size(0)
        
        cnt += 1
        if cnt > max_round_epoch:
            break
    
    if count_accum == 0:
        print(
            f"Did not sample a single '{task_a.dst_entity_table}' "
            f"node in any mini-batch. Try to increase the number "
            f"of layers/hops and re-try. If you run into memory "
            f"issues with deeper nets, decrease the batch size."
        )
    train_loss = loss_accum / count_accum
    val_pred = test(net, task_a, eval_loaders_dict["val"][0], eval_loaders_dict["val"][1])
    val_metrics = task_a.evaluate(val_pred, task_a.get_table("val"))
    test_pred = test(net, task_a, eval_loaders_dict["test"][0], eval_loaders_dict["test"][1])
    test_metrics = task_a.evaluate(test_pred)
    
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics[tune_metric]}")
    print(f"Test metrics: {test_metrics[tune_metric]}")
    
    if val_metrics[tune_metric] > best_val_metric:
        best_val_metric = val_metrics[tune_metric]
        best_epoch = epoch
        patience = 0
        state_dict = copy.deepcopy(net.state_dict())
    else:
        patience += 1
        if patience >= early_stop:
            break


  0%|          | 0/62 [00:00<?, ?it/s]

                                                 

Epoch: 01, Train loss: 0.36909774378422766, Val metrics: 0.01710755570047977
Test metrics: 0.014183734405914188


                                                 

Epoch: 02, Train loss: 0.228062704205513, Val metrics: 0.01880277286221944
Test metrics: 0.014806334609142095


                                                 

Epoch: 03, Train loss: 0.21489714590772505, Val metrics: 0.008363026907596825
Test metrics: 0.005988504217113843


                                                 

Epoch: 04, Train loss: 0.21768812931353046, Val metrics: 0.006104222519092888
Test metrics: 0.00626817255827951


                                                 

Epoch: 05, Train loss: 0.21134365734554106, Val metrics: 0.0066324124289044994
Test metrics: 0.006363088164558753


                                                 

Epoch: 06, Train loss: 0.20842920604252047, Val metrics: 0.004541106750066499
Test metrics: 0.004658906032569134


                                                 

Epoch: 07, Train loss: 0.20573318941939261, Val metrics: 0.005020771866152979
Test metrics: 0.004205624704733439


                                                 

Epoch: 08, Train loss: 0.20491278471965943, Val metrics: 0.003559088824986631
Test metrics: 0.0039459401758867


                                                

KeyboardInterrupt: 

In [None]:
net.load_state_dict(state_dict)
test_pred = test(net, task_a, eval_loaders_dict["test"][0], eval_loaders_dict["test"][1])
test_metrics = task_a.evaluate(test_pred)
print(f"Best epoch: {best_epoch}, Test metrics: {test_metrics[tune_metric]}")

                                                 

Best epoch: 5, Test metrics: 0.004980636193871488


