In [1]:
%cd ..
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder

import os

# Disable tokenizers parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

from relbench.datasets import get_dataset
from relbench.tasks import get_task
from relbench.base import BaseTask
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal

import os
import math
import numpy as np
from tqdm import tqdm
import copy

import torch






seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [3]:
dataset = get_dataset(name="rel-trial", download=True)
db = dataset.get_db()
task_a = get_task("rel-trial", "study-outcome", download = True)
task_b = get_task("rel-trial", "site-success", download = True)

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.80 seconds.


In [4]:
import networkx as nx
import numpy as np
import pandas as pd
import torch

from relbench.base import Database
from torch_frame.config import TextEmbedderConfig
from torch_frame import stype

from typing import Dict, Optional, Tuple

In [5]:
gid_to_entity, entity_to_gid = {}, {}
# node_id -> (table_name, pkey_index)
# table_name -> { pkey_index -> node_id}

# initialize node
for table_name, table in db.table_dict.items():
    df = table.df
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()

    gid_init = len(gid_to_entity)
    gid_end = gid_init + len(df)
    gids = np.arange(gid_init, gid_end)
    pkey_idxs = np.arange(len(df))

    tmp_id = [(table_name, i) for i  in range(len(df))]
    gid_to_entity.update(dict(zip(gids, tmp_id)))
    entity_to_gid[table_name] = dict(zip(pkey_idxs, gids))

# initialize edge
row, col = [], []
# source node, target node
for table_name, table in db.table_dict.items():
    df = table.df
    for fkey_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_name]
        # Filter out dangling foreign keys (missing value)
        mask = ~pkey_index.isna()
        fkey_index = pd.Series(np.arange(len(pkey_index)))

        # Filter missing value
        pkey_index = pkey_index[mask].astype(int)
        fkey_index = fkey_index[mask]
        
        # convert to node id
        pkey_gid = pkey_index.map(entity_to_gid[pkey_table_name]).values
        fkey_gid = fkey_index.map(entity_to_gid[table_name]).values

        pkey_gid = torch.LongTensor(pkey_gid)
        fkey_gid = torch.LongTensor(fkey_gid)

        # fkey -> pkey edges
        row.append(fkey_gid)
        col.append(pkey_gid)

        # pkey -> fkey edges
        row.append(pkey_gid)
        col.append(fkey_gid)

row = torch.cat(row, dim=0)
col = torch.cat(col, dim=0)

In [6]:
from torch_cluster import random_walk

In [7]:
# find the target node
node_gid = [ v for k,v in entity_to_gid[task_a.entity_table].items()]
node_gid = sorted(node_gid)
target_node = torch.LongTensor(node_gid)

In [8]:
walk_length = len(db.table_dict)
round = 10
walks = []
for _ in range(round):
    walk = random_walk(row, col, target_node, walk_length, p=10, q=0.5)
    # walk -> [num_target_node, walk_length]
    walk = walk.unsqueeze(1)
    # walk -> [num_target_node, 1, walk_length]
    walks.append(walk)

In [9]:
node_bags = torch.concatenate(walks, dim = 1)
# node_bags -> [num_target_node, round, walk_length]
node_bags.shape

torch.Size([249730, 10, 16])

In [7]:
# prepare the data for BM25 relevance score

# remove the foreignKey columns in db.df table, since the graph is already constructed.
# the foreignKey column is duplicated data
for table_name in db.table_dict.keys():
    for fk_col,  _ in db.table_dict[table_name].fkey_col_to_pkey_table.items():
        db.table_dict[table_name].df.drop(fk_col, axis=1, inplace=True)

In [8]:
# convert the text and numerical data to categroical data through cluster or bin
col_to_stype_dict = get_stype_proposal(db)


# rule 0
# based on the column name
# we predefined some
numerical_keywords = [
    'count', 'num', 'amount', 'total', 'length', 'height', 'value', 'rate',  'number',
    'score', 'level', 'size', 'price', 'percent', 'ratio', 'volume', 'index', 'avg', 'max', 'min'
]
categorical_keywords = [
    'type', 'category', 'class', 'label', 'status', 'code', 'id',
    'region', 'zone', 'flag', 'is_', 'has_', 'mode', 'city', 'state', 'zip'
]

text_keywords = [
    'description', 'comments', 'content', 'name', 'review', 'message', 'note', 'query', 'summary'
]


# rule 1
# unique_value < 0.02 * total_value -> categorical data
# rule 1, general rule for text and numerical data

for table_name, table in db.table_dict.items():
    df = table.df

    for col_name in df.columns:
        if col_name not in col_to_stype_dict[table_name]:
            continue
        guess_type = col_to_stype_dict[table_name][col_name]

        # rule 0
        if any([kw in col_name.lower() for kw in text_keywords]):
            if guess_type == stype.text_embedded:
                continue

        if any([kw in col_name.lower() for kw in numerical_keywords]):
            # check the data can be converted to numerical data
            is_convertible = (
                pd.to_numeric(df[col_name], errors='coerce').notna()
                + df[col_name].isna()).all()
            
            if is_convertible:
                if guess_type != stype.numerical:
                    print(
                        f"[Rule 0] Convert {table_name}.{col_name} from {guess_type} to numerical data")
                col_to_stype_dict[table_name][col_name] = stype.numerical
                continue

        unique_value = len(df[col_name].unique())
        count_value = (~df[col_name].isna()).sum()

        if any([kw in col_name.lower() for kw in categorical_keywords]):
            if guess_type != stype.categorical:
                # print the unique value and count value for check
                print(
                    f"[Rule 0] Convert {table_name}.{col_name} from {guess_type} to categorical data")
                print(
                    f"Unique value: {unique_value}, Count value: {count_value}")

            col_to_stype_dict[table_name][col_name] = stype.categorical
            continue

        # rule 1
        if guess_type == stype.categorical or guess_type == stype.timestamp:
            continue
        # check whether can convert to numerical
        is_convertible = (
            pd.to_numeric(df[col_name], errors='coerce').notna()
            + df[col_name].isna()).all()
        
        if is_convertible and guess_type == stype.numerical:
            continue

        # for  type  numerical or text_embedding check Rule 1
        if unique_value*1.0 / count_value < 0.02:
            # minimum average frequency is 50.
            col_to_stype_dict[table_name][col_name] = stype.categorical
            print(
                f"[Rule 1] Convert {table_name}.{col_name} from {guess_type} to categorical data")
            print(f"Unique value: {unique_value}, Count value: {count_value}")

[Rule 0] Convert interventions.intervention_id from numerical to categorical data
Unique value: 3462, Count value: 3462
[Rule 0] Convert interventions_studies.id from numerical to categorical data
Unique value: 171771, Count value: 171771
[Rule 0] Convert facilities_studies.id from numerical to categorical data
Unique value: 1798765, Count value: 1798765
[Rule 0] Convert sponsors.sponsor_id from numerical to categorical data
Unique value: 53241, Count value: 53241
[Rule 0] Convert sponsors.agency_class from text_embedded to categorical data
Unique value: 9, Count value: 53241
[Rule 0] Convert eligibilities.id from numerical to categorical data
Unique value: 249730, Count value: 249730
[Rule 1] Convert eligibilities.minimum_age from text_embedded to categorical data
Unique value: 289, Count value: 234048
[Rule 1] Convert eligibilities.maximum_age from text_embedded to categorical data
Unique value: 435, Count value: 132548
[Rule 0] Convert reported_event_totals.id from numerical to cate

In [9]:
# for numerical data, we try to cluster and bin it convert it to categorical data
# for text data, we encoded it using pre-trained model and cluster them to categorical data

table_data = {}
# table_name -> pd.df
for table_name, table in db.table_dict.items():
    df = table.df.copy()
    table_data[table_name] = df
    for col_name in df.columns:
        if col_name not in col_to_stype_dict[table_name]:
            continue

        dtype = col_to_stype_dict[table_name][col_name]
        if dtype == stype.numerical:
            # bin
            # we assign equal-width bin, and dynamically adjust the bin size

            # step 1. dynamically adjust the bin number
            n = (~df[col_name].isna()).sum()
            binned = pd.Series(index=df.index, dtype='object')

            if n > 1_000:
                # Rice rule
                bin_num = math.ceil(2 * n**(1/3))
            else:
                # n < 1000,
                # Sturges' formula
                bin_num = math.ceil(math.log2(n) + 1)

            # Step 2. bin the outlier first
            q1 = df[col_name].quantile(0.1)
            q2 = df[col_name].quantile(0.9)
            upper_bound = q2 + 1.5 * (q2 - q1)
            lower_bound = q1 - 1.5 * (q2 - q1)

            upper_outlier_mask = df[col_name] > upper_bound
            lower_outlier_mask = df[col_name] < lower_bound
            normal_mask = ~upper_outlier_mask & ~lower_outlier_mask
            normal_mask = normal_mask & df[col_name].notna()

            upper_outlier_label = f"larger than {upper_bound}"
            lower_outlier_label = f"smaller than {lower_bound}"

            binned[upper_outlier_mask] = upper_outlier_label
            binned[lower_outlier_mask] = lower_outlier_label

            # Step 3. bin the normal data

            # first get the bin labeled
            _, bin_edges = pd.cut(df[col_name][normal_mask], bins=bin_num,
                                  labels=False, retbins=True, include_lowest=True)

            bin_labels = [
                f"{bin_edges[i]:.0f}-{bin_edges[i+1]:.0f}" if bin_edges[i].is_integer() and bin_edges[i+1].is_integer()
                else f"{bin_edges[i]:.2f}-{bin_edges[i+1]:.2f}"
                for i in range(len(bin_edges) - 1)
            ]
            
            binned[normal_mask] = pd.cut(df[col_name][normal_mask], bins=bin_edges, labels=bin_labels)
            
            print(f"Bin {table_name}.{col_name} to {bin_num} bins, convert numerical data to categorical data")
            df[col_name] = binned
            
        if dtype == stype.text_embedded:
            # encode the text in next step
            pass

Bin reported_event_totals.subjects_affected to 137 bins, convert numerical data to categorical data
Bin reported_event_totals.subjects_at_risk to 136 bins, convert numerical data to categorical data
Bin drop_withdrawals.count to 146 bins, convert numerical data to categorical data
Bin studies.enrollment to 126 bins, convert numerical data to categorical data
Bin studies.number_of_arms to 115 bins, convert numerical data to categorical data
Bin studies.number_of_groups to 65 bins, convert numerical data to categorical data
Bin outcome_analyses.param_value to 107 bins, convert numerical data to categorical data
Bin outcome_analyses.dispersion_value to 70 bins, convert numerical data to categorical data
Bin outcome_analyses.p_value to 115 bins, convert numerical data to categorical data
Bin outcome_analyses.ci_percent to 109 bins, convert numerical data to categorical data
Bin outcome_analyses.ci_lower_limit to 105 bins, convert numerical data to categorical data
Bin outcome_analyses.ci_u

In [13]:
# for text type data, basically, text type data is more diverse than numerical data and categorical data
# there are several steps to process:
# --- 1. using pre-trained model to encode the text data.
# --- 2. generally the text embedding is high-dimensional, we reduce the dimensionality in each attribute.
# --- 3. we cluster these low-dimensional vector to bins
# --- 4. we convert the cluster result to categorical data

# <Fail>
# from sentence_transformers import SentenceTransformer
# import umap
# import hdbscan

# encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# for table_name, df in table_data.items():
#     for col_name in df.columns:
#         if col_name not in col_to_stype_dict[table_name]:
#             continue
#         dtype = col_to_stype_dict[table_name][col_name]
#         if dtype != stype.text_embedded:
#             continue
        
#         # preprocess text embedding
#         series = df[col_name]
#         sentences = series[series.notna()].tolist()
        
#         sentence_set = list(set(sentences))
#         embedding_set = encoder.encode(sentence_set)
#         # reduce the dimension from 384 to 10 using UMAP
#         std_embedding_set = umap.UMAP(n_components=10, random_state = 2025, n_jobs = 1).fit_transform(embedding_set)
        
#         # using HDBSCAN to cluster the data
#         labels = hdbscan.HDBSCAN(
#             min_samples = 10,
#             min_cluster_size = 30,
#         ).fit_predict(std_embedding_set)
        
#         unique_elements, counts = np.unique(labels, return_counts = True)
       
#         # convert outlier to outlier label
#         outlier_mask = labels == -1
#         outlier_num = outlier_mask.sum()
#         print(outlier_num)

#         labels[outlier_mask] = list(range(len(unique_elements) - 1, len(unique_elements) - 1 + outlier_num))
#         unique_num = len(unique_elements) - 1 + outlier_num
        
#         print(f"Cluster {table_name}.{col_name} to {unique_num} bins, contains {outlier_num} outlier, Text Embedding to Categorical data")
        
#         # convert the label to categorical data
#         sentence_to_label = {st: int(y) for (st, y) in zip(sentence_set, labels)}
#         # series[series.notna()] = series[series.notna()].map(sentence_to_label)

In [14]:
for table_name, col_types in col_to_stype_dict.items():
    for col_name, col_type in col_types.items():
        if col_type == stype.text_embedded:
            print(f"{table_name}.{col_name} is text_embedded")

interventions.mesh_term is text_embedded
sponsors.name is text_embedded
eligibilities.population is text_embedded
eligibilities.criteria is text_embedded
eligibilities.gender_description is text_embedded
designs.masking_description is text_embedded
designs.intervention_model_description is text_embedded
drop_withdrawals.reason is text_embedded
studies.target_duration is text_embedded
studies.acronym is text_embedded
studies.baseline_population is text_embedded
studies.brief_title is text_embedded
studies.official_title is text_embedded
studies.source is text_embedded
studies.biospec_description is text_embedded
studies.detailed_descriptions is text_embedded
studies.brief_summaries is text_embedded
outcome_analyses.non_inferiority_description is text_embedded
outcome_analyses.p_value_description is text_embedded
outcome_analyses.method_description is text_embedded
outcome_analyses.estimate_description is text_embedded
outcome_analyses.groups_description is text_embedded
outcome_analyses

In [None]:
# for text type data, basically, text type data is more diverse than numerical data and categorical data
# there are several steps to process:

# --- 1. using pre-trained model to extract the keywords from text.
# --- 2. these keywords are concatenated with the column name to form the token in "doc"

# in KeyBERT model, we set the setting to default, only extract one-gram keywors for simplicity.

# why just extract keywords from this keywords
# -- reduces noise by focusing only on meaningful and frequent concepts.
# -- creates a more compact document with relevant terms.
from keybert import KeyBERT
kw_model = KeyBERT()
text_dict_in_col_table = {}
for table_name, df in table_data.items():
    for col_name in df.columns:
        if col_name not in col_to_stype_dict[table_name]:
            continue

        dtype = col_to_stype_dict[table_name][col_name]
        if dtype != stype.text_embedded:
            continue

        # only text embedding data
        series = df[col_name]
        sentences =  series[series.notna()].tolist()
        sentence_set = list(set(sentences))
        print(f"=> Table @{table_name} and col @{col_name} has {len(series)} records, has {len(sentence_set)} unique sentences")
        
        # convert the sentences to keyword set
        # construct a dict {sentence : keyword set}
        batch_size = 2048
        sentence_to_kws = {}
        for i in range(0, len(sentence_set), batch_size):
            batch = sentence_set[i:i+batch_size]
            keywords = kw_model.extract_keywords(batch)
            sentence_to_kws.update(dict(zip(batch, keywords)))

        sentence_to_kws = {st: [kw for kw, _ in kws] for st, kws in sentence_to_kws.items()}
        # remove relevance score
        
        print(f"Convert the [{table_name}.{col_name}] from text to keywords set")
        # assign keywords_sets to the dataframe
        # df[mask][col_name] = keywords_series
        text_dict_in_col_table[(table_name, col_name)] = sentence_to_kws
        # each keyword is like this [('keyword1', 0.9), ('keyword2', 0.8), ('keyword3', 0.7)]

In [16]:
# [IO READ]
# read the text_dict
import pickle
tmp_text_dict_file_path = './tmp/text_dict.pkl'
with open(tmp_text_dict_file_path, "rb") as f:
    text_dict_in_col_table = pickle.load(f)

In [None]:
# [IO WRIRTE]
# temporarily save the text_dict_in_col_table
import pickle
tmp_text_dict_file_path = './tmp/text_dict.pkl'
with open(tmp_text_dict_file_path, 'wb') as f:
    pickle.dump(text_dict_in_col_table, f)
file_size = os.path.getsize(tmp_text_dict_file_path)
print(f"Save the text_dict_in_col_table to {tmp_text_dict_file_path}, size: {file_size} bytes")

In [17]:
# collect the walk related node and convert it to "Doc"
# the "word" should be table_name + col_name + value
# if the column is type, the "word" should be table_name + col_name + keyword

In [18]:
# first preprocess
# directly process a gid -> doc Dict
from functools import reduce
from tqdm import tqdm

table_idx_to_doc = {}

for table_name, df in tqdm(table_data.items()):
    pkey_to_gid = entity_to_gid[table_name]
    
    tmp_df = pd.DataFrame()
    
    for col_name in df.columns:
        if col_name not in col_to_stype_dict[table_name]:
            continue
        
        col_type = col_to_stype_dict[table_name][col_name]
        
        if col_type is stype.numerical or col_type is stype.categorical or col_type is stype.timestamp:
            series = df[col_name].apply(lambda x: [f"{table_name}.{col_name}.{x}"] if x is not None else [])
            
        elif col_type is stype.text_embedded:
            sentence_to_kws = text_dict_in_col_table[(table_name, col_name)]
            sentence_to_words = {st: [f"{table_name}.{col_name}.{kw}" for kw in kws] for st, kws in sentence_to_kws.items()}
            series = df[col_name].apply(lambda x: sentence_to_words[x] if x is not None else [])

        tmp_df[col_name] = series
        
    # reduce each row
    idx_to_keywords = tmp_df.apply(lambda x: reduce(lambda a, b: a + b, x.tolist()), axis=1).tolist()
    table_idx_to_doc[table_name] = idx_to_keywords

100%|██████████| 15/15 [01:36<00:00,  6.43s/it]


In [None]:
# [IO]
# temporarily save the gid_to_doc and entity_to_gid


tmp_gid_to_doc_file_path = './tmp/table_idx_to_doc.pkl'
with open(tmp_gid_to_doc_file_path, 'wb') as f:
    pickle.dump(table_idx_to_doc, f)
# table-> List[List[kw]], the idx is pkey 
file_size = os.path.getsize(tmp_gid_to_doc_file_path)
print(f"Save the gid_to_doc to {tmp_gid_to_doc_file_path}, size: {file_size} bytes")

tmp_entity_to_gid_file_path = './tmp/entity_to_gid.pkl'
with open(tmp_entity_to_gid_file_path, 'wb') as f:
    pickle.dump(entity_to_gid, f)
file_size = os.path.getsize(tmp_entity_to_gid_file_path)
print(f"Save the entity_to_gid to {tmp_entity_to_gid_file_path}, size: {file_size} bytes")


tmp_git_to_entity_file_path = './tmp/gid_to_entity.pkl'
with open(tmp_git_to_entity_file_path, 'wb') as f:
    pickle.dump(gid_to_entity, f)
# gid -> (table_name, pkey_index)
file_size = os.path.getsize(tmp_git_to_entity_file_path)
print(f"Save the gid_to_entity to {tmp_git_to_entity_file_path}, size: {file_size} bytes")

In [19]:
# based on walk and gid_to_doc, to generate these docs
from functools import reduce
docs = []
for idx in tqdm(range(node_bags.shape[0])):
    walks = node_bags[idx].tolist() # [round, walk_length]
    # for each walks, we remove repeated node to construct subgraph.
    # make sure in each round of sample (RandomWalk), the node is unique
    doc = reduce(lambda a, b: list(set(a)) + list(set(b)), walks)
    doc = map(lambda x: table_idx_to_doc[gid_to_entity[x][0]][gid_to_entity[x][1]], doc)
    doc = reduce(lambda a, b: a + b, doc)
    docs.append(doc)

100%|██████████| 249730/249730 [01:06<00:00, 3764.37it/s]


In [21]:
# [IO WRITE]
# tmporarily save these docs
# the target node is table in task_a
entity = task_a.entity_table
file_name = f"{entity}_docs.pkl"
file_path = f"./tmp/{file_name}"
with open(file_path, 'wb') as f:
    pickle.dump(docs, f)
file_size = os.path.getsize(file_path)
print(f"Save the {entity} docs to {file_path}, size: {file_size} bytes")

Save the studies docs to ./tmp/studies_docs.pkl, size: 1666987496 bytes
