In [1]:
import sys
sys.path.append('../../../../')
from aiagents4pharma.talk2knowledgegraphs.datasets.biobridge_datamodule import BioBridgeDataModule

  from .autonotebook import tqdm as notebook_tqdm
INFO:aiagents4pharma.talk2scholars.tools.pdf.question_and_answer:Loaded Question and Answer tool configuration.


In [2]:
from aiagents4pharma.talk2knowledgegraphs.models.nbfnet import tasks, util
import torch
import pprint
from torch_geometric.data import Data

In [3]:
# Prepare BioBridge dataset
configs = {
    "train_ratio": 0.8,
    "val_ratio": 0.1,
    "test_ratio": 0.1,
    "num_workers": 0,
    "batch_size": 32,
    "random_state": 0,
}
biobridge_dataset = BioBridgeDataModule(primekg_dir="../../../../../data/primekg",
                                        biobridge_dir="../../../../../data/biobridge_primekg",
                                        configs=configs)
biobridge_dataset.prepare_data()
biobridge_dataset.setup()

Loading PrimeKG dataset...
Loading nodes of PrimeKG dataset ...
../../../../../data/primekg/primekg_nodes.tsv.gz already exists. Loading the data from the local directory.
Loading edges of PrimeKG dataset ...
../../../../../data/primekg/primekg_edges.tsv.gz already exists. Loading the data from the local directory.
Loading data config file of BioBridgePrimeKG...
File data_config.json already exists in ../../../../../data/biobridge_primekg.
Building node embeddings...
Building full triplets...
Building train-test split...


In [4]:
# Get dataset
dataset = biobridge_dataset.data["set"]
train_data = biobridge_dataset.data["train"]
val_data = biobridge_dataset.data["val"]
test_data = biobridge_dataset.data["test"]

In [5]:
# Arguments to run the model
args_config = "../../../../aiagents4pharma/talk2knowledgegraphs/models/config/transductive/biobridge.yaml"
args_seed = 1024
vars = {"gpus": "null"}

cfg = util.load_config(args_config, context=vars)
working_dir = util.create_working_directory(cfg)

In [6]:
torch.manual_seed(args_seed + util.get_rank())

<torch._C.Generator at 0x7bd05e30f370>

In [7]:
logger = util.get_root_logger()
if util.get_rank() == 0:
    logger.warning("Random seed: %d" % args_seed)
    logger.warning("Config file: %s" % args_config)
    logger.warning(pprint.pformat(cfg))

             'root': '~/datasets/knowledge_graphs/'},
 'model': {'aggregate_func': 'pna',
           'class': 'NBFNet',
           'dependent': True,
           'hidden_dims': [32, 32, 32, 32, 32, 32],
           'input_dim': 32,
           'layer_norm': True,
           'message_func': 'distmult',
           'remove_one_hop': True,
           'short_cut': True},
 'optimizer': {'class': 'Adam', 'lr': 0.005},
 'output_dir': '~/experiments/',
 'task': {'adversarial_temperature': 0.5,
          'metric': ['mr', 'mrr', 'hits@1', 'hits@3', 'hits@10'],
          'num_negative': 32,
          'strict_negative': True},
 'train': {'batch_size': 8, 'gpus': None, 'log_interval': 100, 'num_epoch': 1}}


In [8]:
# Update config with dataset information
is_inductive = cfg.dataset["class"].startswith("Ind")
# dataset = util.build_dataset(cfg) # We use the dataset from BioBridgeDataModule
cfg.model.num_relation = dataset.num_relations
cfg

{'output_dir': '~/experiments/',
 'dataset': {'class': 'BioBridgeDataModule',
  'root': '~/datasets/knowledge_graphs/'},
 'model': {'class': 'NBFNet',
  'input_dim': 32,
  'hidden_dims': [32, 32, 32, 32, 32, 32],
  'message_func': 'distmult',
  'aggregate_func': 'pna',
  'short_cut': True,
  'layer_norm': True,
  'dependent': True,
  'remove_one_hop': True,
  'num_relation': 18},
 'task': {'num_negative': 32,
  'strict_negative': True,
  'adversarial_temperature': 0.5,
  'metric': ['mr', 'mrr', 'hits@1', 'hits@3', 'hits@10']},
 'optimizer': {'class': 'Adam', 'lr': 0.005},
 'train': {'gpus': None, 'batch_size': 8, 'num_epoch': 1, 'log_interval': 100}}

In [9]:
import pandas as pd
a = pd.DataFrame.from_dict({"node_type": biobridge_dataset.mapper["ntid2dim"].keys(), "node_dim": biobridge_dataset.mapper["ntid2dim"].values()})

In [10]:
{k: v for k, v in a.values}

{0: 768, 7: 768, 2: 768, 6: 512, 5: 768, 1: 2560}

In [11]:
# Add BioBridge parameters
cfg.model.biobridge = {
    "nodes": biobridge_dataset.nodes,
    "mapper_ntid2dim": pd.DataFrame.from_dict({"node_type": biobridge_dataset.mapper["ntid2dim"].keys(), "node_dim": biobridge_dataset.mapper["ntid2dim"].values()})
}

In [12]:
# Build model
model = util.build_model(cfg)
model

NBFNet(
  (project): ModuleDict(
    (node_type_0): Sequential(
      (0): Linear(in_features=768, out_features=32, bias=True)
      (1): ReLU()
    )
    (node_type_7): Sequential(
      (0): Linear(in_features=768, out_features=32, bias=True)
      (1): ReLU()
    )
    (node_type_2): Sequential(
      (0): Linear(in_features=768, out_features=32, bias=True)
      (1): ReLU()
    )
    (node_type_6): Sequential(
      (0): Linear(in_features=512, out_features=32, bias=True)
      (1): ReLU()
    )
    (node_type_5): Sequential(
      (0): Linear(in_features=768, out_features=32, bias=True)
      (1): ReLU()
    )
    (node_type_1): Sequential(
      (0): Linear(in_features=2560, out_features=32, bias=True)
      (1): ReLU()
    )
  )
  (layers): ModuleList(
    (0-5): 6 x GeneralizedRelationalConv()
  )
  (query): Embedding(18, 32)
  (mlp): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True

In [13]:
# Set device for each data
device = util.get_device(cfg)
model = model.to(device)
train_data, valid_data, test_data = dataset[0], dataset[1], dataset[2]
train_data = train_data.to(device)
valid_data = valid_data.to(device)
test_data = test_data.to(device)

In [14]:
# Filtered data for ranking
if is_inductive:
    # for inductive setting, use only the test fact graph for filtered ranking
    filtered_data = None
else:
    # for transductive setting, use the whole graph for filtered ranking
    filtered_data = Data(edge_index=dataset.data.target_edge_index, edge_type=dataset.data.target_edge_type)
    filtered_data = filtered_data.to(device)



### Training

In [15]:
import math
import torch
from torch import optim
from torch import nn
from torch.nn import functional as F
from torch import distributed as dist
from torch.utils import data as torch_data

In [16]:
separator = ">" * 30
line = "-" * 30

In [17]:
@torch.no_grad()
def test(cfg, model, test_data, filtered_data=None):
    world_size = util.get_world_size()
    rank = util.get_rank()

    test_triplets = torch.cat([test_data.target_edge_index, test_data.target_edge_type.unsqueeze(0)]).t()
    sampler = torch_data.DistributedSampler(test_triplets, world_size, rank)
    test_loader = torch_data.DataLoader(test_triplets, cfg.train.batch_size, sampler=sampler)

    model.eval()
    rankings = []
    num_negatives = []
    for batch in test_loader:
        t_batch, h_batch = tasks.all_negative(test_data, batch)
        t_pred = model(test_data, t_batch)
        h_pred = model(test_data, h_batch)

        if filtered_data is None:
            t_mask, h_mask = tasks.strict_negative_mask(test_data, batch)
        else:
            t_mask, h_mask = tasks.strict_negative_mask(filtered_data, batch)
        pos_h_index, pos_t_index, pos_r_index = batch.t()
        t_ranking = tasks.compute_ranking(t_pred, pos_t_index, t_mask)
        h_ranking = tasks.compute_ranking(h_pred, pos_h_index, h_mask)
        num_t_negative = t_mask.sum(dim=-1)
        num_h_negative = h_mask.sum(dim=-1)

        rankings += [t_ranking, h_ranking]
        num_negatives += [num_t_negative, num_h_negative]

    ranking = torch.cat(rankings)
    num_negative = torch.cat(num_negatives)
    all_size = torch.zeros(world_size, dtype=torch.long, device=device)
    all_size[rank] = len(ranking)
    if world_size > 1:
        dist.all_reduce(all_size, op=dist.ReduceOp.SUM)
    cum_size = all_size.cumsum(0)
    all_ranking = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
    all_ranking[cum_size[rank] - all_size[rank]: cum_size[rank]] = ranking
    all_num_negative = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
    all_num_negative[cum_size[rank] - all_size[rank]: cum_size[rank]] = num_negative
    if world_size > 1:
        dist.all_reduce(all_ranking, op=dist.ReduceOp.SUM)
        dist.all_reduce(all_num_negative, op=dist.ReduceOp.SUM)

    if rank == 0:
        for metric in cfg.task.metric:
            if metric == "mr":
                score = all_ranking.float().mean()
            elif metric == "mrr":
                score = (1 / all_ranking.float()).mean()
            elif metric.startswith("hits@"):
                values = metric[5:].split("_")
                threshold = int(values[0])
                if len(values) > 1:
                    num_sample = int(values[1])
                    # unbiased estimation
                    fp_rate = (all_ranking - 1).float() / all_num_negative
                    score = 0
                    for i in range(threshold):
                        # choose i false positive from num_sample - 1 negatives
                        num_comb = math.factorial(num_sample - 1) / \
                                   math.factorial(i) / math.factorial(num_sample - i - 1)
                        score += num_comb * (fp_rate ** i) * ((1 - fp_rate) ** (num_sample - i - 1))
                    score = score.mean()
                else:
                    score = (all_ranking <= threshold).float().mean()
            logger.warning("%s: %g" % (metric, score))
    mrr = (1 / all_ranking.float()).mean()

    return mrr

In [18]:
def train_and_validate(cfg, model, train_data, valid_data, filtered_data=None):
    if cfg.train.num_epoch == 0:
        return

    world_size = util.get_world_size()
    rank = util.get_rank()

    train_triplets = torch.cat([train_data.target_edge_index, train_data.target_edge_type.unsqueeze(0)]).t()
    sampler = torch_data.DistributedSampler(train_triplets, world_size, rank)
    train_loader = torch_data.DataLoader(train_triplets, cfg.train.batch_size, sampler=sampler)

    cls = cfg.optimizer.pop("class")
    optimizer = getattr(optim, cls)(model.parameters(), **cfg.optimizer)
    if world_size > 1:
        parallel_model = nn.parallel.DistributedDataParallel(model, device_ids=[device])
    else:
        parallel_model = model

    step = math.ceil(cfg.train.num_epoch / 10)
    best_result = float("-inf")
    best_epoch = -1

    batch_id = 0
    for i in range(0, cfg.train.num_epoch, step):
        parallel_model.train()
        for epoch in range(i, min(cfg.train.num_epoch, i + step)):
            if util.get_rank() == 0:
                logger.warning(separator)
                logger.warning("Epoch %d begin" % epoch)

            losses = []
            sampler.set_epoch(epoch)
            for batch in train_loader:
                batch = tasks.negative_sampling(train_data, batch, cfg.task.num_negative,
                                                strict=cfg.task.strict_negative)
                pred = parallel_model(train_data, batch)
                target = torch.zeros_like(pred)
                target[:, 0] = 1
                loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
                neg_weight = torch.ones_like(pred)
                if cfg.task.adversarial_temperature > 0:
                    with torch.no_grad():
                        neg_weight[:, 1:] = F.softmax(pred[:, 1:] / cfg.task.adversarial_temperature, dim=-1)
                else:
                    neg_weight[:, 1:] = 1 / cfg.task.num_negative
                loss = (loss * neg_weight).sum(dim=-1) / neg_weight.sum(dim=-1)
                loss = loss.mean()

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                if util.get_rank() == 0 and batch_id % cfg.train.log_interval == 0:
                    logger.warning(separator)
                    logger.warning("binary cross entropy: %g" % loss)
                losses.append(loss.item())
                batch_id += 1

            if util.get_rank() == 0:
                avg_loss = sum(losses) / len(losses)
                logger.warning(separator)
                logger.warning("Epoch %d end" % epoch)
                logger.warning(line)
                logger.warning("average binary cross entropy: %g" % avg_loss)

        epoch = min(cfg.train.num_epoch, i + step)
        if rank == 0:
            logger.warning("Save checkpoint to model_epoch_%d.pth" % epoch)
            state = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict()
            }
            torch.save(state, "model_epoch_%d.pth" % epoch)
        util.synchronize()

        if rank == 0:
            logger.warning(separator)
            logger.warning("Evaluate on valid")
        result = test(cfg, model, valid_data, filtered_data=filtered_data)
        if result > best_result:
            best_result = result
            best_epoch = epoch

    if rank == 0:
        logger.warning("Load checkpoint from model_epoch_%d.pth" % best_epoch)
    state = torch.load("model_epoch_%d.pth" % best_epoch, map_location=device)
    model.load_state_dict(state["model"])
    util.synchronize()

In [19]:
import os
os.environ["TORCH_CUDA_ARCH_LIST"] = "5.0;6.0;7.0;7.5;8.0;8.6;9.0+PTX"
os.environ["CUDA_HOME"] = "/usr/local/cuda-12.8"
os.environ["PATH"] = os.environ["CUDA_HOME"] + "/bin:" + os.environ["PATH"]
# os.environ["LD_LIBRARY_PATH"] = os.environ["CUDA_HOME"] + "/lib64:" + os.environ["LD_LIBRARY_PATH"]

# export TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0+PTX"
# export CUDA_HOME=/usr/local/cuda-12.8
# export PATH=$CUDA_HOME/bin:$PATH
# export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

In [20]:
# Train and validate the model
train_and_validate(cfg, model, train_data, valid_data, filtered_data=filtered_data)



Load rspmm extension. This may take a while...




IndexError: index out of range in self

### Evaluation on Validation Set

In [None]:
# Validate
if util.get_rank() == 0:
    logger.warning(separator)
    logger.warning("Evaluate on valid")
test(cfg, model, valid_data, filtered_data=filtered_data)


### Evaluation on Testing Set

In [None]:
if util.get_rank() == 0:
    logger.warning(separator)
    logger.warning("Evaluate on test")
test(cfg, model, test_data, filtered_data=filtered_data)