In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import sys
sys.path.append('../..')
import dvu
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from sklearn.tree import export_text, DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
from copy import deepcopy
import imodels
from viz import *
from ipywidgets import interactive
from ipywidgets.widgets import *


mpl.rcParams['figure.dpi'] = 250
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# look at model example

In [69]:
def plot_sim(n=50, std=1, reg_param=0, show=True, linear_data=False):
    np.random.seed(13)
    
    if linear_data:
        def gt_func(X):
            return X
    else:
        def gt_func(X):
            return +1 * (X < 2) * (X < 1) + \
                   -0 * (X < 2) * (X > 1) + \
                   +1 * (X >= 2) * (X < 3) + \
                   +0 * (X >= 2) * (X > 3)
    
    


    # data to fit
    X = np.random.uniform(0, 4, n)
    y = gt_func(X) + np.random.normal(0, 1, n) * std
    plt.plot(X, y, 'o', color='black', ms=4, alpha=0.5, markeredgewidth=0)


    # data to plot
    X_tile = np.linspace(0, 4, 500)
    y_tile = gt_func(X_tile)
    plt.plot(X_tile, y_tile, label='Groundtruth', color='black', lw=3)


    m1 = DecisionTreeRegressor(random_state=1) #, max_leaf_nodes=15)
    m1.fit(X.reshape(-1, 1), y)
    y_pred_dt = m1.predict(X_tile.reshape(-1, 1))
    plt.plot(X_tile, y_pred_dt, '-', label='CART', color=cb, alpha=0.5, lw=4)

    mshrunk = imodels.ShrunkTreeRegressor(deepcopy(m1), reg_param=reg_param)
    y_pred_shrunk = mshrunk.predict(X_tile.reshape(-1, 1))
    plt.plot(X_tile, y_pred_shrunk, label='Shrunk', color='#ff4b33', alpha=0.5, lw=4)


    plt.xlabel('X')
    plt.ylabel('Y')
    dvu.line_legend(adjust_text_labels=False)
    if show:
        plt.show()
    elif show=='save':
        plt.savefig(f'gif/{reg_param}.svg')

#     print('dt', export_text(m1, feature_names=['X']))
#     print('dt', export_text(mshrunk.estimator_, feature_names=['X']))

# plot_sim(n=100)



In [None]:
interactive_plot = interactive(plot_sim,
                               n=IntSlider(value=50, min=20, max=300, step=10),
                               std=FloatSlider(value=0.5, min=0.01, max=3),
                               reg_param=(0, 150))
output = interactive_plot.children[-1]
output.layout.height = '800px'
interactive_plot # note this can't be exported to static

In [None]:
plot_sim(n=400, std=1, reg_param=100)
plt.savefig('intro_indicators.pdf')

In [None]:
plot_sim(n=400, std=1, reg_param=50, linear_data=True)
plt.savefig('intro_linear.pdf')

# try exporting to webpage
- ipywidgets doesn't support this properly
- bokeh requires writing JS code
- soln: save a bunch of svgs and then reopen them with a slider
    - ex: https://github.com/JanSellner/ImageSequenceAnimation

In [44]:
# save all images
for reg_param in [0, 1, 5, 10, 181]:
    plot_sim(reg_param=reg_param, show='save')
    plt.close()