# shap-knn

## 使用可解释工具增强训练集的公平性

![图片](assets/2024-05-08-流程图.drawio.svg)

In [1]:
# 导入依赖
import numpy as np
import pandas as pd
from rich import print
import matplotlib.pyplot as plt

In [2]:
import shap
import xgboost

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [3]:
# 字体设置
from pylab import mpl
# 设置显示中文字体
mpl.rcParams["font.sans-serif"] = ["Microsoft YaHei"]

In [4]:
# 导入自定义函数
%load_ext autoreload
%autoreload 2
from utils.helper import fairness_metrics, test_model # 测试各种指标并输出图像
from utils.shap_helper_0508 import get_ext_train_comp_by_k
from model.models import train_model_and_test


In [6]:
# 测试集
from mydata.adult_0508 import (
    X_train,
    y_train,
    X_test,
    y_test,
    sex,
    feature_index,
    X_train_plus,
    y_train_plus,
    idxs_plus,
    X_train_minus,
    y_train_minus,
    idxs_minus
)  # 获得 数据集 测试集 和 测试集上的敏感属性

In [7]:
# 如果
def shap_xgboost(X_train, y_train, k):
    model = xgboost.XGBRegressor()
    model.fit(X_train, y_train)
    explainer = shap.Explainer(model)
    shap_values = explainer(X_train)
    arr: np.ndarray = shap_values[:, feature_index].values
    print("arr 的长度是", len(arr))
    sorted_indices = np.argsort(np.abs(arr))
    top_k_indices = sorted_indices[-k:]
    return top_k_indices

top_k_plus = shap_xgboost(X_train_plus, y_train_plus, k=30)

In [8]:
from model.models import train_model_and_test

In [9]:
xgboost_plus, res_ard = train_model_and_test(
    X_train,
    y_train,
    X_test,
    y_test,
    test_func=test_model,
    sensitive_feature=sex,
    model_cls=xgboost.XGBRegressor,
    desc="ard测试组",
)

In [11]:
result = get_ext_train_comp_by_k(
    model=xgboost_plus,
    X_train=X_train,
    y_train=y_train,
    feature_index=feature_index,
    k=100,
)


In [12]:
X_train_top_plus = result.X_train_top
y_train_top_plus = result.y_train_top

In [13]:
X_train_top_plus

Unnamed: 0,age,workclass,fnlwgt,education,education.num,marital.status,occupation,relationship,race,sex,capital.gain,capital.loss,hours.per.week,native.country
32098,0.101484,2.600478,-1.494279,-0.332263,1.133894,-0.402341,-0.782234,2.214196,0.392980,-1.430470,-0.145189,-0.217407,-1.662414,0.262317
25206,0.028248,-1.884720,0.438778,0.184396,-0.423425,-0.402341,-0.026696,-0.899410,0.392980,0.699071,-0.145189,-0.217407,-0.200753,0.262317
23491,0.247956,-0.090641,0.045292,1.217715,-0.034095,0.926666,-0.782234,-0.276689,0.392980,-1.430470,-0.145189,-0.217407,-0.038346,0.262317
12367,-0.850587,-1.884720,0.793152,0.184396,-0.423425,0.926666,-0.530388,0.968753,0.392980,0.699071,-0.145189,-0.217407,-0.038346,0.262317
7054,-0.044989,-2.781760,-0.853275,0.442726,1.523223,-0.402341,-0.782234,-0.899410,0.392980,0.699071,-0.145189,-0.217407,-0.038346,0.262317
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
31092,0.907082,-0.090641,-0.485221,-0.590592,0.355234,-1.731347,1.484379,1.591474,-1.963453,0.699071,-0.145189,-0.217407,0.611281,0.262317
6865,0.687373,-0.090641,-0.667563,-0.332263,1.133894,-1.731347,1.232533,1.591474,0.392980,0.699071,-0.145189,-0.217407,0.286467,0.262317
16823,0.467665,-1.884720,-0.042608,0.442726,1.523223,-1.731347,0.728841,0.968753,0.392980,0.699071,-0.145189,-0.217407,1.179704,0.262317
32545,0.467665,-1.884720,-0.667563,-0.848922,0.744564,-1.731347,0.728841,1.591474,0.392980,0.699071,-0.145189,-0.217407,0.611281,0.262317
