In [1]:
from itertools import combinations

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from pmlb import fetch_data

import pdpexplorer
from pdpexplorer.pdp import partial_dependence

In [2]:
df = fetch_data('churn')

In [3]:
df.to_dict(orient='list')

{'state': [16,
  35,
  31,
  35,
  36,
  1,
  19,
  24,
  18,
  49,
  15,
  39,
  12,
  26,
  12,
  34,
  13,
  46,
  45,
  43,
  9,
  5,
  3,
  40,
  45,
  29,
  50,
  26,
  24,
  11,
  14,
  30,
  18,
  3,
  36,
  10,
  0,
  19,
  0,
  31,
  10,
  20,
  2,
  13,
  48,
  37,
  22,
  8,
  13,
  50,
  12,
  15,
  45,
  44,
  50,
  15,
  5,
  5,
  48,
  36,
  13,
  4,
  23,
  41,
  27,
  47,
  50,
  23,
  44,
  43,
  31,
  23,
  32,
  33,
  11,
  23,
  7,
  34,
  23,
  20,
  49,
  37,
  5,
  10,
  43,
  46,
  15,
  49,
  17,
  31,
  7,
  18,
  2,
  21,
  3,
  0,
  26,
  3,
  25,
  13,
  12,
  19,
  46,
  17,
  15,
  3,
  22,
  32,
  5,
  44,
  22,
  19,
  1,
  8,
  3,
  21,
  31,
  33,
  24,
  13,
  29,
  43,
  20,
  25,
  33,
  3,
  25,
  25,
  43,
  29,
  29,
  4,
  31,
  43,
  2,
  1,
  34,
  9,
  0,
  33,
  8,
  20,
  22,
  46,
  46,
  22,
  47,
  44,
  18,
  39,
  12,
  29,
  35,
  39,
  36,
  18,
  35,
  45,
  24,
  9,
  40,
  45,
  12,
  22,
  36,
  8,
  10,
  7,
  3,
  1,
  5,
  

In [4]:
np.where(np.array([1, 2, 3, 4, 5, 6]) == 3)

(array([2]),)

In [5]:
df

Unnamed: 0,state,account length,area code,phone number,international plan,voice mail plan,number vmail messages,total day minutes,total day calls,total day charge,...,total eve calls,total eve charge,total night minutes,total night calls,total night charge,total intl minutes,total intl calls,total intl charge,number customer service calls,target
0,16,128.0,415.0,2845,0,1,25.0,265.1,110.0,45.07,...,99.0,16.78,244.7,91.0,11.01,10.0,3.0,2.70,1.0,0
1,35,107.0,415.0,2301,0,1,26.0,161.6,123.0,27.47,...,103.0,16.62,254.4,103.0,11.45,13.7,3.0,3.70,1.0,0
2,31,137.0,415.0,1616,0,0,0.0,243.4,114.0,41.38,...,110.0,10.30,162.6,104.0,7.32,12.2,5.0,3.29,0.0,0
3,35,84.0,408.0,2510,1,0,0.0,299.4,71.0,50.90,...,88.0,5.26,196.9,89.0,8.86,6.6,7.0,1.78,2.0,0
4,36,75.0,415.0,155,1,0,0.0,166.7,113.0,28.34,...,122.0,12.61,186.9,121.0,8.41,10.1,3.0,2.73,3.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,11,50.0,408.0,2000,0,1,40.0,235.7,127.0,40.07,...,126.0,18.96,297.5,116.0,13.39,9.9,5.0,2.67,2.0,0
4996,49,152.0,415.0,394,0,0,0.0,184.2,90.0,31.31,...,73.0,21.83,213.6,113.0,9.61,14.7,2.0,3.97,3.0,1
4997,7,61.0,415.0,313,0,0,0.0,140.6,89.0,23.90,...,128.0,14.69,212.4,97.0,9.56,13.6,4.0,3.67,1.0,0
4998,7,109.0,510.0,3471,0,0,0.0,188.8,67.0,32.10,...,92.0,14.59,224.4,89.0,10.10,8.5,6.0,2.30,0.0,0


In [6]:
df_X = df.drop(columns=[
    'target', 'state', 'phone number',
    'total day charge', 'total night charge',
    'total eve charge'
])
y = df['target'].values

In [7]:
regr = RandomForestRegressor(n_estimators=100, max_features='sqrt')
regr.fit(df_X, y)

In [8]:
features = list(df_X.columns)
pairs = list(combinations(features, 2))

In [9]:
pd_data = partial_dependence(
    predict=regr.predict,
    df=df_X,
    one_way_features=features,
    categorical_features={'area code', 'international plan', 'voice mail plan'},
    two_way_feature_pairs=pairs,
    n_instances=1000,
    resolution=10,
    n_jobs=4,
)

In [10]:
w = pdpexplorer.PDPExplorerWidget(
    predict=regr.predict,
    df=df_X,
    pd_data=pd_data,
    n_jobs=4,
    height=600
)

w

PDPExplorerWidget(dataset={'account length': [120.0, 74.0, 125.0, 132.0, 36.0, 99.0, 49.0, 107.0, 73.0, 157.0,…

In [11]:
w.dataset

{'account length': [120.0,
  74.0,
  125.0,
  132.0,
  36.0,
  99.0,
  49.0,
  107.0,
  73.0,
  157.0,
  120.0,
  35.0,
  64.0,
  106.0,
  53.0,
  72.0,
  35.0,
  69.0,
  122.0,
  80.0,
  42.0,
  132.0,
  55.0,
  43.0,
  80.0,
  111.0,
  128.0,
  124.0,
  64.0,
  146.0,
  160.0,
  106.0,
  86.0,
  167.0,
  158.0,
  53.0,
  54.0,
  79.0,
  72.0,
  17.0,
  118.0,
  94.0,
  41.0,
  95.0,
  118.0,
  114.0,
  64.0,
  42.0,
  173.0,
  89.0,
  137.0,
  146.0,
  90.0,
  1.0,
  122.0,
  137.0,
  123.0,
  81.0,
  136.0,
  121.0,
  146.0,
  30.0,
  116.0,
  24.0,
  45.0,
  93.0,
  169.0,
  174.0,
  183.0,
  13.0,
  74.0,
  120.0,
  210.0,
  99.0,
  85.0,
  58.0,
  175.0,
  77.0,
  161.0,
  134.0,
  106.0,
  139.0,
  130.0,
  69.0,
  136.0,
  75.0,
  105.0,
  102.0,
  157.0,
  122.0,
  112.0,
  59.0,
  63.0,
  140.0,
  115.0,
  98.0,
  121.0,
  69.0,
  119.0,
  94.0,
  7.0,
  95.0,
  110.0,
  136.0,
  114.0,
  27.0,
  60.0,
  130.0,
  107.0,
  176.0,
  93.0,
  19.0,
  44.0,
  85.0,
  109.0,
  169.