In [None]:
import os
from tqdm import tqdm
import numpy as np
from itertools import product
from hyperopt import hp
from ray import tune
from ray.tune.search.hyperopt import HyperOptSearch
from ray.air.config import RunConfig
from ogb.linkproppred import LinkPropPredDataset, Evaluator

In [None]:
dataset_name = 'ogbl-biokg' # ogbl-biokg, ogbl-wikikg2

rank_path = f"ranks/{dataset_name}"
checkpoint_path = f"weights/{dataset_name}_rel_weights.npy"

max_concurrent_trials = 8
num_samples = 100
num_initial_points = 20

In [None]:
evaluator = Evaluator(name=dataset_name)

if dataset_name == 'ogbl-biokg':
    model_names = ['KGBench', 'ComplexRP', 'TripleRE']
elif dataset_name == 'ogbl-wikikg2':
    model_names = model_names = ['Text', 'InterHTPlus', 'StarGraph']
else:
    raise NotImplementedError(f"Unsupported dataset: {dataset_name}")

In [None]:
print(f"Loading dataset: {dataset_name}")
ranks = {
    'valid': [np.load(f"{rank_path}/{m}_valid_ranks.npy") for m in model_names],
    'test': [np.load(f"{rank_path}/{m}_test_ranks.npy") for m in model_names]
}

n_model = len(ranks['test'])

In [None]:
def eval_model(sub_ranks):
    new_ranks = 502 - sub_ranks

    mrr_head = evaluator.eval({'y_pred_pos': new_ranks[:, 0], 'y_pred_neg': new_ranks[:, 1:501]})['mrr_list'].mean()
    mrr_tail = evaluator.eval({'y_pred_pos': new_ranks[:, 501], 'y_pred_neg': new_ranks[:, 502:]})['mrr_list'].mean()

    return {'mrr': (mrr_head + mrr_tail) / 2}

In [None]:
dataset = LinkPropPredDataset(name=dataset_name)
split_edge = dataset.get_edge_split()
train_triples, valid_triples, test_triples = split_edge["train"], split_edge["valid"], split_edge["test"]

In [None]:
if dataset_name == 'ogbl-biokg':
    test_relation = test_triples['relation']
    valid_relation = valid_triples['relation']
    num_relation = int(max(train_triples['relation']))+1
elif dataset_name == 'ogbl-wikikg2':
    origin_num_relation = int(max(train_triples['relation'].max(), valid_triples['relation'].max(), test_triples['relation'].max()))+1
    test_relation = np.concatenate((test_triples['relation'], test_triples['relation'] + origin_num_relation), axis=0)
    valid_relation = np.concatenate((valid_triples['relation'], valid_triples['relation'] + origin_num_relation), axis=0)
    num_relation = int(max(test_relation.max(), valid_relation.max())) + 1

print(num_relation)

In [None]:
rel_indexes = {
    'valid': {},
    'test': {}
} # relation_id -> np array

for relation_id in range(num_relation):
    rel_indexes['test'][relation_id] = np.where(test_relation == relation_id)[0]
    rel_indexes['valid'][relation_id] = np.where(valid_relation == relation_id)[0]


In [None]:
def objective(config, data):
    sub_ranks = data
    weights = [config[f"w_{i}"] for i in range(len(sub_ranks))]
    ranks_avg = np.average(sub_ranks, weights=weights, axis=0)
    mrr = eval_model(ranks_avg)

    return mrr


In [None]:
default_config = {
    'w_0': 0.33,
    'w_1': 0.33,
    'w_2': 0.34,
}

In [None]:
if checkpoint_path and os.path.exists(checkpoint_path):
    print("Load existing models")
    rel_weights = np.load(checkpoint_path)
else:
    print("Searching for ensemble weights")
    rel_weights = np.zeros((num_relation, n_model))

    search_space = {f"w_{i}": hp.uniform(f"w_{i}", 0, 1)  for i in range(n_model)}
    hyperopt_search = HyperOptSearch(search_space, metric="mrr", mode="max", n_initial_points=num_initial_points)

    for rel_id in tqdm(range(num_relation)):
        if len(rel_indexes['valid'][rel_id]) == 0:
            # default weights
            rel_weights[rel_id] = np.fromiter(default_config.values(), dtype=np.float32)
            continue

        subranks = [model_rank[rel_indexes['valid'][rel_id]] for model_rank in ranks['valid']]
        tuner = tune.Tuner(tune.with_parameters(objective, data=(subranks)), param_space=search_space,
                tune_config=tune.TuneConfig(num_samples=num_samples, search_alg=hyperopt_search, max_concurrent_trials=max_concurrent_trials),
                run_config=RunConfig(verbose=0))
        results = tuner.fit()

        best_weights = np.fromiter(results.get_best_result(metric="mrr", mode="max").config.values(), dtype='float32')
        rel_weights[rel_id] = best_weights

    # np.save(f"rel_weights.npy", rel_weights)


In [None]:
print("Evaluating")
rel_res = {
    'test': [{'mrr': 0} for _ in range(num_relation)],
    'valid': [{'mrr': 0} for _ in range(num_relation)]
}

for rel_id in tqdm(range(num_relation)):
    # test results
    if len(rel_indexes['test'][rel_id]) == 0:
        continue
    sub_ranks = [model_rank[rel_indexes['test'][rel_id]] for model_rank in ranks['test']]
    config = {f"w_{i}": rel_weights[rel_id][i] for i in range(n_model)}
    metrics = objective(config, (sub_ranks))
    rel_res['test'][rel_id] = metrics

    # valid results
    if len(rel_indexes['valid'][rel_id]) == 0:
        continue
    sub_ranks = [model_rank[rel_indexes['valid'][rel_id]] for model_rank in ranks['valid']]
    config = {f"w_{i}": rel_weights[rel_id][i] for i in range(n_model)}
    metrics = objective(config, (sub_ranks))
    rel_res['valid'][rel_id] = metrics

In [None]:
test_mrr = 0
valid_mrr = 0

for rel_id in range(num_relation):
    test_mrr += rel_res['test'][rel_id]['mrr'] * len(rel_indexes['test'][rel_id])
    valid_mrr += rel_res['valid'][rel_id]['mrr'] * len(rel_indexes['valid'][rel_id])

test_mrr = test_mrr / ranks['test'][0].shape[0]
valid_mrr = valid_mrr / ranks['valid'][0].shape[0]

print(f"Test MRR: {test_mrr}\nValidation MRR: {valid_mrr}")