In [1]:
%cd ..
from relbench.datasets import get_dataset
from tqdm import tqdm
import numpy as np

/home/lingze/embedding_fusion


In [2]:
dataset = get_dataset(name="rel-trial", download = True)
db = dataset.get_db()

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.65 seconds.


In [3]:
# homoGraph
from utils.builder import HomoGraph, make_homograph_from_db
homoGraph = make_homograph_from_db(db, verbose=True)

table interventions_studies -> table studies has 171771 edges
table interventions_studies -> table interventions has 171771 edges
table facilities_studies -> table studies has 1798765 edges
table facilities_studies -> table facilities has 1798765 edges
table eligibilities -> table studies has 249730 edges
table reported_event_totals -> table studies has 383064 edges
table designs -> table studies has 249093 edges
table conditions_studies -> table studies has 408422 edges
table conditions_studies -> table conditions has 408422 edges
table drop_withdrawals -> table studies has 381199 edges
table outcome_analyses -> table studies has 225846 edges
table outcome_analyses -> table outcomes has 225846 edges
table sponsors_studies -> table studies has 391462 edges
table sponsors_studies -> table sponsors has 391462 edges
table outcomes -> table studies has 411933 edges


In [4]:
from utils.preprocess import infer_type_in_db
from utils.tokenize import tokenize_database

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
col_type_dict = infer_type_in_db(db, True)

[rule 0]: interventionsInferred intervention_id from numerical as categorical
[rule 0]: interventions_studiesInferred id from numerical as categorical
[rule 0]: interventions_studiesInferred nct_id from numerical as categorical
[rule 0]: interventions_studiesInferred intervention_id from numerical as categorical
[rule 0]: facilities_studiesInferred id from numerical as categorical
[rule 0]: facilities_studiesInferred nct_id from numerical as categorical
[rule 0]: facilities_studiesInferred facility_id from numerical as categorical
[rule 0]: sponsorsInferred sponsor_id from numerical as categorical
[rule 0]: eligibilitiesInferred id from numerical as categorical
[rule 0]: eligibilitiesInferred nct_id from numerical as categorical
[rule 0]: reported_event_totalsInferred id from numerical as categorical
[rule 0]: reported_event_totalsInferred nct_id from numerical as categorical
[rule 0]: designsInferred id from numerical as categorical
[rule 0]: designsInferred nct_id from numerical as c

In [6]:
tk_db = tokenize_database(db, col_type_dict, './tmp_docs/rel-trails', True)

----------------> Tokenizing interventions each column
-> Load tokenized data from ./tmp_docs/rel-trails/interventions.npy
----------------> Tokenizing interventions_studies each column
-> Load tokenized data from ./tmp_docs/rel-trails/interventions_studies.npy
----------------> Tokenizing facilities_studies each column
-> Load tokenized data from ./tmp_docs/rel-trails/facilities_studies.npy
----------------> Tokenizing sponsors each column
-> Load tokenized data from ./tmp_docs/rel-trails/sponsors.npy
----------------> Tokenizing eligibilities each column
-> Load tokenized data from ./tmp_docs/rel-trails/eligibilities.npy
----------------> Tokenizing reported_event_totals each column
-> Load tokenized data from ./tmp_docs/rel-trails/reported_event_totals.npy
----------------> Tokenizing designs each column
-> Load tokenized data from ./tmp_docs/rel-trails/designs.npy
----------------> Tokenizing conditions_studies each column
-> Load tokenized data from ./tmp_docs/rel-trails/condition

In [7]:
from utils.document import generate_document_given_table
from utils.builder import identify_entity_table
from utils.builder import generate_hop_matrix

In [8]:
entity_tables = identify_entity_table(db)
entity_tables

['interventions',
 'sponsors',
 'eligibilities',
 'designs',
 'studies',
 'conditions']

In [10]:
# generated the documents and build the retrieval index
entity_to_docs = {}
walk_length = 10
round = 8
for entity in entity_tables:
   _, entity_to_docs[entity] = generate_document_given_table(
        homoGraph, 
        tk_db, 
        entity, 
        walk_length=walk_length, 
        round = round, 
        verbose=True
    )

- Walks for table interventions - shape torch.Size([3462, 8, 11])


                                                     

- Walks for table sponsors - shape torch.Size([53241, 8, 11])


                                                       

- Walks for table eligibilities - shape torch.Size([249730, 8, 11])


                                                         

- Walks for table designs - shape torch.Size([249093, 8, 11])


                                                         

- Walks for table studies - shape torch.Size([249730, 8, 11])


                                                         

- Walks for table conditions - shape torch.Size([3973, 8, 11])


                                                     

In [11]:
# temporarily save the index
import bm25s
entity_to_retriver = {}
for entity, docs in entity_to_docs.items():
    retriever = bm25s.BM25(backend="numba")
    retriever.index(docs)
    retriever.activate_numba_scorer()
    entity_to_retriver[entity] = retriever

# save the retriever
for entity, retriever in entity_to_retriver.items():
    retriever.save(f"./tmp/{entity}_retriever_bm25")

                                                                                          

In [9]:
import bm25s
entity_to_retriver = {}

# save the retriever
entity_to_retriver = {}
for entity in entity_tables:
    path = f"./tmp/{entity}_retriever_bm25"
    retriever = bm25s.BM25.load(path)
    retriever.activate_numba_scorer()
    entity_to_retriver[entity] = retriever
    print(f"load {path}")

load ./tmp/interventions_retriever_bm25
load ./tmp/sponsors_retriever_bm25
load ./tmp/eligibilities_retriever_bm25
load ./tmp/designs_retriever_bm25
load ./tmp/studies_retriever_bm25
load ./tmp/conditions_retriever_bm25


In [17]:
# resample the candidate docs, and retrieve the related docs in the bm25 retrievers

In [10]:
# generated the documents and build the retrieval index
entity_to_docs = {}
walk_length = 10
round = 8
entity_to_docs = {}
entity_candidate_pkys = {}
# for each
for entity in entity_tables:
    n = len(db.table_dict[entity].df)
    sample_size = n // 10
    sample_size = max(sample_size, 4096)
    entity_candidate_pkys[entity], entity_to_docs[entity] = generate_document_given_table(
        homoGraph, 
        tk_db, 
        entity, 
        walk_length=walk_length, 
        round = round,
        sample_size = sample_size,
        verbose=True
    )

- Walks for table interventions - shape torch.Size([3462, 8, 11])


                                                     

- Walks for table sponsors - shape torch.Size([5324, 8, 11])


                                                     

- Walks for table eligibilities - shape torch.Size([24973, 8, 11])


                                                       

- Walks for table designs - shape torch.Size([24909, 8, 11])


                                                       

- Walks for table studies - shape torch.Size([24973, 8, 11])


                                                       

- Walks for table conditions - shape torch.Size([3973, 8, 11])


                                                     

In [11]:
# Add the cross-table edges,
# first we want to find the multi-hop entity pairs
hop_matrix = generate_hop_matrix(db)
edge_candidates_pairs = []
for entity in entity_tables:
    for entity2 in entity_tables:
        if entity == entity2:
            continue
        
        if entity2 not in hop_matrix.graph[entity]:
            # not one hop
            edge_candidates_pairs.append((entity, entity2))
edge_candidates_pairs

[('interventions', 'sponsors'),
 ('interventions', 'eligibilities'),
 ('interventions', 'designs'),
 ('interventions', 'studies'),
 ('interventions', 'conditions'),
 ('sponsors', 'interventions'),
 ('sponsors', 'eligibilities'),
 ('sponsors', 'designs'),
 ('sponsors', 'studies'),
 ('sponsors', 'conditions'),
 ('eligibilities', 'interventions'),
 ('eligibilities', 'sponsors'),
 ('eligibilities', 'designs'),
 ('eligibilities', 'conditions'),
 ('designs', 'interventions'),
 ('designs', 'sponsors'),
 ('designs', 'eligibilities'),
 ('designs', 'conditions'),
 ('studies', 'interventions'),
 ('studies', 'sponsors'),
 ('studies', 'conditions'),
 ('conditions', 'interventions'),
 ('conditions', 'sponsors'),
 ('conditions', 'eligibilities'),
 ('conditions', 'designs'),
 ('conditions', 'studies')]

In [12]:
edge_candidates_pairs[0]

('interventions', 'sponsors')

In [13]:
import numpy as np
topn = 20
edge_dict = {}
# (src_table, des_table) -> edge 2-D array
for entity, retrieve_entity in edge_candidates_pairs:

    # retrieve the related docs
    entity_query_docs = entity_to_docs[entity]
    entity_query_pkys = entity_candidate_pkys[entity]
    retriever = entity_to_retriver[retrieve_entity]
    
    related_pkys, scores = retriever.retrieve(entity_query_docs, k = topn, n_threads = 24)
    
    score_np = np.array(scores)
    related_pkys_np = np.array(related_pkys)
    threshold = score_np.mean() + 2*scores.std()
    
    # Get indices where the score is above the threshold
    mask = score_np > threshold

    # Apply the mask
    filtered_cols = related_pkys_np[mask]

    # Generate the corresponding query entities
    entity_query_pkys = np.array(entity_query_pkys)  # shape [n]

    # Repeat each query item the number of True values per row in the mask
    row_repeats = mask.sum(axis=1)  # how many times to repeat each query
    filtered_rows = np.repeat(entity_query_pkys, row_repeats)
    
    
    filtered_edge = np.stack([filtered_rows, filtered_cols], axis=1)
    # added edge
    num_edges = filtered_rows.shape[0]
    edge_dict[(entity, retrieve_entity)] = filtered_edge
    print(f"Add cross table edges #{num_edges} between {entity} and {retrieve_entity}")
    

Add cross table edges #2378 between interventions and sponsors
Add cross table edges #2764 between interventions and eligibilities
Add cross table edges #2792 between interventions and designs
Add cross table edges #2810 between interventions and studies
Add cross table edges #2445 between interventions and conditions
Add cross table edges #2736 between sponsors and interventions
Add cross table edges #5050 between sponsors and eligibilities
Add cross table edges #5098 between sponsors and designs
Add cross table edges #5029 between sponsors and studies
Add cross table edges #2947 between sponsors and conditions
Add cross table edges #17213 between eligibilities and interventions


In [26]:
edge_dict = {}
# (src_table, des_table) -> edge 2-D array
npz_data = {
    f"{src}-{dst}": edge_array
    for (src, dst), edge_array in edge_dict.items()
}

path = f"./edges/rel-trail-edges.npz"
np.savez(path, **npz_data)

Self entity correlation, which generate positive pairs in the constrastive learning.

In [10]:
# self-entity correlation
# which can generate the positive pairs in the contrastive learning
# generated the documents and build the retrieval index
entity_to_docs = {}
walk_length = 10
round = 8
entity_to_docs = {}
entity_candidate_pkys = {}
# for each

for entity in entity_tables:
    n = len(db.table_dict[entity].df)
    sample_size = n // 5
    sample_size = max(sample_size, 4096)
    pkys , docs = generate_document_given_table(
        homoGraph, 
        tk_db, 
        entity, 
        walk_length=walk_length, 
        round = round,
        sample_size = sample_size,
        verbose=True
    )
    entity_candidate_pkys[entity] = pkys
    entity_to_docs[entity] = docs

- Walks for table interventions - shape torch.Size([3462, 8, 11])


                                                     

- Walks for table sponsors - shape torch.Size([10648, 8, 11])


                                                       

- Walks for table eligibilities - shape torch.Size([49946, 8, 11])


                                                       

- Walks for table designs - shape torch.Size([49818, 8, 11])


                                                       

- Walks for table studies - shape torch.Size([49946, 8, 11])


                                                       

- Walks for table conditions - shape torch.Size([3973, 8, 11])


                                                     

In [11]:
topn = 21
# the most related doc should be itself, so we need to retrieve topn + 1
positive_pool_dict = {}
# entity -> positive candidate, padding the non-value
threshold = 0.7
batch_size = 1024
for entity, retriever in entity_to_retriver.items():
    # retrieve the related docs
    entity_query_docs = entity_to_docs[entity]
    entity_query_pkys = entity_candidate_pkys[entity]
    score_np = []
    related_pkys_np = []
    print(f"--------> {entity}")
    for batch_idx in tqdm(range(0, len(entity_query_docs), batch_size)):
        batch_query_docs = entity_query_docs[batch_idx:batch_idx + batch_size]
        related_pkys, scores = retriever.retrieve(batch_query_docs, k = topn, n_threads=-1)
        score_np.append(np.array(scores))
        related_pkys_np.append(np.array(related_pkys))
    
    score_np = np.concatenate(score_np, axis = 0)
    related_pkys_np = np.concatenate(related_pkys_np, axis = 0)
    # Get indices where the score is above the threshold
    # the first one is the most related one, should be itself
    mask = score_np > (score_np[:,[0]] * threshold)
    # add padding for those non-related docs which is filtered out.
    related_pkys_np[~mask] = -1
    rows_num = np.sum(mask, axis = 1)
    # except itself, still has similar docs
    rows_mask = rows_num > 1
    positive_pool = related_pkys_np[rows_mask]
    
    positive_pool_dict[entity] = positive_pool
    print(f"Generate positive pools #{len(positive_pool)}, original candidate {len(entity_query_docs)} in {entity} table")

--------> interventions


100%|██████████| 4/4 [00:06<00:00,  1.67s/it]


Generate positive pools #787, original candidate 3462 in interventions table
--------> sponsors


100%|██████████| 11/11 [00:06<00:00,  1.63it/s]


Generate positive pools #3942, original candidate 10648 in sponsors table
--------> eligibilities


100%|██████████| 49/49 [07:29<00:00,  9.17s/it]


Generate positive pools #2491, original candidate 49946 in eligibilities table
--------> designs


100%|██████████| 49/49 [06:36<00:00,  8.08s/it]


Generate positive pools #2980, original candidate 49818 in designs table
--------> studies


100%|██████████| 49/49 [06:54<00:00,  8.46s/it]


Generate positive pools #6019, original candidate 49946 in studies table
--------> conditions


100%|██████████| 4/4 [00:01<00:00,  3.80it/s]

Generate positive pools #1299, original candidate 3973 in conditions table





In [None]:
np.savez(path, **positive_pool_dict)