In [1]:
%cd ..
from relbench.datasets import get_dataset
from tqdm import tqdm
import numpy as np

/home/lingze/embedding_fusion


In [2]:
import torch
from torch import Tensor
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from utils.data import StackDataset
from utils.preprocess import infer_type_in_db
from utils.tokenize import tokenize_database
from utils.builder import build_pyg_hetero_graph

In [4]:
dataset = StackDataset(cache_dir="/home/lingze/.cache/relbench/stack")

In [5]:
db = dataset.get_db()

Loading Database object from /home/lingze/.cache/relbench/stack/db...
Done in 10.99 seconds.


In [6]:
for table_name, table in db.table_dict.items():
    n = len(table.df)
    print(f"Table {table_name} has {n} rows")

Table tags has 1597 rows
Table postHistory has 1175368 rows
Table comments has 623967 rows
Table badges has 463463 rows
Table postTag has 648577 rows
Table users has 255360 rows
Table postLinks has 77337 rows
Table votes has 1317876 rows
Table posts has 333893 rows


In [7]:

col_type_dict = infer_type_in_db(db, True)

[rule 0]: tags Inferred Id from numerical as categorical
[rule 0]: postHistory Inferred Id from numerical as categorical
[rule 0]: postHistory Inferred PostId from numerical as categorical
[rule 0]: postHistory Inferred UserId from numerical as categorical
[rule 0]: postHistory Inferred PostHistoryTypeId from numerical as categorical
[rule 0]: postHistory Inferred ContentLicense from categorical as text_embedded
[rule 1]: postHistory Inferred ContentLicense from text_embedded as categorical
[rule 0]: postHistory Inferred RevisionGUID from text_embedded as categorical
[rule 0]: comments Inferred Id from numerical as categorical
[rule 0]: comments Inferred PostId from numerical as categorical
[rule 0]: comments Inferred UserId from numerical as categorical
[rule 1]: comments Inferred Score from numerical as categorical
[rule 0]: comments Inferred ContentLicense from categorical as text_embedded
[rule 1]: comments Inferred ContentLicense from text_embedded as categorical
[rule 0]: badges 

In [8]:
# print text embedding type, profile
for table, type_dict in col_type_dict.items():
    for col_name, stype in type_dict.items():
        if stype == stype.text_embedded:
            unique_value = db.table_dict[table].df[col_name].unique()
            n = len(unique_value)
            nm = len(db.table_dict[table].df)
            nan_num = db.table_dict[table].df[col_name].isnull().sum()
            print(f"{table}.{col_name}: {stype}, unique values: {n}/{nm} Nan Value: {nan_num}/{nm}")

tags.TagName: text_embedded, unique values: 1597/1597 Nan Value: 0/1597
postHistory.UserDisplayName: text_embedded, unique values: 3061/1175368 Nan Value: 1150131/1175368
postHistory.Text: text_embedded, unique values: 960746/1175368 Nan Value: 114938/1175368
postHistory.Comment: text_embedded, unique values: 125619/1175368 Nan Value: 714309/1175368
comments.UserDisplayName: text_embedded, unique values: 2222/623967 Nan Value: 612281/623967
comments.Text: text_embedded, unique values: 621044/623967 Nan Value: 0/623967
badges.Name: text_embedded, unique values: 327/463463 Nan Value: 0/463463
users.DisplayName: text_embedded, unique values: 218050/255360 Nan Value: 19/255360
users.Location: text_embedded, unique values: 11514/255360 Nan Value: 184185/255360
users.AboutMe: text_embedded, unique values: 47461/255360 Nan Value: 205676/255360
posts.OwnerDisplayName: text_embedded, unique values: 4531/333893 Nan Value: 325346/333893
posts.Title: text_embedded, unique values: 163648/333893 Nan

In [None]:
# preprocess the table, concatenate the columns which is text type
#        /--- text_col_1 ---/ --- text_col_2 --- / --- text_col_3 --- / 
# row 1  /------- A   -----/ ------- B   -----  / -----   C  ------- /
# -------> Generate a new TexT column
# "text_col_1 is A, text_col_2 is B, text_col_3 is C"

# Therefore, we only need to convert this text column to vector 
# and drop the original text columns
# for saving memory and computation 


for table_name, type_dict in col_type_dict.items():
    # collect the text columns
    text_cols = [ col for col, stype in type_dict.items() if stype == stype.text_embedded]
    compress_cols = []
    # for long text, we still keep it as one column
    for col in text_cols:
        avg_word_count = db.table_dict[table_name].df[col].dropna().apply(lambda x: len(str(x).split())).mean()
        if avg_word_count < 128: # a half of default max length of BERT Max length （256）
            # remove the long text cols
            compress_cols.append(col)
          
    
    if len(compress_cols) <= 1:
        # if only one text column, we do not need to compress
        continue
    
    print(f"----> Compressing {table_name} text columns: {compress_cols}")
    
    df = db.table_dict[table_name].df
    compress_text_df = df[compress_cols]
    
    def row_to_text(row):
        if row.isna().all():
            return None
        tokens = [f"{key} is {value}" for key, value in row.dropna().items()]
        return ", ".join(tokens)

    text_list = compress_text_df.apply(row_to_text, axis=1).tolist()
    
    # drop the compressed columns
    df.drop(columns=compress_cols, inplace=True)
    df["text_compress"] = text_list
    
    # update the type dict
    for col in compress_cols:
        type_dict.pop(col)
    type_dict["text_compress"] = stype.text_embedded

----> Compressing postHistory text columns: ['UserDisplayName', 'Text', 'Comment']
----> Compressing comments text columns: ['UserDisplayName', 'Text']
----> Compressing users text columns: ['DisplayName', 'Location', 'AboutMe']
----> Compressing posts text columns: ['OwnerDisplayName', 'Title']


In [9]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)


  return self.fget.__get__(instance, owner)()


In [10]:
cache_dir = "./data/stack-tensor-frame"
data, col_stats_dict = build_pyg_hetero_graph(
    db,
    col_type_dict,
    text_embedder_cfg,
    cache_dir,
    True,
)

-----> Materialize tags Tensor Frame
-----> Materialize postHistory Tensor Frame
-----> Materialize comments Tensor Frame
-----> Materialize badges Tensor Frame
-----> Build edge between posts and tags
-----> Materialize users Tensor Frame
-----> Materialize postLinks Tensor Frame
-----> Materialize votes Tensor Frame
-----> Materialize posts Tensor Frame


True