In [173]:
from __future__ import print_function
from ipywidgets import interact, IntSlider, FloatSlider
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
from sklearn.ensemble import RandomForestClassifier

%matplotlib inline

random_state = 42

colours = ['blue', 'red', 'green', 'orange']

tabcolours = ['tab:' + c for c in colours]

def gen_range(r):
    return np.arange(r["min"], r["max"], 0.1)

def plot_boundary(clf, X, y, groups):
    ranges = [{"min": X[:, i].min() - 1, "max": X[:, i].max() + 1} for i in [0, 1]]
    xc, yc  = np.meshgrid(gen_range(ranges[0]), gen_range(ranges[1]))
    p = clf.predict(np.c_[xc.ravel(), yc.ravel()])
    z = p.reshape(xc.shape)
    fig, ax = plt.subplots(figsize=(14, 8))
    ax.contourf(xc, yc, z, levels=groups - 1, alpha=0.3, colors=colours)
    ax.scatter(X[:,0], X[:,1], c=list(map((lambda col: tabcolours[int(col)]), y)), s=70, marker='x')
    plt.show()

def gen_set(samples, noise, groups):
    x, t = datasets.make_s_curve(n_samples=samples, noise=noise, random_state=random_state)
    return np.c_[x[:, 0], x[:, 1]], list(map((lambda p: round(p) % groups), t))

def random_forest_classifier(X, y, trees, max_depth, min_samples_split):
    return RandomForestClassifier(n_estimators=trees, max_depth=max_depth, min_samples_split=min_samples_split, random_state=random_state).fit(X, y)

def action(samples, noise, groups, trees, max_depth, min_samples_split):
    X, y = gen_set(samples, noise, groups)
    clf = random_forest_classifier(X, y, trees, max_depth, min_samples_split)
    plot_boundary(clf, X, y, groups)

def int_slider(min, max):
    return IntSlider(min=min, max=max, continuous_update=False)

def float_slider(min, max, step):
    return FloatSlider(min=min, max=max, step=step, continuous_update=False)

display(interactive(
    action,
    samples=int_slider(10, 1000),
    noise=float_slider(0.1, 0.9, 0.1),
    groups=int_slider(2, 4),
    trees=int_slider(2, 1000),
    max_depth=int_slider(2, 1000),
    min_samples_split=int_slider(2, 100)
))

interactive(children=(IntSlider(value=10, continuous_update=False, description='samples', max=1000, min=10), F…