# Extract certain quadruples that should be explained and store them as input for the explainer

In [4]:
from matplotlib import pyplot as plt 

import numpy as np
import sys
import util_scripts.stats_utils as su
import util_scripts.dataset_utils as du
sys.path.append('..')

import sys
import os
from tqdm import tqdm

# Add the project root directory
sys.path.append(os.path.abspath("../rule_based"))  # or the absolute path to your project

from rule_based.rule_dataset import RuleDataset
from rule_based.eval import evaluate
import rule_utils
from tgb.linkproppred.evaluate import Evaluator
import os

## For given rankings, and ids, find all (test/val) quadruples that have reciprocal rank (rr) in between given thresholds

In [None]:
dataset_name = 'tkgl-icews14'
rankings_filename ='tkgl-icews14-rankings_test_conf_0_corr_conf_0_noisyor_crules_frules_zrules_pvalue_30_num_top_rules_10_multi.txt'  # replace with name of your rankings file -
eval_mode ='test'
rankings_arules_name = os.path.join('..', 'files', 'rankings', dataset_name, rankings_filename)
rule_dataset =  RuleDataset(name=dataset_name)

upper_limit = 1 # rr < upper_limit
lower_limit = 0.2 # rr > lower_limit

exp_name = 'compare'
outpath = os.path.join('..', 'files', 'explanations', exp_name, 'input')

if not os.path.exists(outpath):
    os.makedirs(outpath)

explanations_path = os.path.join(outpath, "quadruples.txt")  # this can be used as input for the explainer

src_of_interest = 'all'
dst_of_interst = 'all'
t_of_interest = 'all'
rel_of_interest = [5]

raw file found, skipping download
Dataset directory is  c:\Users\jgasting\PythonScripts\Rules\GraphTRuCoLa\tgb/datasets\tkgl_icews14
loading processed file
num_rels:  230
>>> loading and indexing of dataset 2.896 seconds
>>> average number of time steps for a triple: 1.804
>>> checked order of time steps, everything is fine


In [6]:
dataset = rule_dataset.dataset
num_nodes = rule_dataset.dataset.num_nodes
split_mode = eval_mode
evaluator = Evaluator(name=dataset.name, k_value=[1,10,100])
neg_sampler = dataset.negative_sampler  

if eval_mode == "val":
    testdata = rule_dataset.val_data
    print("loading negative val samples")
    dataset.load_val_ns() # load negative samples, i.e. the nodes that are not used for time aware filter mrr
elif eval_mode == "test":
    testdata = rule_dataset.test_data
    print("loading negative test samples")
    dataset.load_test_ns() # load negative samples, i.e. the nodes that are not used for time aware filter mrr

rankings_arules = rule_utils.read_rankings_order(rankings_arules_name, num_nodes)


loading negative test samples


In [9]:


print('>>> starting evaluation for every triple, in the ', eval_mode, 'set')
total_iterations = len(testdata)
progressbar_percentage = 0.01


increment = int(total_iterations*progressbar_percentage) if int(total_iterations*progressbar_percentage) >=1 else 1
remaining = total_iterations

in_threshold_counter = 0

with open(explanations_path, 'w') as file_explanations:
    # with tqdm(total=total_iterations) as pbar:
    counter = 0
    file_explanations.write("subject rel object timestep\n")
    for i, (src, dst, t, rel) in enumerate(zip(testdata[:,0], testdata[:,2], testdata[:,3], testdata[:,1])):
        # only do this for src, dst, t, rel of interest
        if not src_of_interest == 'all':
            if not src in src_of_interest:
                continue
        if not dst_of_interst == 'all':
            if not dst in dst_of_interst:
                continue
        if not rel_of_interest == 'all':
            if not rel in rel_of_interest:
                continue
        if not t_of_interest == 'all':
            if not t in t_of_interest:
                continue

        # Update progress bar

            
        original_t = rule_dataset.timestamp_id2orig[t]

        # Query negative batch list - all negative samples for the given positive edge that are not temporal conflicts (time aware mrr)
        neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([original_t]), edge_type=np.array([rel]), split_mode=split_mode)

        # Make predictions for given src, rel, t
        # Compute a score for each node in neg_batch_list and for actual correct node dst
        scores_array_arules =rule_utils.create_scores_array(rankings_arules[(src, rel, t)], num_nodes)

        #### a_rules
        predictions_neg_arules = scores_array_arules[neg_batch_list[0]]
        predictions_pos_arules = np.array(scores_array_arules[dst])
        # Evaluate the predictions
        input_dict = {
            "y_pred_pos": predictions_pos_arules,
            "y_pred_neg": predictions_neg_arules,
            "eval_metric": ['mrr'], 
        }
        predictions_arules = evaluator.eval(input_dict)
        mrr_arule = float(predictions_arules['mrr'])
        hits10_arule = float(predictions_arules['hits@10'])
        hits1_arule = float(predictions_arules['hits@1'])
        hits100_arule = float(predictions_arules['hits@100'])
        #### check if all conditions are met to have b-predictiion [significantly] better than a-prediction
        if mrr_arule < upper_limit: #1) the correct candidate in b has a higher rank than in a
            if mrr_arule > lower_limit:
                file_explanations.write(str(src) + " " + str(rel) + " " + str(dst) + " " + str(t) + '\n')
                in_threshold_counter +=1

print("in total we had", (in_threshold_counter), "cases mrr was in threshold")
print("explanations written to", explanations_path)


>>> starting evaluation for every triple, in the  test set
in total we had 63 cases mrr was in threshold
explanations written to ..\files\explanations\compare\input\quadruples.txt
