In [1]:
%cd ..

/home/lingze/embedding_fusion


In [2]:
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

In [4]:
import torch
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index
from utils.data import StackDataset

In [5]:
dataset = get_dataset('rel-trial', download=True)
db = dataset.get_db()

Loading Database object from /home/lingze/.cache/relbench/stack/db...
Done in 9.84 seconds.


In [6]:
from utils.preprocess import infer_type_in_db
col_type_dict = infer_type_in_db(db, True)

[rule 0]: tags Inferred Id from numerical as categorical
[rule 0]: postHistory Inferred Id from numerical as categorical
[rule 0]: postHistory Inferred PostId from numerical as categorical
[rule 0]: postHistory Inferred UserId from numerical as categorical
[rule 0]: postHistory Inferred PostHistoryTypeId from numerical as categorical
[rule 0]: postHistory Inferred ContentLicense from categorical as text_embedded
[rule 1]: postHistory Inferred ContentLicense from text_embedded as categorical
[rule 0]: postHistory Inferred RevisionGUID from text_embedded as categorical
[rule 0]: comments Inferred Id from numerical as categorical
[rule 0]: comments Inferred PostId from numerical as categorical
[rule 0]: comments Inferred UserId from numerical as categorical
[rule 1]: comments Inferred Score from numerical as categorical
[rule 0]: comments Inferred ContentLicense from categorical as text_embedded
[rule 1]: comments Inferred ContentLicense from text_embedded as categorical
[rule 0]: badges 

In [7]:
# print text embedding type, profile
for table, type_dict in col_type_dict.items():
    for col_name, stype in type_dict.items():
        if stype == stype.text_embedded:
            unique_value = db.table_dict[table].df[col_name].unique()
            n = len(unique_value)
            nm = len(db.table_dict[table].df)
            nan_num = db.table_dict[table].df[col_name].isnull().sum()
            print(f"{table}.{col_name}: {stype}, unique values: {n}/{nm} Nan Value: {nan_num}/{nm}")

tags.TagName: text_embedded, unique values: 1597/1597 Nan Value: 0/1597
postHistory.UserDisplayName: text_embedded, unique values: 3061/1175368 Nan Value: 1150131/1175368
postHistory.Text: text_embedded, unique values: 960746/1175368 Nan Value: 114938/1175368
postHistory.Comment: text_embedded, unique values: 125619/1175368 Nan Value: 714309/1175368
comments.UserDisplayName: text_embedded, unique values: 2222/623967 Nan Value: 612281/623967
comments.Text: text_embedded, unique values: 621044/623967 Nan Value: 0/623967
badges.Name: text_embedded, unique values: 327/463463 Nan Value: 0/463463
users.DisplayName: text_embedded, unique values: 218050/255360 Nan Value: 19/255360
users.Location: text_embedded, unique values: 11514/255360 Nan Value: 184185/255360
users.AboutMe: text_embedded, unique values: 47461/255360 Nan Value: 205676/255360
posts.OwnerDisplayName: text_embedded, unique values: 4531/333893 Nan Value: 325346/333893
posts.Title: text_embedded, unique values: 163648/333893 Nan

In [8]:
def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

In [9]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "all-MiniLM-L12-v2",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

In [8]:
# preprocess the table, concatenate the columns which is text type
#        /--- text_col_1 ---/ --- text_col_2 --- / --- text_col_3 --- / 
# row 1  /------- A   -----/ ------- B   -----  / -----   C  ------- /
# -------> Generate a new TexT column
# "text_col_1 is A, text_col_2 is B, text_col_3 is C"

# Therefore, we only need to convert this text column to vector 
# and drop the original text columns
# for saving memory and computation 


for table_name, type_dict in col_type_dict.items():
    # collect the text columns
    text_cols = [ col for col, stype in type_dict.items() if stype == stype.text_embedded]
    compress_cols = []
    # for long text, we still keep it as one column
    for col in text_cols:
        avg_word_count = db.table_dict[table_name].df[col].dropna().apply(lambda x: len(str(x).split())).mean()
        if avg_word_count < 128: # a half of default max length of BERT Max length （256）
            # remove the long text cols
            compress_cols.append(col)
          
    
    if len(compress_cols) <= 1:
        # if only one text column, we do not need to compress
        continue
    
    print(f"----> Compressing {table_name} text columns: {compress_cols}")
    
    df = db.table_dict[table_name].df
    compress_text_df = df[compress_cols]
    
    def row_to_text(row):
        if row.isna().all():
            return None
        tokens = [f"{key} is {value}" for key, value in row.dropna().items()]
        return ", ".join(tokens)

    text_list = compress_text_df.apply(row_to_text, axis=1).tolist()
    
    # drop the compressed columns
    df.drop(columns=compress_cols, inplace=True)
    df["text_compress"] = text_list
    
    # update the type dict
    for col in compress_cols:
        type_dict.pop(col)
    type_dict["text_compress"] = stype.text_embedded

----> Compressing postHistory text columns: ['UserDisplayName', 'Text', 'Comment']
----> Compressing comments text columns: ['UserDisplayName', 'Text']
----> Compressing users text columns: ['DisplayName', 'Location', 'AboutMe']
----> Compressing posts text columns: ['OwnerDisplayName', 'Title']


In [9]:
# print text embedding type
for table, type_dict in col_type_dict.items():
    for col_name, stype in type_dict.items():
        if stype == stype.text_embedded:
            unique_value = db.table_dict[table].df[col_name].unique()
            n = len(unique_value)
            nm = len(db.table_dict[table].df)
            nan_num = db.table_dict[table].df[col_name].isnull().sum()
            print(f"{table}.{col_name}: {stype} --> unique values: {n}/{nm} || Nan Value: {nan_num}/{nm}")
            

tags.TagName: text_embedded --> unique values: 1597/1597 || Nan Value: 0/1597
postHistory.text_compress: text_embedded --> unique values: 1040950/1175368 || Nan Value: 48747/1175368
comments.text_compress: text_embedded --> unique values: 621070/623967 || Nan Value: 0/623967
badges.Name: text_embedded --> unique values: 327/463463 || Nan Value: 0/463463
users.text_compress: text_embedded --> unique values: 230262/255360 || Nan Value: 9/255360
posts.Body: text_embedded --> unique values: 333357/333893 || Nan Value: 493/333893
posts.text_compress: text_embedded --> unique values: 164938/333893 || Nan Value: 166739/333893


In [None]:
# start build graph
cache_dir = "./data/rel-trial-tensor-frame"
# cache_dir = "./data/stack-tensor-frame"
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )

-----> Materialize interventions Tensor Frame
-----> Materialize interventions_studies Tensor Frame
-----> Materialize facilities_studies Tensor Frame
-----> Materialize sponsors Tensor Frame
-----> Materialize eligibilities Tensor Frame
-----> Materialize reported_event_totals Tensor Frame
-----> Materialize designs Tensor Frame
-----> Materialize conditions_studies Tensor Frame
-----> Materialize drop_withdrawals Tensor Frame
-----> Materialize studies Tensor Frame
-----> Materialize outcome_analyses Tensor Frame
-----> Materialize sponsors_studies Tensor Frame
-----> Materialize outcomes Tensor Frame
-----> Materialize conditions Tensor Frame
-----> Materialize facilities Tensor Frame


True

In [10]:
# cache the col_type_dict
import pickle
with open(f"{cache_dir}/col_type_dict.pkl", "wb") as f:
    pickle.dump(col_type_dict, f)