In [1]:
%cd ..
import torch
import math
import torch_frame
import copy
from tqdm import tqdm
from utils.data import DatabaseFactory
from utils.builder import build_pyg_hetero_graph
from utils.resource import get_text_embedder_cfg
from utils.util import load_col_types
from utils.document import generate_document_given_table
from utils.builder import identify_entity_table
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cache_dir = "/home/lingze/.cache/relbench/ratebeer"
db = DatabaseFactory.get_db("ratebeer", cache_dir = cache_dir)

Loading Database object from /home/lingze/.cache/relbench/ratebeer/db...
Done in 0.59 seconds.


In [3]:
db.table_dict['beers'].df

Unnamed: 0,beer_id,brewer_id,style_id,alcohol_pct,ibu,is_seasonal,created_at,view_count,avg_rating,rating_count,...,rating_std_dev,overall_percentile,style_percentile,last_9m_avg,last_9m_count,straight_avg_rating,straight_rating_count,year4_avg,year4_overall,year4_count
0,0,3,15.0,5.0,20.0,False,2000-04-02 00:00:00.000,317,2.948404,254,...,0.58,25.620046,26.787973,5.946154,2,3.048845,303,2.766827,10.264648,10
1,1,1,7.0,0.0,,False,2000-04-02 00:00:00.000,311,3.285759,4,...,0.65,0.000000,0.000000,,0,3.471429,7,,,0
2,2,0,54.0,4.8,,False,2000-04-02 00:00:00.000,179,2.463477,211,...,0.53,8.040534,4.697825,6.433333,1,2.476818,220,,,0
3,3,3,62.0,7.0,,False,2000-04-02 00:00:00.000,188,3.137644,122,...,0.51,42.056885,75.973324,6.576190,1,3.189474,133,3.119702,,3
4,4,2,5.0,5.5,,False,2000-04-02 00:00:00.000,241,3.388736,110,...,0.63,71.576494,70.683458,6.600000,1,3.385000,120,2.940864,,3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
719159,719159,16978,12.0,6.1,,False,2024-05-31 23:40:45.363,0,3.043142,1,...,,,,,0,3.500000,1,,,0
719160,719160,36113,181.0,5.4,0.0,False,2024-05-31 23:43:16.477,0,3.140665,2,...,,,,,0,3.550000,2,,,0
719161,719161,122,169.0,11.6,0.0,True,2024-05-31 23:45:22.057,0,3.472321,9,...,,78.375949,87.807467,,0,3.630000,10,,,0
719162,719162,13788,72.0,8.0,0.0,True,2024-05-31 23:47:02.870,0,3.257796,6,...,,,,,0,3.470000,10,,,0


In [4]:
entity_tables = identify_entity_table(db)
entity_tables.append('beers')
entity_tables

['places', 'countries', 'brewers', 'users', 'beers']

In [5]:
edge_candidates_pairs = []
edge_candidates_pairs = [(src, dst) for src in entity_tables for dst in entity_tables if src != dst]
edge_candidates_pairs

[('places', 'countries'),
 ('places', 'brewers'),
 ('places', 'users'),
 ('places', 'beers'),
 ('countries', 'places'),
 ('countries', 'brewers'),
 ('countries', 'users'),
 ('countries', 'beers'),
 ('brewers', 'places'),
 ('brewers', 'countries'),
 ('brewers', 'users'),
 ('brewers', 'beers'),
 ('users', 'places'),
 ('users', 'countries'),
 ('users', 'brewers'),
 ('users', 'beers'),
 ('beers', 'places'),
 ('beers', 'countries'),
 ('beers', 'brewers'),
 ('beers', 'users')]

In [6]:
from utils.builder import generate_hop_matrix
hop_matrix = generate_hop_matrix(db)

In [7]:
hop_matrix.search_tables('places', 2)
hop_matrix.search_tables('countries', 2)

['countries', 'place_ratings', 'beers', 'users', 'favorites', 'beer_ratings']

In [8]:
edge_candidates_pairs_ = []
for src, dst in edge_candidates_pairs:
    src_more_hop_tables = hop_matrix.search_tables(src, 2)
    if dst in src_more_hop_tables:
        edge_candidates_pairs_.append((src, dst))

In [9]:
edge_candidates_pairs_

[('places', 'brewers'),
 ('places', 'users'),
 ('places', 'beers'),
 ('countries', 'users'),
 ('countries', 'beers'),
 ('brewers', 'places'),
 ('brewers', 'users'),
 ('users', 'places'),
 ('users', 'countries'),
 ('users', 'brewers'),
 ('users', 'beers'),
 ('beers', 'places'),
 ('beers', 'countries'),
 ('beers', 'users')]

In [10]:
# homoGraph
from utils.builder import HomoGraph, make_homograph_from_db
homoGraph = make_homograph_from_db(db, verbose=True)

table favorites -> table beers has 142597 edges
table favorites -> table users has 142597 edges
table beers -> table brewers has 719084 edges
table places -> table countries has 36183 edges
table availability -> table beers has 45759 edges
table availability -> table countries has 45768 edges
table availability -> table places has 45768 edges
table place_ratings -> table places has 79789 edges
table place_ratings -> table users has 71072 edges


table beer_ratings -> table beers has 4231639 edges
table beer_ratings -> table users has 4232417 edges
table brewers -> table countries has 39032 edges


In [11]:
from utils.preprocess import infer_type_in_db
from utils.tokenize import tokenize_database
col_type_dict = infer_type_in_db(db, True)

[rule 0]: favorites Inferred favorite_id from numerical as categorical
[rule 0]: favorites Inferred user_id from numerical as categorical
[rule 0]: favorites Inferred beer_id from numerical as categorical
[rule 0]: favorites Inferred list_id from numerical as categorical
[rule 0]: beers Inferred beer_id from numerical as categorical
[rule 0]: beers Inferred brewer_id from numerical as categorical
[rule 0]: beers Inferred style_id from numerical as categorical
[rule 1]: beers Inferred view_count from numerical as categorical
[rule 1]: beers Inferred last_9m_count from numerical as categorical
[rule 0]: places Inferred place_id from numerical as categorical
[rule 0]: places Inferred state_id from numerical as categorical
[rule 0]: places Inferred country_id from numerical as categorical
[rule 1]: places Inferred has_smoking from numerical as categorical
[rule 1]: places Inferred has_cigars from numerical as categorical
[rule 1]: places Inferred takes_reservations from numerical as catego

In [12]:
tk_db = tokenize_database(db, col_type_dict, './tmp_docs/ratebeer', True)

----------------> Tokenizing favorites each column
-> Load tokenized data from ./tmp_docs/ratebeer/favorites.npy
----------------> Tokenizing beers each column
-> Load tokenized data from ./tmp_docs/ratebeer/beers.npy
----------------> Tokenizing places each column
-> Load tokenized data from ./tmp_docs/ratebeer/places.npy
----------------> Tokenizing availability each column
-> Load tokenized data from ./tmp_docs/ratebeer/availability.npy
----------------> Tokenizing place_ratings each column
-> Load tokenized data from ./tmp_docs/ratebeer/place_ratings.npy
----------------> Tokenizing beer_ratings each column
-> Load tokenized data from ./tmp_docs/ratebeer/beer_ratings.npy
----------------> Tokenizing countries each column
-> Load tokenized data from ./tmp_docs/ratebeer/countries.npy
----------------> Tokenizing brewers each column
-> Load tokenized data from ./tmp_docs/ratebeer/brewers.npy
----------------> Tokenizing users each column
-> Load tokenized data from ./tmp_docs/ratebeer

In [13]:
# # generated the documents and build the retrieval index
# entity_to_docs = {}
# walk_length = 10
# round = 10
# for entity in entity_tables:
#    _, entity_to_docs[entity] = generate_document_given_table(
#         homoGraph, 
#         tk_db, 
#         entity, 
#         walk_length=walk_length, 
#         round = round, 
#         verbose=True
#     )

In [14]:
# # temporarily save the index
# import bm25s
# entity_to_retriver = {}
# for entity, docs in entity_to_docs.items():
#     retriever = bm25s.BM25(backend="numba")
#     retriever.index(docs)
#     retriever.activate_numba_scorer()
#     entity_to_retriver[entity] = retriever

# # save the retriever
# for entity, retriever in entity_to_retriver.items():
#     retriever.save(f"./tmp/ratebeer/{entity}_retriever_bm25")

In [15]:
import bm25s
entity_to_retriver = {}

# load the retriever
entity_to_retriver = {}
for entity in entity_tables:
    path = f"./tmp/ratebeer/{entity}_retriever_bm25"
    retriever = bm25s.BM25.load(path)
    retriever.activate_numba_scorer()
    entity_to_retriver[entity] = retriever
    print(f"load {path}")

load ./tmp/ratebeer/places_retriever_bm25
load ./tmp/ratebeer/countries_retriever_bm25
load ./tmp/ratebeer/brewers_retriever_bm25
load ./tmp/ratebeer/users_retriever_bm25
load ./tmp/ratebeer/beers_retriever_bm25


In [16]:
# resample the candidate docs, and retrieve the related docs in the bm25 retrievers
# generated the documents and build the retrieval index
walk_length = 8
round = 10
entity_to_docs = {}
entity_candidate_pkys = {}
# for each
sample_size_dict = {
    "places": 0.5,
    "countries": 1,
    "brewers": 0.5,
    "users": 0.5,
    "beers": 0.1,
}

for entity in entity_tables:
    n = len(db.table_dict[entity].df)
    sample_size = int(sample_size_dict.get(entity, 0.1) * n)
    entity_candidate_pkys[entity], entity_to_docs[entity] = generate_document_given_table(
        homoGraph, 
        tk_db, 
        entity, 
        walk_length=walk_length, 
        round = round,
        sample_size = sample_size,
        verbose=True
    )

- Walks for table places - shape torch.Size([18091, 10, 9])


                                                       

- Walks for table countries - shape torch.Size([251, 10, 9])


                                                   

- Walks for table brewers - shape torch.Size([19516, 10, 9])


                                                       

- Walks for table users - shape torch.Size([44831, 10, 9])


                                                       

- Walks for table beers - shape torch.Size([71916, 10, 9])


                                                       

In [17]:
import numpy as np
topn = 10
edge_dict = {}
# (src_table, des_table) -> edge 2-D array
edge_candidates_pairs = edge_candidates_pairs_
for entity, retrieve_entity in edge_candidates_pairs:

    # retrieve the related docs
    entity_query_docs = entity_to_docs[entity]
    entity_query_pkys = entity_candidate_pkys[entity]
    retriever = entity_to_retriver[retrieve_entity]
    
    related_pkys, scores = retriever.retrieve(entity_query_docs, k = topn, n_threads = 24)
    
    score_np = np.array(scores)
    related_pkys_np = np.array(related_pkys)
    threshold = score_np.mean() + 2*scores.std()
    
    # Get indices where the score is above the threshold
    mask = score_np > threshold

    # Apply the mask
    filtered_cols = related_pkys_np[mask]

    # Generate the corresponding query entities
    entity_query_pkys = np.array(entity_query_pkys)  # shape [n]

    # Repeat each query item the number of True values per row in the mask
    row_repeats = mask.sum(axis=1)  # how many times to repeat each query
    filtered_rows = np.repeat(entity_query_pkys, row_repeats)
    
    
    filtered_edge = np.stack([filtered_rows, filtered_cols], axis=1)
    # added edge
    num_edges = filtered_rows.shape[0]
    edge_dict[(entity, retrieve_entity)] = filtered_edge
    print(f"Add cross table edges #{num_edges} between {entity} and {retrieve_entity}")

Add cross table edges #8408 between places and brewers
Add cross table edges #6187 between places and users
Add cross table edges #7826 between places and beers
Add cross table edges #83 between countries and users
Add cross table edges #71 between countries and beers
Add cross table edges #8144 between brewers and places
Add cross table edges #10119 between brewers and users
Add cross table edges #26109 between users and places
Add cross table edges #16350 between users and countries
Add cross table edges #18127 between users and brewers
Add cross table edges #20677 between users and beers
Add cross table edges #35062 between beers and places
Add cross table edges #25480 between beers and countries
Add cross table edges #29988 between beers and users


In [18]:
# (src_table, des_table) -> edge 2-D array
npz_data = {
    f"{src}-{dst}": edge_array
    for (src, dst), edge_array in edge_dict.items()
}

path = f"./edges/ratebeer-edges.npz"
np.savez(path, **npz_data)

In [23]:
# self-entity correlation
# which can generate the positive pairs in the contrastive learning
topn = 11
# the most related doc should be itself, so we need to retrieve topn + 1
positive_pool_dict = {}
# entity -> positive candidate, padding the non-value
threshold = 0.75
batch_size = 1024
for entity, retriever in entity_to_retriver.items():
    # retrieve the related docs
    entity_query_docs = entity_to_docs[entity]
    entity_query_pkys = entity_candidate_pkys[entity]
    score_np = []
    related_pkys_np = []
    print(f"--------> {entity}")
    for batch_idx in tqdm(range(0, len(entity_query_docs), batch_size)):
        batch_query_docs = entity_query_docs[batch_idx:batch_idx + batch_size]
        related_pkys, scores = retriever.retrieve(batch_query_docs, k = topn, n_threads=-1)
        score_np.append(np.array(scores))
        related_pkys_np.append(np.array(related_pkys))
    
    score_np = np.concatenate(score_np, axis = 0)
    related_pkys_np = np.concatenate(related_pkys_np, axis = 0)
    # Get indices where the score is above the threshold
    # the first one is the most related one, should be itself
    mask = score_np > (score_np[:,[0]] * threshold)
    # add padding for those non-related docs which is filtered out.
    related_pkys_np[~mask] = -1
    rows_num = np.sum(mask, axis = 1)
    # except itself, still has similar docs
    rows_mask = rows_num > 1
    positive_pool = related_pkys_np[rows_mask]
    
    positive_pool_dict[entity] = positive_pool
    print(f"Generate positive pools #{len(positive_pool)}, original candidate {len(entity_query_docs)} in {entity} table")

--------> places


100%|██████████| 18/18 [00:44<00:00,  2.49s/it]


Generate positive pools #3312, original candidate 18091 in places table
--------> countries


100%|██████████| 1/1 [00:00<00:00, 14.82it/s]


Generate positive pools #7, original candidate 251 in countries table
--------> brewers


100%|██████████| 20/20 [00:40<00:00,  2.01s/it]


Generate positive pools #1263, original candidate 19516 in brewers table
--------> users


100%|██████████| 44/44 [02:28<00:00,  3.37s/it]


Generate positive pools #6558, original candidate 44831 in users table
--------> beers


100%|██████████| 71/71 [2:25:39<00:00, 123.09s/it]  

Generate positive pools #41506, original candidate 71916 in beers table





In [24]:
path = "./samples/ratebeer-samples.npz"
np.savez(path, **positive_pool_dict)