In [1]:
%cd ..
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
import pickle
import os

from torch_geometric.data import HeteroData
from relbench.datasets import get_dataset

from utils.data import DatabaseFactory

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
cache_dir = "/home/lingze/.cache/relbench/ratebeer"
dataset = DatabaseFactory.get_dataset("ratebeer", cache_dir=cache_dir)

In [4]:
db = DatabaseFactory.get_db("ratebeer", cache_dir=cache_dir)

Loading Database object from /home/lingze/.cache/relbench/ratebeer/db...
Done in 0.84 seconds.


In [5]:
from utils.preprocess import infer_type_in_db

col_type_dict = infer_type_in_db(db, True)

[rule 0]: favorites Inferred favorite_id from numerical as categorical
[rule 0]: favorites Inferred user_id from numerical as categorical
[rule 0]: favorites Inferred beer_id from numerical as categorical
[rule 0]: favorites Inferred list_id from numerical as categorical
[rule 0]: beers Inferred beer_id from numerical as categorical
[rule 0]: beers Inferred brewer_id from numerical as categorical
[rule 0]: beers Inferred style_id from numerical as categorical
[rule 1]: beers Inferred view_count from numerical as categorical
[rule 1]: beers Inferred last_9m_count from numerical as categorical
[rule 0]: places Inferred place_id from numerical as categorical
[rule 0]: places Inferred state_id from numerical as categorical
[rule 0]: places Inferred country_id from numerical as categorical
[rule 1]: places Inferred has_cigars from numerical as categorical
[rule 1]: places Inferred takes_reservations from numerical as categorical
[rule 1]: places Inferred has_games from numerical as categori

In [7]:
# preprocess the table, concatenate the columns which is text type
#        /--- text_col_1 ---/ --- text_col_2 --- / --- text_col_3 --- / 
# row 1  /------- A   -----/ ------- B   -----  / -----   C  ------- /
# -------> Generate a new TexT column
# "text_col_1 is A, text_col_2 is B, text_col_3 is C"

# Therefore, we only need to convert this text column to vector 
# and drop the original text columns
# for saving memory and computation 
from torch_frame import stype

for table_name, type_dict in col_type_dict.items():
    # collect the text columns
    text_cols = [ col for col, stype in type_dict.items() if stype == stype.text_embedded]
    compress_cols = []
    # for long text, we still keep it as one column
    for col in text_cols:
        avg_word_count = db.table_dict[table_name].df[col].dropna().apply(lambda x: len(str(x).split())).mean()
        if avg_word_count < 128: # a half of default max length of BERT Max length （256）
            # remove the long text cols
            compress_cols.append(col)
          
    
    if len(compress_cols) <= 1:
        # if only one text column, we do not need to compress
        print(f"----> No need to compress {table_name} text columns: {compress_cols}")
        continue
    
    print(f"----> Compressing {table_name} text columns: {compress_cols}")
    
    df = db.table_dict[table_name].df
    compress_text_df = df[compress_cols]
    
    def row_to_text(row):
        if row.isna().all():
            return None
        tokens = [f"{key} is {value}" for key, value in row.dropna().items()]
        return ", ".join(tokens)

    text_list = compress_text_df.apply(row_to_text, axis=1).tolist()
    
    # drop the compressed columns
    df.drop(columns=compress_cols, inplace=True)
    df["text_compress"] = text_list
    
    # update the type dict
    for col in compress_cols:
        type_dict.pop(col)
    type_dict["text_compress"] = stype.text_embedded

----> No need to compress favorites text columns: []
----> No need to compress beers text columns: []
----> No need to compress places text columns: []
----> No need to compress availability text columns: []
----> No need to compress place_ratings text columns: []
----> No need to compress beer_ratings text columns: []
----> No need to compress countries text columns: []
----> No need to compress brewers text columns: []
----> No need to compress users text columns: []


In [8]:
from utils.resource import get_text_embedder_cfg
text_embedder_cfg = get_text_embedder_cfg(
    # model_name = "sentence-transformers/average_word_embeddings_glove.6B.300d", 
    model_name = "all-MiniLM-L12-v2",
    device = device)

In [9]:
# materialize the tensor_frame 
from utils.builder import build_pyg_hetero_graph
cache_dir = "./data/ratebeer-tensor-frame"
data, col_stats_dict = build_pyg_hetero_graph(
    db,
    col_type_dict,
    text_embedder_cfg,
    cache_dir,
    True,
)

-----> Materialize favorites Tensor Frame
-----> Materialize beers Tensor Frame
-----> Materialize places Tensor Frame
-----> Materialize availability Tensor Frame
-----> Materialize place_ratings Tensor Frame
-----> Materialize beer_ratings Tensor Frame
-----> Materialize countries Tensor Frame
-----> Materialize brewers Tensor Frame
-----> Materialize users Tensor Frame


In [10]:
# save the col_type_dict
with open(os.path.join(cache_dir, "col_type_dict.pkl"), "wb") as f:
    pickle.dump(col_type_dict, f)