In [1]:
%cd ..
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
import pickle
import os

from torch_geometric.data import HeteroData
from relbench.datasets import get_dataset

from utils.data import preprocess_event_database

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = get_dataset('rel-event')
db = dataset.get_db()
preprocess_event_database(db)

Loading Database object from /home/lingze/.cache/relbench/rel-event/db...
Done in 3.02 seconds.


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  event_df["event_id"].replace(event_id2index, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  event_df["event_id"].replace(event_id2index, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] 

In [3]:
from utils.document import generate_document_given_table
from utils.builder import identify_entity_table
from utils.builder import generate_hop_matrix
entity_tables = identify_entity_table(db)
entity_tables

['events', 'users']

In [4]:
edge_candidates_pairs = []
edge_candidates_pairs.extend([("events", "users"), ("users", "events")])

In [5]:
# homoGraph
from utils.builder import HomoGraph, make_homograph_from_db
homoGraph = make_homograph_from_db(db, verbose=True)

table event_attendees -> table events has 49822 edges
table event_attendees -> table users has 49822 edges
table user_friends -> table users has 213703 edges
table user_friends -> table users has 213703 edges
table events -> table users has 86 edges
table event_interest -> table events has 14135 edges
table event_interest -> table users has 14135 edges


In [6]:
from utils.preprocess import infer_type_in_db
from utils.tokenize import tokenize_database
col_type_dict = infer_type_in_db(db, True)

[rule 0]: event_attendees Inferred id from numerical as categorical
[rule 0]: event_attendees Inferred user_id from numerical as categorical
[rule 0]: user_friends Inferred id from numerical as categorical
[rule 0]: events Inferred event_id from numerical as categorical
[rule 0]: events Inferred user_id from numerical as categorical
[rule 1]: events Inferred c_1 from numerical as categorical
[rule 1]: events Inferred c_2 from numerical as categorical
[rule 1]: events Inferred c_3 from numerical as categorical
[rule 1]: events Inferred c_4 from numerical as categorical
[rule 1]: events Inferred c_5 from numerical as categorical
[rule 1]: events Inferred c_6 from numerical as categorical
[rule 1]: events Inferred c_7 from numerical as categorical
[rule 1]: events Inferred c_8 from numerical as categorical
[rule 1]: events Inferred c_9 from numerical as categorical
[rule 1]: events Inferred c_10 from numerical as categorical
[rule 1]: events Inferred c_11 from numerical as categorical
[ru

In [7]:
# check all col types
for table_name, col_types in col_type_dict.items():
    print(f"Table {table_name}")
    for col, type_ in col_types.items():
        print(f"{col}: {type_}")
    print("*"*40)

Table event_attendees
id: categorical
event: categorical
status: categorical
user_id: categorical
start_time: timestamp
****************************************
Table user_friends
id: categorical
user: categorical
friend: categorical
****************************************
Table events
event_id: categorical
user_id: categorical
start_time: timestamp
city: text_embedded
state: text_embedded
zip: text_embedded
country: text_embedded
lat: numerical
lng: numerical
c_1: categorical
c_2: categorical
c_3: categorical
c_4: categorical
c_5: categorical
c_6: categorical
c_7: categorical
c_8: categorical
c_9: categorical
c_10: categorical
c_11: categorical
c_12: categorical
c_13: categorical
c_14: categorical
c_15: categorical
c_16: categorical
c_17: categorical
c_18: categorical
c_19: categorical
c_20: categorical
c_21: categorical
c_22: categorical
c_23: categorical
c_24: categorical
c_25: categorical
c_26: categorical
c_27: categorical
c_28: categorical
c_29: categorical
c_30: categorical
c_3

In [8]:
tk_db = tokenize_database(db, col_type_dict, './tmp_docs/rel-events', True)

----------------> Tokenizing event_attendees each column
-> Load tokenized data from ./tmp_docs/rel-events/event_attendees.npy
----------------> Tokenizing user_friends each column
-> Load tokenized data from ./tmp_docs/rel-events/user_friends.npy
----------------> Tokenizing events each column
-> Load tokenized data from ./tmp_docs/rel-events/events.npy
----------------> Tokenizing event_interest each column
-> Load tokenized data from ./tmp_docs/rel-events/event_interest.npy
----------------> Tokenizing users each column
-> Load tokenized data from ./tmp_docs/rel-events/users.npy


In [9]:
generated the documents and build the retrieval index
# entity_to_docs = {}
# walk_length = 10
# round = 10
# for entity in entity_tables:
#    _, entity_to_docs[entity] = generate_document_given_table(
#         homoGraph, 
#         tk_db, 
#         entity, 
#         walk_length=walk_length, 
#         round = round, 
#         verbose=True
#     )

SyntaxError: invalid syntax (710976225.py, line 1)

In [None]:
# # temporarily save the index
# import bm25s
# entity_to_retriver = {}
# for entity, docs in entity_to_docs.items():
#     retriever = bm25s.BM25(backend="numba")
#     retriever.index(docs)
#     retriever.activate_numba_scorer()
#     entity_to_retriver[entity] = retriever

# # save the retriever
# for entity, retriever in entity_to_retriver.items():
#     retriever.save(f"./tmp/event/{entity}_retriever_bm25")

                                                                                         

In [10]:
import bm25s
entity_to_retriver = {}

# load the retriever
entity_to_retriver = {}
for entity in entity_tables:
    path = f"./tmp/event/{entity}_retriever_bm25"
    retriever = bm25s.BM25.load(path)
    retriever.activate_numba_scorer()
    entity_to_retriver[entity] = retriever
    print(f"load {path}")

load ./tmp/event/events_retriever_bm25
load ./tmp/event/users_retriever_bm25


In [11]:
# resample the candidate docs, and retrieve the related docs in the bm25 retrievers
# generated the documents and build the retrieval index
walk_length = 8
round = 10
entity_to_docs = {}
entity_candidate_pkys = {}
# for each
for entity in entity_tables:
    n = len(db.table_dict[entity].df)
    sample_size = n
    entity_candidate_pkys[entity], entity_to_docs[entity] = generate_document_given_table(
        homoGraph, 
        tk_db, 
        entity, 
        walk_length=walk_length, 
        round = round,
        sample_size = sample_size,
        verbose=True
    )

- Walks for table events - shape torch.Size([11465, 10, 6])


                                                       

- Walks for table users - shape torch.Size([37143, 10, 6])


                                                        

In [12]:
# Add the cross-table edges,
import numpy as np
topn = 20
edge_dict = {}
batch_size = 2048
# (src_table, des_table) -> edge 2-D array
for entity, retrieve_entity in edge_candidates_pairs:

    # retrieve the related docs
    entity_query_docs = entity_to_docs[entity]
    entity_query_pkys = entity_candidate_pkys[entity]
    retriever = entity_to_retriver[retrieve_entity]
    
    print(f"--------> {entity} ---- {retrieve_entity}")
    score_np = []
    related_pkys_np = []
    for batch_idx in tqdm(range(0, len(entity_query_docs), batch_size)):
        batch_query_docs = entity_query_docs[batch_idx:batch_idx+batch_size]
        related_pkys, scores = retriever.retrieve(batch_query_docs, k = topn, n_threads = -1)
        score_np.append(np.array(scores))
        related_pkys_np.append(np.array(related_pkys))
    
    score_np = np.concatenate(score_np, axis = 0)
    related_pkys_np = np.concatenate(related_pkys_np, axis = 0)
    
    threshold = score_np.mean() + 2*scores.std()
    # Get indices where the score is above the threshold
    mask = score_np > threshold

    # Apply the mask
    filtered_cols = related_pkys_np[mask]

    # Generate the corresponding query entities
    entity_query_pkys = np.array(entity_query_pkys)  # shape [n]

    # Repeat each query item the number of True values per row in the mask
    row_repeats = mask.sum(axis=1)  # how many times to repeat each query
    filtered_rows = np.repeat(entity_query_pkys, row_repeats)
    
    
    filtered_edge = np.stack([filtered_rows, filtered_cols], axis=1)
    # added edge
    num_edges = filtered_rows.shape[0]
    edge_dict[(entity, retrieve_entity)] = filtered_edge
    print(f"Add cross table edges #{num_edges} between {entity} and {retrieve_entity}")
    

--------> events ---- users


100%|██████████| 6/6 [00:19<00:00,  3.22s/it]


Add cross table edges #8174 between events and users
--------> users ---- events


100%|██████████| 19/19 [00:07<00:00,  2.62it/s]

Add cross table edges #30902 between users and events





In [13]:
# (src_table, des_table) -> edge 2-D array
npz_data = {
    f"{src}-{dst}": edge_array
    for (src, dst), edge_array in edge_dict.items()
}

path = f"./edges/rel-event-edges.npz"
np.savez(path, **npz_data)

In [14]:
# resample the candidate docs, and retrieve the related docs in the bm25 retrievers
# generated the documents and build the retrieval index
walk_length = 10
round = 10
entity_to_docs = {}
entity_candidate_pkys = {}
# for each
for entity in entity_tables:
    n = len(db.table_dict[entity].df)
    sample_size = n
    entity_candidate_pkys[entity], entity_to_docs[entity] = generate_document_given_table(
        homoGraph, 
        tk_db, 
        entity, 
        walk_length=walk_length, 
        round = round,
        sample_size = sample_size,
        verbose=True
    )

- Walks for table events - shape torch.Size([11465, 10, 6])


                                                       

- Walks for table users - shape torch.Size([37143, 10, 6])


                                                        

In [15]:
# self-entity correlation
# which can generate the positive pairs in the contrastive learning
entity_topn = {
    "events": 21,
    "users": 21,
}
# the most related doc should be itself, so we need to retrieve topn + 1
positive_pool_dict = {}
# entity -> positive candidate, padding the non-value
threshold = 0.7
batch_size = 2048
for entity, retriever in entity_to_retriver.items():
    # retrieve the related docs
    topn = entity_topn[entity]
    entity_query_docs = entity_to_docs[entity]
    entity_query_pkys = entity_candidate_pkys[entity]
    score_np = []
    related_pkys_np = []
    print(f"--------> {entity}")
    for batch_idx in tqdm(range(0, len(entity_query_docs), batch_size)):
        batch_query_docs = entity_query_docs[batch_idx:batch_idx + batch_size]
        related_pkys, scores = retriever.retrieve(batch_query_docs, k = topn, n_threads=-1)
        score_np.append(np.array(scores))
        related_pkys_np.append(np.array(related_pkys))
    
    score_np = np.concatenate(score_np, axis = 0)
    related_pkys_np = np.concatenate(related_pkys_np, axis = 0)
    # Get indices where the score is above the threshold
    # the first one is the most related one, should be itself
    mask = score_np > (score_np[:,[0]] * threshold)
    # add padding for those non-related docs which is filtered out.
    related_pkys_np[~mask] = -1
    rows_num = np.sum(mask, axis = 1)
    # except itself, still has similar docs
    rows_mask = rows_num > 1
    positive_pool = related_pkys_np[rows_mask]
    
    positive_pool_dict[entity] = positive_pool
    print(f"Generate positive pools #{len(positive_pool)}, original candidate {len(entity_query_docs)} in {entity} table")

--------> events


100%|██████████| 6/6 [00:14<00:00,  2.43s/it]


Generate positive pools #3554, original candidate 11465 in events table
--------> users


100%|██████████| 19/19 [00:09<00:00,  1.94it/s]

Generate positive pools #24091, original candidate 37143 in users table





In [17]:
path = "./samples/rel-event-samples.npz"
np.savez(path, **positive_pool_dict)