In [2]:
import pandas as pd
import numpy as np

In [62]:
#Single Feature ALE
def ALE(model, X, select_feature, grid_resolution):
    # model: fitted machine learning model
    # X: Training set
    # select_feature: which feature is calculated on
    # grid_resolution: number of intervals
    
    #find the number of points in one interval
    nj_k = int(np.floor(X.shape[0]/grid_resolution))

    #sort X according to select feature
    X_copy = X.copy()
    X_sorted = X_copy[X_copy[:, select_feature].argsort()]

    #partition X
    X_splitted = []
    z = [X_sorted[:, select_feature].min()-X_sorted[:, select_feature].min()/10]
    for i in range(0, X_sorted.shape[0], nj_k):
        if i+nj_k > X_sorted.shape[0]:
            X_i = X_sorted[i:, :]
        else:
            X_i = X_sorted[i:i+nj_k, :]
        X_splitted.append(X_i)
        z.append(X_i[:, select_feature].max())
    
    
    #calculate uncentered effect
    ale = []
    ale_i = 0
    for i in range(len(X_splitted)):
        X_i = X_splitted[i]
        X_i_lower = X_i.copy()
        X_i_upper = X_i.copy()
        X_i_lower[:, select_feature] = z[i]
        X_i_upper[:, select_feature] = z[i+1]
        pred_upper = model.predict(X_i_upper)
        pred_lower = model.predict(X_i_lower)
        diff = (pred_upper - pred_lower)
        ave_diff = diff.mean()
        ale_i+=ave_diff
        ale.append(ale_i)
    
    #calculate centered effect
    ale_expectation = np.sum([X_splitted[i].shape[0]*ale[i] for i in range(len(ale))])/X_sorted.shape[0]
    ale_centered = ale - ale_expectation 
    
    return ale_centered, z[1:]

In [4]:
from sklearn.datasets import load_boston

X_all, y_all = load_boston().data, load_boston().target

In [5]:
from sklearn.ensemble import RandomForestRegressor

RF = RandomForestRegressor(random_state=0, n_estimators=100)
RF.fit(X_all, y_all)
RF.score(X_all, y_all)

0.9813474035784516

In [153]:
ale, z = ALE(model=RF, X=X_all, select_feature=3, grid_resolution=50)

In [154]:
len(ale)

51

In [155]:
len(z)

51

In [156]:
ale

array([-0.01121976, -0.01121976, -0.01121976, -0.01121976, -0.01121976,
       -0.01121976, -0.01121976, -0.01121976, -0.01121976, -0.01121976,
       -0.01121976, -0.01121976, -0.01121976, -0.01121976, -0.01121976,
       -0.01121976, -0.01121976, -0.01121976, -0.01121976, -0.01121976,
       -0.01121976, -0.01121976, -0.01121976, -0.01121976, -0.01121976,
       -0.01121976, -0.01121976, -0.01121976, -0.01121976, -0.01121976,
       -0.01121976, -0.01121976, -0.01121976, -0.01121976, -0.01121976,
       -0.01121976, -0.01121976, -0.01121976, -0.01121976, -0.01121976,
       -0.01121976, -0.01121976, -0.01121976, -0.01121976, -0.01121976,
       -0.01121976, -0.01121976,  0.14648024,  0.14648024,  0.14648024,
        0.14648024])

In [157]:
z

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0]

In [161]:
X_all[:10]

array([[6.3200e-03, 1.8000e+01, 2.3100e+00, 0.0000e+00, 5.3800e-01,
        6.5750e+00, 6.5200e+01, 4.0900e+00, 1.0000e+00, 2.9600e+02,
        1.5300e+01, 3.9690e+02, 4.9800e+00],
       [2.7310e-02, 0.0000e+00, 7.0700e+00, 0.0000e+00, 4.6900e-01,
        6.4210e+00, 7.8900e+01, 4.9671e+00, 2.0000e+00, 2.4200e+02,
        1.7800e+01, 3.9690e+02, 9.1400e+00],
       [2.7290e-02, 0.0000e+00, 7.0700e+00, 0.0000e+00, 4.6900e-01,
        7.1850e+00, 6.1100e+01, 4.9671e+00, 2.0000e+00, 2.4200e+02,
        1.7800e+01, 3.9283e+02, 4.0300e+00],
       [3.2370e-02, 0.0000e+00, 2.1800e+00, 0.0000e+00, 4.5800e-01,
        6.9980e+00, 4.5800e+01, 6.0622e+00, 3.0000e+00, 2.2200e+02,
        1.8700e+01, 3.9463e+02, 2.9400e+00],
       [6.9050e-02, 0.0000e+00, 2.1800e+00, 0.0000e+00, 4.5800e-01,
        7.1470e+00, 5.4200e+01, 6.0622e+00, 3.0000e+00, 2.2200e+02,
        1.8700e+01, 3.9690e+02, 5.3300e+00],
       [2.9850e-02, 0.0000e+00, 2.1800e+00, 0.0000e+00, 4.5800e-01,
        6.4300e+00, 5.8700e