### Construct self-supervised learning process

In [1]:
%cd ..
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from model.utils import InfoNCE
from relbench.modeling.nn import HeteroTemporalEncoder
from relbench.datasets import get_dataset
from relbench.tasks import get_task
from relbench.base import BaseTask
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal
import os
import math
import numpy as np
from tqdm import tqdm
import copy

import torch

seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [2]:
dataset = get_dataset(name="rel-trial", download = True)
db = dataset.get_db()
task_a = get_task("rel-trial", "study-outcome", download = True)

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.74 seconds.


In [3]:
from torch_frame import stype
from torch import Tensor
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
from typing import List, Optional
import pandas as pd

In [4]:
col_to_stype_dict = get_stype_proposal(db)

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)

  return self.fget.__get__(instance, owner)()


In [5]:
# preprocess col_to_stype_dict using our own rules

# rule 0
# based on the column name
# we predefined some
numerical_keywords = [
    'count', 'num', 'amount', 'total', 'length', 'height', 'value', 'rate',  'number',
    'score', 'level', 'size', 'price', 'percent', 'ratio', 'volume', 'index', 'avg', 'max', 'min'
]
categorical_keywords = [
    'type', 'category', 'class', 'label', 'status', 'code', 'id',
    'region', 'zone', 'flag', 'is_', 'has_', 'mode', 'city', 'state', 'zip'
]

text_keywords = [
    'description', 'comments', 'content', 'name', 'review', 'message', 'note', 'query', 'summary'
]


# rule 1
# unique_value < 0.02 * total_value -> categorical data
# rule 1, general rule for text and numerical data

for table_name, table in db.table_dict.items():
    df = table.df

    for col_name in df.columns:
        if col_name not in col_to_stype_dict[table_name]:
            continue
        guess_type = col_to_stype_dict[table_name][col_name]

        # rule 0
        if any([kw in col_name.lower() for kw in text_keywords]):
            if guess_type == stype.text_embedded:
                continue

        if any([kw in col_name.lower() for kw in numerical_keywords]):
            # check the data can be converted to numerical data
            is_convertible = (
                pd.to_numeric(df[col_name], errors='coerce').notna()
                + df[col_name].isna()).all()
            
            if is_convertible:
                if guess_type != stype.numerical:
                    print(
                        f"[Rule 0] Convert {table_name}.{col_name} from {guess_type} to numerical data")
                col_to_stype_dict[table_name][col_name] = stype.numerical
                continue

        unique_value = len(df[col_name].unique())
        count_value = (~df[col_name].isna()).sum()

        if any([kw in col_name.lower() for kw in categorical_keywords]):
            if guess_type != stype.categorical:
                # print the unique value and count value for check
                print(
                    f"[Rule 0] Convert {table_name}.{col_name} from {guess_type} to categorical data")
                print(
                    f"Unique value: {unique_value}, Count value: {count_value}")

            col_to_stype_dict[table_name][col_name] = stype.categorical
            continue

        # rule 1
        if guess_type == stype.categorical or guess_type == stype.timestamp:
            continue
        # check whether can convert to numerical
        is_convertible = (
            pd.to_numeric(df[col_name], errors='coerce').notna()
            + df[col_name].isna()).all()
        
        if is_convertible and guess_type == stype.numerical:
            continue

        # for  type  numerical or text_embedding check Rule 1
        if unique_value*1.0 / count_value < 0.02:
            # minimum average frequency is 50.
            col_to_stype_dict[table_name][col_name] = stype.categorical
            print(
                f"[Rule 1] Convert {table_name}.{col_name} from {guess_type} to categorical data")
            print(f"Unique value: {unique_value}, Count value: {count_value}")

[Rule 0] Convert interventions.intervention_id from numerical to categorical data
Unique value: 3462, Count value: 3462
[Rule 0] Convert interventions_studies.id from numerical to categorical data
Unique value: 171771, Count value: 171771
[Rule 0] Convert interventions_studies.nct_id from numerical to categorical data
Unique value: 90364, Count value: 171771
[Rule 0] Convert interventions_studies.intervention_id from numerical to categorical data
Unique value: 3432, Count value: 171771
[Rule 0] Convert facilities_studies.id from numerical to categorical data
Unique value: 1798765, Count value: 1798765
[Rule 0] Convert facilities_studies.nct_id from numerical to categorical data
Unique value: 227838, Count value: 1798765
[Rule 0] Convert facilities_studies.facility_id from numerical to categorical data
Unique value: 431513, Count value: 1798765
[Rule 0] Convert sponsors.sponsor_id from numerical to categorical data
Unique value: 53241, Count value: 53241
[Rule 0] Convert sponsors.agency

In [7]:
# build heterogeneous graph.
# remove the primary key
from relbench.modeling.graph import make_pkey_fkey_graph
root_dir = "/home/lingze/embedding_fusion/data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-trial_materialized_cache"
    ),  # store materialized graph for convenience
)

In [8]:
from torch_geometric.loader import NeighborLoader

In [9]:
# construct input node loader
BATCH_SIZE = 512
def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

def generate_loader(entity_table: str, node_idxs: Tensor):
    nodes = (entity_table, node_idxs)
    n = node_idxs.size(0)
    # set the time to the db max timestamp
    time_series = pd.Series([db.max_timestamp] * n)
    input_time = torch.from_numpy(to_unix_time(time_series))
    # no temporal information, we use max_timestamp in training
    loader = NeighborLoader(
        data,
        num_neighbors= [
            128 for i in range(2)
        ], # number of neighbors to sample for each layer
        input_nodes = nodes,
        time_attr = "time",
        input_time = input_time,
        batch_size = BATCH_SIZE,
        shuffle=True,
        num_workers = 0,
    )
    return loader    

In [10]:
entity_table = task_a.entity_table
entity_table

'studies'

In [11]:
# construct bottom model
channels = 64
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [12]:
# load the pre-calculated positive sample index
path = './tmp/studies_positive_samples.npy'
positive_sample_index = np.load(path)
positive_sample_index.shape

(249730, 100)

In [13]:
node_idx = torch.from_numpy(db.table_dict[entity_table].df.index.astype(int).values)
ssl_loader = generate_loader(entity_table, node_idx)

In [17]:
lr = 0.001
negative_num = 20
temprature = 0.01
net.reset_parameters()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)
epoches = 20
early_restart_steps = 20
loss_fn = InfoNCE(temperature=temprature, negative_mode='paired')

In [19]:
net.to(device)
for epoch in range(1, epoches + 1):
    net.train()
    loss_accum = count_accum = 0
    
    for cnt, batch in enumerate(ssl_loader):
        # early restart new epoch    
        if cnt > early_restart_steps:
            break
        
        anchor_nodes = batch[entity_table].n_id.numpy()
        B = anchor_nodes.shape[0]
        # dynamically to sample positive and negative samples
        positive_sample_nodes = []
        negative_sample_indexs = []
        for anchor_idx, anchor_node in enumerate(anchor_nodes):
            # positive sample, just from the pre-calculated positive sample pool
            positive_sample_pool = positive_sample_index[anchor_node]
            positive_sample_nodes.append(
                np.random.choice(positive_sample_pool)
            )
            
            # negative sample, we choose negative samples within the batch
            # we sample the negative node index in the batch.
            negative_sample_index_pool = np.arange(B)
            negative_sample_index_pool = negative_sample_index_pool[negative_sample_index_pool != anchor_idx]
            negative_sample_index = np.random.choice(negative_sample_index_pool, negative_num, replace=False)
            negative_samples = anchor_nodes[negative_sample_index]
            
            loop_n = 0
            while set(negative_samples.tolist()) & set(positive_sample_pool.tolist()) and loop_n < 5: 
                negative_sample_index = np.random.choice(negative_sample_index_pool, negative_num, replace=False)
                negative_samples = anchor_nodes[negative_sample_index]
                loop_n += 1 # avoid infinite loop
            
            negative_sample_indexs.append(negative_sample_index)
        
        # generate positive samples neighbors
        time_series = pd.Series([db.max_timestamp] * B)
        input_time = torch.from_numpy(to_unix_time(time_series))
        positive_loader = NeighborLoader(
            data,
            num_neighbors = [128 for i in range(2)],
            input_time= input_time,
            time_attr= "time",
            input_nodes = (entity_table, torch.from_numpy(np.array(positive_sample_nodes)).to(torch.long)),
            batch_size = BATCH_SIZE,
            shuffle = False,
            num_workers = 0,
        )
        positive_batch = next(iter(positive_loader))
        # [B,]
        negative_sample_indexs = torch.from_numpy(np.array(negative_sample_indexs)).to(torch.long)
        # [B, neg_num]
        
        # calculate the loss
        
        optimizer.zero_grad()
        batch, positive_batch = batch.to(device), positive_batch.to(device)
        negative_sample_indexs = negative_sample_indexs.to(device)
        
        anchor_nodes_embedding = net.get_node_embedding(batch, entity_table)[entity_table]
        positive_nodes_embedding = net.get_node_embedding(positive_batch, entity_table)[entity_table]
        negative_nodes_embedding = anchor_nodes_embedding[negative_sample_indexs]
        
        loss = loss_fn(anchor_nodes_embedding, positive_nodes_embedding, negative_nodes_embedding)
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() 
        count_accum += B
    train_loss = loss_accum / count_accum
    print(f"Epoch {epoch}, Loss: {train_loss}")

Epoch 1, Loss: 0.021847360279588474
Epoch 2, Loss: 0.018506837920064016
Epoch 3, Loss: 0.018144231910506885
Epoch 4, Loss: 0.01786864921450615
Epoch 5, Loss: 0.01788589047888915
