In [1]:
import tqdm
import time
import numpy as np
import pandas as pd

from sharp import ShaRP
from sharp.utils import scores_to_ordering
from xai_ranking.preprocessing import preprocess_atp_data
from xai_ranking.datasets import fetch_atp_data
from xai_ranking.scorers import atp_score
from xai_ranking.metrics import kendall_agreement, fidelity, sensitivity
from mlresearch.utils import check_random_states

RNG_SEED = 42
N_RUNS = 5

In [2]:
random_states = check_random_states(RNG_SEED, N_RUNS)

dataset = {
    "name": "ATP",
    "data": preprocess_atp_data(fetch_atp_data()),
    "scorer": atp_score,
}

default_kwargs = {
    "qoi": "rank",
    "measure": "shapley",
    "sample_size": None,
    "coalition_size": None,
    "replace": True,
    "n_jobs": -1,
}
parameters_to_change = {
    "coalition_size": [i for i in range(1, 6)],
    "sample_size": [i for i in range(6, 90, 5)],
    "n_jobs": [i for i in range(1, 32)]
}

In [9]:
X = dataset["data"][0]
scorer = dataset["scorer"]
ranking = scores_to_ordering(scorer(dataset["data"][0]))
result_cols = (
    ["parameter", "parameter_value", "avg_time"] 
    + [f"time_{i}" for i in range(N_RUNS)]
    + [f"agreement_{i}" for i in range(N_RUNS)]
    + [f"fidelity_{i}" for i in range(N_RUNS)]
    + [f"sensitivity_{i}" for i in range(N_RUNS)]
)

times = []
agreements = []
fidelities = []
sensitivities = []
for i in tqdm.tqdm(range(N_RUNS)):
    start = time.time()
    baseline_sharp = ShaRP(target_function=dataset["scorer"], random_state=random_states[i], **default_kwargs)
    baseline_sharp.fit(X)
    baseline_contr = baseline_sharp.all(X)
    end = time.time()

    # Save metrics - compute agreement, fidelity and sensitivity metrics here
    times.append(end - start)
    agreements.append(1)
    # agreements.append(kendall_agreement(baseline_contr, baseline_contr))
    fidelities.append(fidelity(X, dataset["scorer"], baseline_contr, random_state=random_states[i])[0])
    sensitivities.append(sensitivity(X, baseline_contr, ranking)[0])

exact_results_row = [np.nan, np.nan, np.mean(times)] + times + agreements + fidelities + sensitivities
result_df = [exact_results_row]

for parameter, parameter_values in parameters_to_change.items():
    print(f"Alternating parameter: {parameter}")
    # default_value = default_kwargs[parameter] if parameter in default_kwargs else None

    for parameter_value in tqdm.tqdm(parameter_values):
        default_kwargs[parameter] = parameter_value

        times = []
        agreements = []
        fidelities = []
        sensitivities = []
        for i in range(N_RUNS):
            start = time.time()
            sharp = ShaRP(
                target_function=dataset["scorer"],
                **default_kwargs,
            )
            sharp.fit(X)
            contr = sharp.all(X)
            end = time.time()
            times.append(end - start)

            # Save metrics - compute agreement, fidelity and sensitivity metrics here
            times.append(end - start)
            agreements.append(1)
            # agreements.append(kendall_agreement(baseline_contr, baseline_contr))
            fidelities.append(fidelity(X, dataset["scorer"], baseline_contr, random_state=random_states[i])[0])
            sensitivities.append(sensitivity(X, baseline_contr, ranking)[0])

        results_row = [
            parameter,
            parameter_value,
            np.mean(times)
        ] + times + agreements + fidelities + sensitivities
        result_df.append(results_row)

results = pd.DataFrame(result_df, columns=result_cols)
results.to_csv("results/time-experiment-" + dataset["name"] + ".csv")

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [01:44<00:00, 20.88s/it]


Alternating parameter: coalition_size


  0%|          | 0/5 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [10]:
results = pd.DataFrame(result_df, columns=result_cols)
results

Unnamed: 0,parameter,parameter_value,avg_time,time_0,time_1,time_2,time_3,time_4,agreement_0,agreement_1,...,fidelity_0,fidelity_1,fidelity_2,fidelity_3,fidelity_4,sensitivity_0,sensitivity_1,sensitivity_2,sensitivity_3,sensitivity_4
0,,,19.882732,17.299065,19.700095,20.802603,20.754905,20.85699,1,1,...,1.831773,1.820014,1.847919,1.833094,1.84208,0.167442,0.167752,0.168837,0.160775,0.164341
