--> To put content in collapsable format, we can use the following snippet in the markdowns:

<details>
<summary style="cursor: pointer">
<b> double click the markdown to see the code instead of this </b>
</summary>

# 3.7 Feature Importance via Surrogate Models
-- Train interpretable models on complex ones

- Surrogate Decision Trees (on ensemble or neural nets)
- Linear Approximation of Non-linear Models
- Local Surrogate Models (e.g., in LIME)
- Model Distillation-based Attribution

----------

## 3.7.1 Surrogate Decision Trees (on ensemble or neural nets)

<details>
<summary style="cursor: pointer">
<h2> { Understanding Surrogate Decision Trees } </h2>
</summary>
<h3> What are Surrogate Decision Trees? </h3>
<p> A surrogate decision tree is a simplified model trained to mimic the predictions of a complex model (like a neural network or ensemble). It approximates the decision boundaries to improve interpretability.</p>
<h3> Role in Model Interpretability: </h3>
<ul>
    <li> Acts as a transparent proxy for an opaque model (e.g., XGBoost, deep nets).</li>
    <li> Offers global interpretability by showing which features the original model implicitly uses most.</li>
</ul>
<h3> Resources: </h3>
<ol>
    <li><a href="https://christophm.github.io/interpretable-ml-book/surrogate.html" target="_blank">Interpretable ML Book — Surrogate Models</a></li>
    <li><a href="https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html" target="_blank">Scikit-learn DecisionTreeClassifier</a></li>
</ol>
</details>

##### Parameters:
- model: Trained black-box model (ensemble, neural network, etc.)
- X: Features (DataFrame or array-like)
- feature_names: List of feature names (optional)
- tree_max_depth: Maximum depth of surrogate decision tree (default=5)
- random_state: Random seed for reproducibility
- show_plot: Whether to plot feature importances
- plot_size: Tuple indicating plot size (width, height)

##### Returns:
- Fitted surrogate decision tree
- DataFrame of feature importances
- Displays a plot of feature importances if show_plot=True

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.base import is_classifier, is_regressor

def surrogate_tree_feature_importance(X,
                                       model,
                                       feature_names=None,
                                       tree_max_depth=5,
                                       random_state=42,
                                       show_plot=True,
                                       plot_size=(12, 8)):
    
    if not isinstance(X, pd.DataFrame):
        X = pd.DataFrame(X)

    np.random.seed(random_state)

    # Predict on the dataset
    if is_classifier(model):
        y_pred = model.predict(X)
    elif is_regressor(model):
        y_pred = model.predict(X)
    else:
        raise ValueError("Unsupported model type for surrogate modeling.")

    # Choose type of surrogate tree
    if is_classifier(model) and len(np.unique(y_pred)) <= 20:
        surrogate = DecisionTreeClassifier(max_depth=tree_max_depth, random_state=random_state)
    else:
        surrogate = DecisionTreeRegressor(max_depth=tree_max_depth, random_state=random_state)

    # Fit surrogate tree
    surrogate.fit(X, y_pred)

    # Feature importances
    importances = surrogate.feature_importances_
    feature_names = feature_names or X.columns.tolist()

    importance_df = pd.DataFrame({
        'Feature': feature_names,
        'Importance': importances
    }).sort_values('Importance', ascending=False).reset_index(drop=True)

    if show_plot and not importance_df.empty:
        plt.figure(figsize=plot_size)
        colors = plt.cm.plasma(np.linspace(0.2, 1, len(importance_df)))

        bars = plt.barh(importance_df['Feature'],
                        importance_df['Importance'],
                        color=colors,
                        alpha=0.9)

        plt.xlabel('Feature Importance', fontsize=12)
        plt.title('Surrogate Decision Tree Feature Importances', fontsize=14, pad=20)
        plt.grid(axis='x', linestyle='--', alpha=0.3)

        for bar in bars:
            width = bar.get_width()
            plt.text(width + 0.01,
                     bar.get_y() + bar.get_height() / 2,
                     f'{width:.2f}',
                     va='center',
                     fontsize=9)

        method_text = (
            f"Method: Surrogate Tree\n"
            f"Estimator: {type(model).__name__}\n"
            f"Tree Depth: {tree_max_depth}"
        )
        plt.annotate(method_text,
                     xy=(0.02, 0.02),
                     xycoords='axes fraction',
                     ha='left',
                     va='bottom',
                     fontsize=9,
                     bbox=dict(boxstyle='round', alpha=0.1))

        plt.tight_layout()
        plt.show()

    return surrogate, importance_df

--------

## 3.7.2 Linear Approximation of Non-linear Models

<details>
<summary style="cursor: pointer">
<h2> { Understanding Linear Approximation of Non-linear Models } </h2>
</summary>
<h3> What is Linear Approximation? </h3>
<p> Linear approximation involves fitting a linear model (like Linear Regression or Logistic Regression) to approximate the local behavior of a complex model around a given point or region.</p>
<h3> Role in Interpretability: </h3>
<ul>
    <li> Offers insight into feature effects in a specific neighborhood of the input space.</li>
    <li> Often used in conjunction with tools like LIME or SHAP (kernel-based versions).</li>
</ul>
<h3> Resources: </h3>
<ol>
    <li><a href="https://arxiv.org/abs/1602.04938" target="_blank">LIME Paper (explains local linear approximations)</a></li>
    <li><a href="https://scikit-learn.org/stable/modules/linear_model.html" target="_blank">Scikit-learn Linear Models</a></li>
</ol>
</details>

##### Parameters:
- model: Trained black-box model (neural network, ensemble, etc.)
- X: Features (DataFrame or array-like)
- feature_names: List of feature names (optional)
- random_state: Random seed for reproducibility
- show_plot: Whether to plot feature importances
- plot_size: Tuple indicating plot size (width, height)
- sample_size: Number of points to sample from X for fitting (optional, for speed)

##### Returns:
- Fitted linear approximation model
- DataFrame of feature coefficients (importance)
- Displays a plot of feature coefficients if show_plot=True

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.base import is_classifier, is_regressor
from sklearn.utils import resample

def linear_approximation_feature_importance(X,
                                             model,
                                             feature_names=None,
                                             random_state=42,
                                             show_plot=True,
                                             plot_size=(12, 8),
                                             sample_size=None):

    if not isinstance(X, pd.DataFrame):
        X = pd.DataFrame(X)

    np.random.seed(random_state)

    # Subsample if needed
    if sample_size is not None and sample_size < len(X):
        X = resample(X, n_samples=sample_size, random_state=random_state)

    # Predict on X
    if is_classifier(model):
        y_pred = model.predict_proba(X)[:, 1]  # For binary classification, use probability of class 1
    elif is_regressor(model):
        y_pred = model.predict(X)
    else:
        raise ValueError("Unsupported model type for linear approximation.")

    # Choose linear model type
    if is_classifier(model):
        surrogate = LinearRegression()
    else:
        surrogate = LinearRegression()

    # Fit surrogate linear model
    surrogate.fit(X, y_pred)

    coefficients = surrogate.coef_
    feature_names = feature_names or X.columns.tolist()

    importance_df = pd.DataFrame({
        'Feature': feature_names,
        'Coefficient': coefficients
    }).sort_values('Coefficient', key=lambda x: np.abs(x), ascending=False).reset_index(drop=True)

    if show_plot and not importance_df.empty:
        plt.figure(figsize=plot_size)
        colors = plt.cm.coolwarm(np.linspace(0.2, 1, len(importance_df)))

        bars = plt.barh(importance_df['Feature'],
                        importance_df['Coefficient'],
                        color=colors,
                        alpha=0.9)

        plt.xlabel('Feature Coefficient (Linear Approximation)', fontsize=12)
        plt.title('Linear Approximation of Non-linear Model (Feature Importance)', fontsize=14, pad=20)
        plt.grid(axis='x', linestyle='--', alpha=0.3)

        for bar in bars:
            width = bar.get_width()
            plt.text(width + (0.01 if width >= 0 else -0.05),
                     bar.get_y() + bar.get_height() / 2,
                     f'{width:.2f}',
                     va='center',
                     fontsize=9)

        method_text = (
            f"Method: Linear Approximation\n"
            f"Estimator: {type(model).__name__}"
        )
        plt.annotate(method_text,
                     xy=(0.02, 0.02),
                     xycoords='axes fraction',
                     ha='left',
                     va='bottom',
                     fontsize=9,
                     bbox=dict(boxstyle='round', alpha=0.1))

        plt.tight_layout()
        plt.show()

    return surrogate, importance_df

--------

## 3.7.3 Local Surrogate Models (e.g., in LIME)

<details>
<summary style="cursor: pointer">
<h2> { Understanding Local Surrogate Models (LIME) } </h2>
</summary>
<h3> What is LIME? </h3>
<p> LIME (Local Interpretable Model-agnostic Explanations) fits an interpretable model (like linear regression or decision trees) locally around a prediction point to understand feature influence.</p>
<h3> Role in Explainability: </h3>
<ul>
    <li> Provides localized, interpretable explanations around individual predictions.</li>
    <li> It perturbs the data and observes how predictions change.</li>
    <li> The surrogate is trained on perturbed samples weighted by similarity to the instance.</li>
</ul>
<h3> Resources: </h3>
<ol>
    <li><a href="https://github.com/marcotcr/lime" target="_blank">LIME GitHub Repo</a></li>
    <li><a href="https://arxiv.org/abs/1602.04938" target="_blank">Original LIME Paper</a></li>
</ol>
</details>

##### Parameters:
- model: Trained black-box model (neural network, ensemble, etc.)
- X: Features (DataFrame or array-like)
- instance: Single instance (row) to explain (array-like or Series)
- num_features: Number of top features to explain locally
- kernel_width: Controls locality around the instance (higher = more global)
- random_state: Random seed for reproducibility
- show_plot: Whether to plot local feature importances
- plot_size: Tuple indicating plot size (width, height)
- sample_size: Number of perturbed samples to generate around instance

##### Returns:
- Fitted local surrogate model
- DataFrame of local feature importance (weights)
- Displays a plot of local feature contributions if show_plot=True

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.utils import check_random_state

def local_surrogate_lime_feature_importance(model,
                                             X,
                                             instance,
                                             num_features=10,
                                             kernel_width=0.75,
                                             random_state=42,
                                             show_plot=True,
                                             plot_size=(10, 6),
                                             sample_size=5000):

    if not isinstance(X, pd.DataFrame):
        X = pd.DataFrame(X)

    feature_names = X.columns.tolist()
    rng = check_random_state(random_state)

    # Ensure instance is array-like
    instance = np.array(instance).reshape(1, -1)

    # Create perturbations around instance
    perturbations = rng.normal(loc=instance, scale=0.1, size=(sample_size, X.shape[1]))
    perturbations = np.clip(perturbations, X.min().values, X.max().values)

    # Predict model outputs on perturbations
    if hasattr(model, 'predict_proba'):
        preds = model.predict_proba(perturbations)[:, 1]  # For binary
    else:
        preds = model.predict(perturbations)

    # Compute distances to original instance
    distances = euclidean_distances(perturbations, instance).flatten()

    # Kernel weights for proximity
    kernel_weights = np.exp(-(distances ** 2) / (kernel_width ** 2))

    # Fit weighted local surrogate model (Ridge for stability)
    surrogate = Ridge(alpha=1.0, fit_intercept=True)
    surrogate.fit(perturbations, preds, sample_weight=kernel_weights)

    coefficients = surrogate.coef_

    importance_df = pd.DataFrame({
        'Feature': feature_names,
        'Local_Weight': coefficients
    }).sort_values('Local_Weight', key=lambda x: np.abs(x), ascending=False)

    importance_df = importance_df.head(num_features).reset_index(drop=True)

    if show_plot and not importance_df.empty:
        plt.figure(figsize=plot_size)
        colors = plt.cm.PuOr(np.linspace(0.2, 1, len(importance_df)))

        bars = plt.barh(importance_df['Feature'],
                        importance_df['Local_Weight'],
                        color=colors,
                        alpha=0.9)

        plt.xlabel('Local Feature Weight', fontsize=12)
        plt.title('LIME-style Local Surrogate Model (Feature Importance)', fontsize=14, pad=20)
        plt.grid(axis='x', linestyle='--', alpha=0.3)

        for bar in bars:
            width = bar.get_width()
            plt.text(width + (0.01 if width >= 0 else -0.05),
                     bar.get_y() + bar.get_height() / 2,
                     f'{width:.2f}',
                     va='center',
                     fontsize=9)

        method_text = (
            f"Method: LIME-style Local Surrogate\n"
            f"Instance explained\n"
            f"Estimator: {type(model).__name__}"
        )
        plt.annotate(method_text,
                     xy=(0.02, 0.02),
                     xycoords='axes fraction',
                     ha='left',
                     va='bottom',
                     fontsize=9,
                     bbox=dict(boxstyle='round', alpha=0.1))

        plt.tight_layout()
        plt.show()

    return surrogate, importance_df

--------

## 3.7.4 Model Distillation-based Attribution

<details>
<summary style="cursor: pointer">
<h2> { Understanding Model Distillation-based Attribution } </h2>
</summary>
<h3> What is Model Distillation? </h3>
<p> Model distillation involves training a simpler (interpretable) model to imitate a more complex one by learning from its soft predictions (logits/probabilities) rather than the hard labels.</p>
<h3> Role in Interpretability: </h3>
<ul>
    <li> Transfers knowledge from a "teacher" model (e.g., a deep net) to a "student" model (e.g., a decision tree or logistic reg).</li>
    <li> Attribution can then be analyzed using the simpler model.</li>
    <li> This gives insight into what the complex model has learned and how it weighs features.</li>
</ul>
<h3> Resources: </h3>
<ol>
    <li><a href="https://arxiv.org/abs/1503.02531" target="_blank">Distilling the Knowledge in a Neural Network (Hinton et al.)</a></li>
    <li><a href="https://christophm.github.io/interpretable-ml-book/distillation.html" target="_blank">Interpretable ML Book — Model Distillation</a></li>
</ol>
</details>

##### Parameters:
- teacher_model: Trained complex model (e.g., ensemble, neural net) to explain
- X: Input features (DataFrame or array-like)
- y: True targets (array-like) or teacher_model predictions (preferred for distillation)
- student_model: Simpler model for distillation (e.g., DecisionTree, LinearRegression)
- random_state: Random seed for reproducibility
- show_plot: Whether to visualize feature importances
- plot_size: Tuple indicating plot size (width, height)
- fit_on_predictions: Whether to train student model on teacher's outputs or true labels

##### Returns:
- Fitted student model (interpretable surrogate)
- DataFrame of global feature importances
- Displays a plot of distilled feature importance if show_plot=True

In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def model_distillation_feature_importance(teacher_model,
                                           X,
                                           y,
                                           student_model,
                                           random_state=42,
                                           show_plot=True,
                                           plot_size=(10, 6),
                                           fit_on_predictions=True):
 
    if not isinstance(X, pd.DataFrame):
        X = pd.DataFrame(X)

    np.random.seed(random_state)

    # Distillation: Fit on teacher's outputs if preferred
    if fit_on_predictions:
        if hasattr(teacher_model, 'predict_proba'):
            distilled_y = teacher_model.predict_proba(X)[:, 1]
        else:
            distilled_y = teacher_model.predict(X)
    else:
        distilled_y = y

    # Train the student model
    student_model.fit(X, distilled_y)

    # Extract feature importances
    if hasattr(student_model, "feature_importances_"):
        importances = student_model.feature_importances_
    elif hasattr(student_model, "coef_"):
        importances = np.abs(student_model.coef_)
    else:
        raise ValueError("Student model must have either feature_importances_ or coef_ attribute.")

    feature_names = X.columns.tolist()

    importance_df = pd.DataFrame({
        'Feature': feature_names,
        'Importance': importances
    }).sort_values('Importance', ascending=False).reset_index(drop=True)

    if show_plot and not importance_df.empty:
        plt.figure(figsize=plot_size)
        colors = plt.cm.magma(np.linspace(0.2, 1, len(importance_df)))

        bars = plt.barh(importance_df['Feature'],
                        importance_df['Importance'],
                        color=colors,
                        alpha=0.9)

        plt.xlabel('Distilled Feature Importance', fontsize=12)
        plt.title('Model Distillation (Feature Attribution)', fontsize=14, pad=20)
        plt.gca().invert_yaxis()
        plt.grid(axis='x', linestyle='--', alpha=0.3)

        for bar in bars:
            width = bar.get_width()
            plt.text(width + 0.01,
                     bar.get_y() + bar.get_height() / 2,
                     f'{width:.3f}',
                     va='center',
                     fontsize=9)

        method_text = (
            f"Teacher: {type(teacher_model).__name__}\n"
            f"Student: {type(student_model).__name__}\n"
            f"Fit on: {'Teacher predictions' if fit_on_predictions else 'True labels'}"
        )
        plt.annotate(method_text,
                     xy=(0.02, 0.02),
                     xycoords='axes fraction',
                     ha='left',
                     va='bottom',
                     fontsize=9,
                     bbox=dict(boxstyle='round', alpha=0.1))

        plt.tight_layout()
        plt.show()

    return student_model, importance_df

--------