In [1]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, accuracy_score, mean_absolute_error
from lofo_importance import LOFOImportance

In [2]:
df = pd.DataFrame()

DATA_SIZE = 1000

df["A"] = np.random.rand(DATA_SIZE)
df["B"] = np.random.rand(DATA_SIZE)
df["C"] = np.random.rand(DATA_SIZE)
df["D"] = np.random.rand(DATA_SIZE)

df["target"] = 0.2*np.random.rand(DATA_SIZE) + df["A"]*df["D"] + 2*df["B"]
df["binary_target"] = (df["target"] > df["target"].median()).astype(int)
df.head()

Unnamed: 0,A,B,C,D,target,binary_target
0,0.797733,0.659586,0.399888,0.810615,1.996598,1
1,0.200037,0.982512,0.641724,0.257339,2.109312,1
2,0.611341,0.093312,0.049886,0.373671,0.534735,0
3,0.039541,0.372044,0.796767,0.050315,0.755376,0
4,0.665328,0.549839,0.190598,0.05431,1.275965,0


In [3]:
lr = LinearRegression()
rf = RandomForestClassifier(max_depth=5)

In [4]:
fi = LOFOImportance(lr, mean_absolute_error, df, 
                    ["A", "B", "C", "D"],
                    "target", needs_proba=False, greater_is_better=False)

importances = fi.get_importance()
importances

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Unnamed: 0,feature,importance_mean,importance_std
1,B,0.42523,0.006381
0,A,0.056764,0.005378
3,D,0.05527,0.002609
2,C,-4.3e-05,7.9e-05


In [5]:
fi = LOFOImportance(rf, accuracy_score, df, 
                    ["A", "B", "C", "D"],
                    "binary_target", needs_proba=False, greater_is_better=True)

importances = fi.get_importance()
importances

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Unnamed: 0,feature,importance_mean,importance_std
1,B,0.372,0.010198
0,A,0.044,0.012329
3,D,0.04,0.011662
2,C,-0.007,0.004359


In [6]:
fi = LOFOImportance(rf, roc_auc_score, df, 
                    ["A", "B", "C", "D"],
                    "binary_target", needs_proba=True, greater_is_better=True)

importances = fi.get_importance()
importances

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Unnamed: 0,feature,importance_mean,importance_std
1,B,0.370519,0.016843
0,A,0.025242,0.006508
3,D,0.016057,0.000733
2,C,-0.001103,0.003727
