In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
sys.path.append('..')
sys.path.append('../..')
sys.path.append('.')
sys.path.append('./scripts')

In [None]:
plt.rcParams['axes.labelsize'] = 30
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 2.0

### Regression

In [None]:
dgp = "linear"
combined_df = pd.DataFrame()
datasets = ['openml_361260', 'openml_361254', 'openml_361259', 'openml_361253', 'openml_361243', 'openml_361242']
feature_seeds = [1,2,3,4,5,6,7,8,9,10]
sample_seeds = [1,2,3]
for data in datasets:
    ablation_directory = f"./results_new/mdi_local.real_data_regression_{data}_{dgp}/{data}_{dgp}/varying_heritability_sample_row_n"
    for sample_seed in sample_seeds:
        for feature_seed in feature_seeds:
            df = pd.read_csv(os.path.join(ablation_directory, f"seed_{feature_seed}_{sample_seed}/results.csv"))
            df["data"] = data
            combined_df = pd.concat([combined_df, df], ignore_index=True)

agg_df = combined_df.groupby(['sample_row_n', 'heritability', 'fi', 'data'])[
    ["auroc_train", "auroc_test"]
].agg(['mean', 'std', 'count']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]
agg_df["auroc_train_sem"] = agg_df["auroc_train_std"] / np.sqrt(agg_df["auroc_train_count"])
agg_df["auroc_test_sem"] = agg_df["auroc_test_std"] / np.sqrt(agg_df["auroc_test_count"])

df = agg_df
df = df[(df["sample_row_n"] == 150) | (df["sample_row_n"] == 500) | (df["sample_row_n"] == 1000)]

In [None]:
methods = [
    'lmdi+',    
    'lmdi',
    'LIME',
    'Treeshap',
]
color_map = {
    'LIME': '#71BEB7',
    'Treeshap': 'orange',
    'lmdi': '#9B5DFF',
    'lmdi+': 'black'
}
data_name = {
    "openml_361260": "Miami Housing",
    "openml_361259": "Puma Robot",
    "openml_361253": "Wave Energy",
    "openml_361254": "SARCOS",
    "openml_361242": "Super Conductivity",
    "openml_361243": "Geographic Origin of Music"
}

feature_values = {
    "openml_361260": 15,
    "openml_361259": 32,
    "openml_361253": 48,
    "openml_361254": 21,
    "openml_361242": 81,
    "openml_361243": 72
}

methods_name = {
    'LIME': 'LIME',
    'lmdi': 'LMDI',
    'Treeshap': 'TreeSHAP',
    'lmdi+': 'LMDI+',
}

In [None]:
datasets = datasets
heritability_all = df["heritability"].unique()[::-1]
marker_size = 7

n_cols = len(heritability_all)
n_rows = len(datasets) 

fig, axs = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(8 * n_cols, 6.5 * n_rows),
    sharey=False
)

if n_rows == 1:
    axs = np.expand_dims(axs, axis=0)
if n_cols == 1:
    axs = np.expand_dims(axs, axis=1)

for row_idx, dataset in enumerate(datasets):
    for col_idx, heritability in enumerate(heritability_all):
        ax = axs[row_idx, col_idx]
        subset = df[(df["data"] == dataset) & (df["heritability"] == heritability)]
        
        for method in methods:
            method_data = subset[subset["fi"] == method]

            if method in ['LIME', 'Treeshap', 'lmdi']:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3, alpha=0.5
                )
            else:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3
                )
        ax.set_xticks([150,500,1000])
        ax.set_xticklabels(["150", "500", "1000"], fontsize=25)
        ax.tick_params(axis='y', labelsize=25)
        if row_idx == n_rows - 1:
            ax.set_xlabel("Sample Size", fontsize=30)
        
        if col_idx == 0:
            dataset_label = data_name[dataset]
            p_val = feature_values[dataset]
            dataset_label = dataset_label.replace(' ', r'\ ')
            if dataset == "openml_361243":
                ax.set_ylabel(
                    f"$\\mathbf{{Geographic\ Origin\ of}}$\n$\\mathbf{{Music\ (p={p_val})}}$\nAUROC",
                    fontsize=30
                )
            else:
                ax.set_ylabel(f"$\\mathbf{{{dataset_label} \ (p={p_val})}}$\nAUROC", fontsize=30)
        else:
            ax.set_ylabel("")
        
        if row_idx == 0:
            ax.set_title(r"$\bf{PVE}$=" + rf"$\bf{{{heritability}}}$", fontsize=30)

        if col_idx == n_cols - 1:
            ax.legend(fontsize=22, loc='lower right')

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.suptitle(r"\textbf{Linear}", fontsize=40, usetex=True)
plt.savefig("feature_ranking_linear_full.pdf", format='pdf', bbox_inches='tight')
plt.show()

In [None]:
dgp = "interaction"
combined_df = pd.DataFrame()
datasets = ['openml_361260', 'openml_361254', 'openml_361259', 'openml_361253', 'openml_361243', 'openml_361242']
feature_seeds = [1,2,3,4,5,6,7,8,9,10]
sample_seeds = [1,2,3]
for data in datasets:
    ablation_directory = f"./results_new/mdi_local.real_data_regression_{data}_{dgp}/{data}_{dgp}/varying_heritability_sample_row_n"
    for sample_seed in sample_seeds:
        for feature_seed in feature_seeds:
            df = pd.read_csv(os.path.join(ablation_directory, f"seed_{feature_seed}_{sample_seed}/results.csv"))
            df["data"] = data
            combined_df = pd.concat([combined_df, df], ignore_index=True)

agg_df = combined_df.groupby(['sample_row_n', 'heritability', 'fi', 'data'])[
    ["auroc_train", "auroc_test"]
].agg(['mean', 'std', 'count']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]
agg_df["auroc_train_sem"] = agg_df["auroc_train_std"] / np.sqrt(agg_df["auroc_train_count"])
agg_df["auroc_test_sem"] = agg_df["auroc_test_std"] / np.sqrt(agg_df["auroc_test_count"])

df = agg_df
df = df[(df["sample_row_n"] == 150) | (df["sample_row_n"] == 500) | (df["sample_row_n"] == 1000)]

In [None]:
datasets = datasets
heritability_all = df["heritability"].unique()[::-1]
marker_size = 7

n_cols = len(heritability_all)
n_rows = len(datasets)

# Create subplots
fig, axs = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(8 * n_cols, 6.5 * n_rows),
    sharey=False
)

if n_rows == 1:
    axs = np.expand_dims(axs, axis=0)
if n_cols == 1:
    axs = np.expand_dims(axs, axis=1)

for row_idx, dataset in enumerate(datasets):
    for col_idx, heritability in enumerate(heritability_all):
        ax = axs[row_idx, col_idx]
        subset = df[(df["data"] == dataset) & (df["heritability"] == heritability)]
        
        for method in methods:
            method_data = subset[subset["fi"] == method]

            if method in ['LIME', 'Treeshap', 'lmdi']:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3, alpha=0.5
                )
            else:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3
                )
        
        ax.set_xticks([150,500,1000])
        ax.set_xticklabels(["150", "500", "1000"], fontsize=25)
        ax.tick_params(axis='y', labelsize=25)
        if row_idx == n_rows - 1:
            ax.set_xlabel("Sample Size", fontsize=30)

        if col_idx == 0:
            dataset_label = data_name[dataset]
            p_val = feature_values[dataset]
            dataset_label = dataset_label.replace(' ', r'\ ')
            if dataset == "openml_361243":
                ax.set_ylabel(
                    f"$\\mathbf{{Geographic\ Origin\ of}}$\n$\\mathbf{{Music\ (p={p_val})}}$\nAUROC",
                    fontsize=30
                )
            else:
                ax.set_ylabel(f"$\\mathbf{{{dataset_label} \ (p={p_val})}}$\nAUROC", fontsize=30)
        else:
            ax.set_ylabel("")
        
        if row_idx == 0:
            ax.set_title(r"$\bf{PVE}$=" + rf"$\bf{{{heritability}}}$", fontsize=30)

        if col_idx == n_cols - 1:
            ax.legend(fontsize=22, loc='lower right')

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.suptitle(r"\textbf{Interaction}", fontsize=40, usetex=True)
plt.savefig("feature_ranking_interaction_full.pdf", format='pdf', bbox_inches='tight')
plt.show()

In [None]:
dgp = "linear_lss"
combined_df = pd.DataFrame()
datasets = ['openml_361260', 'openml_361254', 'openml_361259', 'openml_361253', 'openml_361243', 'openml_361242']
feature_seeds = [1,2,3,4,5,6,7,8,9,10]
sample_seeds = [1,2,3]
for data in datasets:
    ablation_directory = f"./results_new/mdi_local.real_data_regression_{data}_{dgp}/{data}_{dgp}/varying_heritability_sample_row_n"
    for sample_seed in sample_seeds:
        for feature_seed in feature_seeds:
            df = pd.read_csv(os.path.join(ablation_directory, f"seed_{feature_seed}_{sample_seed}/results.csv"))
            df["data"] = data
            combined_df = pd.concat([combined_df, df], ignore_index=True)

agg_df = combined_df.groupby(['sample_row_n', 'heritability', 'fi', 'data'])[
    ["auroc_train", "auroc_test"]
].agg(['mean', 'std', 'count']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]
agg_df["auroc_train_sem"] = agg_df["auroc_train_std"] / np.sqrt(agg_df["auroc_train_count"])
agg_df["auroc_test_sem"] = agg_df["auroc_test_std"] / np.sqrt(agg_df["auroc_test_count"])

df = agg_df
df = df[(df["sample_row_n"] == 150) | (df["sample_row_n"] == 500) | (df["sample_row_n"] == 1000)]

In [None]:
datasets = datasets
heritability_all = df["heritability"].unique()[::-1]
marker_size = 7

n_cols = len(heritability_all)
n_rows = len(datasets)

fig, axs = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(8 * n_cols, 6.5 * n_rows),
    sharey=False
)

if n_rows == 1:
    axs = np.expand_dims(axs, axis=0)
if n_cols == 1:
    axs = np.expand_dims(axs, axis=1)

for row_idx, dataset in enumerate(datasets):
    for col_idx, heritability in enumerate(heritability_all):
        ax = axs[row_idx, col_idx]
        subset = df[(df["data"] == dataset) & (df["heritability"] == heritability)]
        
        for method in methods:
            method_data = subset[subset["fi"] == method]

            if method in ['LIME', 'Treeshap', 'lmdi']:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3, alpha=0.5
                )
            else:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3
                )
        
        ax.set_xticks([150,500,1000])
        ax.set_xticklabels(["150", "500", "1000"], fontsize=25)
        ax.tick_params(axis='y', labelsize=25)
        if row_idx == n_rows - 1:
            ax.set_xlabel("Sample Size", fontsize=30)
        
        if col_idx == 0:
            dataset_label = data_name[dataset]
            p_val = feature_values[dataset]
            dataset_label = dataset_label.replace(' ', r'\ ')
            if dataset == "openml_361243":
                ax.set_ylabel(
                    f"$\\mathbf{{Geographic\ Origin\ of}}$\n$\\mathbf{{Music\ (p={p_val})}}$\nAUROC",
                    fontsize=30
                )
            else:
                ax.set_ylabel(f"$\\mathbf{{{dataset_label} \ (p={p_val})}}$\nAUROC", fontsize=30)
        else:
            ax.set_ylabel("")
        
        if row_idx == 0:
            ax.set_title(r"$\bf{PVE}$=" + rf"$\bf{{{heritability}}}$", fontsize=30)

        if col_idx == n_cols - 1:
            ax.legend(fontsize=22, loc='lower right')

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.suptitle(r"\textbf{Linear + LSS}", fontsize=40, usetex=True)
plt.savefig("feature_ranking_linear_lss_full.pdf", format='pdf', bbox_inches='tight')
plt.show()

### Classification

In [None]:
dgp = "logistic_linear"
datasets = ['openml_361063', 'openml_361069', 'openml_361062', 'openml_9978', 'openml_361071', 'openml_43']
feature_seeds = [1,2,3,4,5,6,7,8,9,10]
sample_seeds = [1,2,3]
combined_df = pd.DataFrame()
for data in datasets:
    ablation_directory = f"./results_new/mdi_local.real_data_classification_{data}_{dgp}/{data}_{dgp}/varying_frac_label_corruption_sample_row_n"
    for sample_seed in sample_seeds:
        for feature_seed in feature_seeds:
            df = pd.read_csv(os.path.join(ablation_directory, f"seed_{feature_seed}_{sample_seed}/results.csv"))
            df["data"] = data
            combined_df = pd.concat([combined_df, df], ignore_index=True)


agg_df = combined_df.groupby(['sample_row_n', 'frac_label_corruption', 'fi', 'data'])[
    ["auroc_train", "auroc_test"]
].agg(['mean', 'std', 'count']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]
agg_df["auroc_train_sem"] = agg_df["auroc_train_std"] / np.sqrt(agg_df["auroc_train_count"])
agg_df["auroc_test_sem"] = agg_df["auroc_test_std"] / np.sqrt(agg_df["auroc_test_count"])

df = agg_df
df = df[(df["sample_row_n"] == 150) | (df["sample_row_n"] == 500) | (df["sample_row_n"] == 1000)]

In [None]:
methods = [
    'lmdi+',    
    'lmdi',
    'LIME',
    'Treeshap',
]
color_map = {
    'LIME': '#71BEB7',
    'Treeshap': 'orange',
    'lmdi': '#9B5DFF',
    'lmdi+': 'black'
}

data_name = {
    "openml_43": "Spam",
    "openml_361062": "Pol",
    "openml_361071": "Jannis",
    "openml_9978": "Ozone",
    "openml_361069": "Higgs",
    "openml_361063": "House 16H"
}

feature_values = {
    "openml_43": 57,
    "openml_361062": 26,
    "openml_361071": 54,
    "openml_9978": 47,
    "openml_361069": 24,
    "openml_361063": 16
}

methods_name = {
    'LIME': 'LIME',
    'lmdi': 'LMDI',
    'Treeshap': 'TreeSHAP',
    'lmdi+': 'LMDI+',
}

In [None]:
datasets = datasets
frac_label_corruption_all = df["frac_label_corruption"].unique()
marker_size = 7

n_cols = len(heritability_all)
n_rows = len(datasets)

fig, axs = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(8 * n_cols, 6.5 * n_rows),
    sharey=False
)

if n_rows == 1:
    axs = np.expand_dims(axs, axis=0)
if n_cols == 1:
    axs = np.expand_dims(axs, axis=1)

for row_idx, dataset in enumerate(datasets):
    for col_idx, frac_label_corruption in enumerate(frac_label_corruption_all):
        ax = axs[row_idx, col_idx]
        subset = df[(df["data"] == dataset) & (df["frac_label_corruption"] == frac_label_corruption)]
        
        for method in methods:
            method_data = subset[subset["fi"] == method]

            if method in ['LIME', 'Treeshap', 'lmdi']:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3, alpha=0.5
                )
            else:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3
                )

        ax.set_xticks([150,500,1000])
        ax.set_xticklabels(["150", "500", "1000"], fontsize=25)
        ax.tick_params(axis='y', labelsize=25)
        if row_idx == n_rows - 1:
            ax.set_xlabel("Sample Size", fontsize=30)

        if col_idx == 0:
            dataset_label = data_name[dataset]
            p_val = feature_values[dataset]
            dataset_label = dataset_label.replace(' ', r'\ ')
            ax.set_ylabel(f"$\\mathbf{{{dataset_label}\ (p={p_val})}}$\nAUROC", fontsize=30)
        else:
            ax.set_ylabel("")
        
        if row_idx == 0:
            ax.set_title(rf"$\bf{{{int(frac_label_corruption*100)}}} \% \ $" + r"$\bf{Corrupted}$", fontsize=30)

        if col_idx == n_cols - 1:
            ax.legend(fontsize=22, loc='lower right')

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.suptitle(r"\textbf{Logistic}", fontsize=40, usetex=True)
plt.savefig("feature_ranking_logistic_linear_full.pdf", format='pdf', bbox_inches='tight')
plt.show()

In [None]:
dgp = "logistic_interaction"
datasets = ['openml_361063', 'openml_361069', 'openml_361062', 'openml_9978', 'openml_361071', 'openml_43']
feature_seeds = [1,2,3,4,5,6,7,8,9,10]
sample_seeds = [1,2,3]
combined_df = pd.DataFrame()
for data in datasets:
    ablation_directory = f"./results_new/mdi_local.real_data_classification_{data}_{dgp}/{data}_{dgp}/varying_frac_label_corruption_sample_row_n"
    for sample_seed in sample_seeds:
        for feature_seed in feature_seeds:
            df = pd.read_csv(os.path.join(ablation_directory, f"seed_{feature_seed}_{sample_seed}/results.csv"))
            df["data"] = data
            combined_df = pd.concat([combined_df, df], ignore_index=True)


agg_df = combined_df.groupby(['sample_row_n', 'frac_label_corruption', 'fi', 'data'])[
    ["auroc_train", "auroc_test"]
].agg(['mean', 'std', 'count']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]
agg_df["auroc_train_sem"] = agg_df["auroc_train_std"] / np.sqrt(agg_df["auroc_train_count"])
agg_df["auroc_test_sem"] = agg_df["auroc_test_std"] / np.sqrt(agg_df["auroc_test_count"])

df = agg_df
df = df[(df["sample_row_n"] == 150) | (df["sample_row_n"] == 500) | (df["sample_row_n"] == 1000)]

In [None]:
datasets = datasets
frac_label_corruption_all = df["frac_label_corruption"].unique()
marker_size = 7

n_cols = len(heritability_all)
n_rows = len(datasets)

fig, axs = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(8 * n_cols, 6.5 * n_rows),
    sharey=False
)

if n_rows == 1:
    axs = np.expand_dims(axs, axis=0)
if n_cols == 1:
    axs = np.expand_dims(axs, axis=1)

for row_idx, dataset in enumerate(datasets):
    for col_idx, frac_label_corruption in enumerate(frac_label_corruption_all):
        ax = axs[row_idx, col_idx]
        subset = df[(df["data"] == dataset) & (df["frac_label_corruption"] == frac_label_corruption)]
        
        for method in methods:
            method_data = subset[subset["fi"] == method]

            if method in ['LIME', 'Treeshap', 'lmdi']:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3, alpha=0.5
                )
            else:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3
                )
        
        ax.set_xticks([150,500,1000])
        ax.set_xticklabels(["150", "500", "1000"], fontsize=25)
        ax.tick_params(axis='y', labelsize=25)
        if row_idx == n_rows - 1:
            ax.set_xlabel("Sample Size", fontsize=30)

        if col_idx == 0:
            dataset_label = data_name[dataset]
            p_val = feature_values[dataset]
            dataset_label = dataset_label.replace(' ', r'\ ')
            ax.set_ylabel(f"$\\mathbf{{{dataset_label}\ (p={p_val})}}$\nAUROC", fontsize=30)
        else:
            ax.set_ylabel("")
        
        if row_idx == 0:
            ax.set_title(rf"$\bf{{{int(frac_label_corruption*100)}}} \% \ $" + r"$\bf{Corrupted}$", fontsize=30)

        if col_idx == n_cols - 1:
            ax.legend(fontsize=22, loc='lower right')

# Adjust layout and spacing
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.suptitle(r"\textbf{Logistic Interaction}", fontsize=40, usetex=True)
plt.savefig("feature_ranking_logistic_interaction_full.pdf", format='pdf', bbox_inches='tight')
plt.show()

In [None]:
dgp = "logistic_linear_lss"
datasets = ['openml_361063', 'openml_361069', 'openml_361062', 'openml_9978', 'openml_361071', 'openml_43']
feature_seeds = [1,2,3,4,5,6,7,8,9,10]
sample_seeds = [1,2,3]
combined_df = pd.DataFrame()
for data in datasets:
    ablation_directory = f"./results_new/mdi_local.real_data_classification_{data}_{dgp}/{data}_{dgp}/varying_frac_label_corruption_sample_row_n"
    for sample_seed in sample_seeds:
        for feature_seed in feature_seeds:
            df = pd.read_csv(os.path.join(ablation_directory, f"seed_{feature_seed}_{sample_seed}/results.csv"))
            df["data"] = data
            combined_df = pd.concat([combined_df, df], ignore_index=True)

agg_df = combined_df.groupby(['sample_row_n', 'frac_label_corruption', 'fi', 'data'])[
    ["auroc_train", "auroc_test"]
].agg(['mean', 'std', 'count']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]
agg_df["auroc_train_sem"] = agg_df["auroc_train_std"] / np.sqrt(agg_df["auroc_train_count"])
agg_df["auroc_test_sem"] = agg_df["auroc_test_std"] / np.sqrt(agg_df["auroc_test_count"])

df = agg_df
df = df[(df["sample_row_n"] == 150) | (df["sample_row_n"] == 500) | (df["sample_row_n"] == 1000)]

In [None]:
datasets = datasets
frac_label_corruption_all = df["frac_label_corruption"].unique()
marker_size = 7

n_cols = len(heritability_all)
n_rows = len(datasets)

# Create subplots
fig, axs = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(8 * n_cols, 6.5 * n_rows),
    sharey=False
)

if n_rows == 1:
    axs = np.expand_dims(axs, axis=0)
if n_cols == 1:
    axs = np.expand_dims(axs, axis=1)

for row_idx, dataset in enumerate(datasets):
    for col_idx, frac_label_corruption in enumerate(frac_label_corruption_all):
        ax = axs[row_idx, col_idx]
        subset = df[(df["data"] == dataset) & (df["frac_label_corruption"] == frac_label_corruption)]
        
        for method in methods:
            method_data = subset[subset["fi"] == method]

            if method in ['LIME', 'Treeshap', 'lmdi']:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3, alpha=0.5
                )
            else:
                ax.errorbar(
                    method_data["sample_row_n"], method_data['auroc_test_mean'], yerr=method_data["auroc_test_sem"],
                    linestyle='solid', marker='o', markersize=marker_size,
                    label=methods_name[method], color=color_map[method], linewidth=3
                )
        
        ax.set_xticks([150,500,1000])
        ax.set_xticklabels(["150", "500", "1000"], fontsize=25)
        ax.tick_params(axis='y', labelsize=25)
        if row_idx == n_rows - 1:
            ax.set_xlabel("Sample Size", fontsize=30)
        
        if col_idx == 0:
            dataset_label = data_name[dataset]
            p_val = feature_values[dataset]
            dataset_label = dataset_label.replace(' ', r'\ ')
            ax.set_ylabel(f"$\\mathbf{{{dataset_label}\ (p={p_val})}}$\nAUROC", fontsize=30)
        else:
            ax.set_ylabel("")

        if row_idx == 0:
            ax.set_title(rf"$\bf{{{int(frac_label_corruption*100)}}} \% \ $" + r"$\bf{Corrupted}$", fontsize=30)

        if col_idx == n_cols - 1:
            ax.legend(fontsize=22, loc='lower right')

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.suptitle(r"\textbf{Logistic + LSS}", fontsize=40, usetex=True)
plt.savefig("feature_ranking_logistic_linear_lss_full.pdf", format='pdf', bbox_inches='tight')
plt.show()