In [1]:
import sys

sys.path.append("..")

import numpy as np
import pandas as pd
from lightgbm import LGBMRanker
from xai_ranking.benchmarks import (
    human_in_the_loop,
    hierarchical_ranking_explanation,
    lime_experiment,
    lime_batch_experiment,
    shap_experiment,
    shap_batch_experiment,
    sharp_experiment,
    sharp_batch_experiment,
    # participation_experiment,
)
from xai_ranking.preprocessing import (
    preprocess_atp_data,
    preprocess_csrank_data,
    preprocess_higher_education_data,
    preprocess_movers_data,
)
from xai_ranking.datasets import (
    fetch_atp_data,
    fetch_csrank_data,
    fetch_higher_education_data,
    fetch_movers_data,
)
from xai_ranking.scorers import (
    atp_score,
    csrank_score,
    higher_education_score,
)

RNG_SEED = 42

In [2]:
# Set up ranker for the moving company dataset:
X, ranks, score = preprocess_movers_data(fetch_movers_data()) 
qids_train = X.index.value_counts().to_numpy()

model = LGBMRanker(
    objective="lambdarank", label_gain=list(range(max(ranks) + 1)), verbose=-1
)
model.fit(
    X=X,
    y=ranks,
    group=qids_train,
)

In [3]:
datasets = [
    {
        "name": "ATP",
        "data": fetch_atp_data().head(20),
        "preprocess": preprocess_atp_data,
        "scorer": atp_score,
    },
    {
        "name": "CSRank",
        "data": fetch_csrank_data().head(20),
        "preprocess": preprocess_csrank_data,
        "scorer": csrank_score,
    },
    {
        "name": "Higher Education",
        "data": fetch_higher_education_data(year=2021).head(20),
        "preprocess": preprocess_higher_education_data,
        "scorer": higher_education_score,
    },
    {
        "name": "Moving Company",
        "data": fetch_movers_data(),
        "preprocess": preprocess_movers_data,
        "scorer": model.predict,
    },
]
xai_methods = [
    {"name": "LIME", "experiment": lime_experiment},
    {"name": "BATCH_LIME", "experiment": lime_batch_experiment},
    {"name": "SHAP", "experiment": shap_experiment},
    {"name": "BATCH_SHAP", "experiment": shap_batch_experiment},
    {"name": "ShaRP", "experiment": sharp_experiment},
    {"name": "BATCH_ShaRP", "experiment": sharp_batch_experiment},
    # {"name": "Participation", "experiment": participation_experiment},
    {"name": "HRE", "experiment": hierarchical_ranking_explanation},
    {"name": "HIL", "experiment": human_in_the_loop},
]

In [4]:
results = {}
for dataset in datasets:
    results[dataset["name"]] = {}
    for xai_method in xai_methods:
        
        experiment_func = xai_method["experiment"]
        preprocess_func = dataset["preprocess"]
        score_func = dataset["scorer"]
        
        X, ranks, scores = preprocess_func(dataset["data"])
        contributions = experiment_func(X, score_func)
        
        results[dataset["name"]][xai_method["name"]] = contributions

        result_df = pd.DataFrame(contributions, columns=X.columns, index=X.index)
        result_df.to_csv(f"results/_contributions_{dataset['name']}_{xai_method['name']}.csv")
        # with open(f"_contributions_{dataset['name']}_{xai_method['name']}.npy", "wb") as f:
        #     np.save(f, contributions)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Usage of np.ndarray subset (sliced data) is not recommended due to it will double the peak memory cost in LightGBM.


  0%|          | 0/20 [00:00<?, ?it/s]

In [5]:
results

array([[-0.95      ,  1.95      , -2.55      ],
       [-0.6       , -0.75      ,  0.8       ],
       [ 1.2       , -0.45      ,  8.7       ],
       [ 1.55      , -0.45      ,  7.35      ],
       [ 1.61666667,  1.21666667,  3.61666667],
       [-0.78333333, -0.93333333, -3.83333333],
       [-0.65      , -0.8       , -3.1       ],
       [-0.63333333, -0.78333333, -2.13333333],
       [ 2.25      , -0.75      ,  0.95      ],
       [-0.78333333, -0.93333333, -3.83333333],
       [-0.83333333,  1.56666667, -0.28333333],
       [ 2.21666667,  1.81666667,  1.41666667],
       [-1.05      , -1.2       , -5.3       ],
       [-1.18333333, -1.33333333, -6.03333333],
       [-0.3       , -0.45      ,  5.2       ],
       [ 2.56666667, -0.63333333,  1.51666667],
       [-0.43333333, -0.43333333,  8.31666667],
       [-1.31666667, -1.46666667, -6.76666667],
       [-1.2       ,  3.25      , -4.6       ],
       [-0.68333333,  1.56666667,  0.56666667]])