In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import sys
sys.path.append('../..')
import dvu
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from sklearn.tree import export_text, DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
from copy import deepcopy
import imodels
from viz import *
import plotly.graph_objects as go
from tqdm import tqdm

mpl.rcParams['figure.dpi'] = 250
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# look at model example

In [2]:
def plot_sim(n=50, std=1, reg_param=0, show=True, linear_data=False, return_curves=False):
    np.random.seed(13)
    
    if linear_data:
        def gt_func(X):
            return X
    else:
        def gt_func(X):
            return +1 * (X < 2) * (X < 1) + \
                   -0 * (X < 2) * (X > 1) + \
                   +1 * (X >= 2) * (X < 3) + \
                   +0 * (X >= 2) * (X > 3)
    

    # data to fit
    X = np.random.uniform(0, 4, n)
    X = np.sort(X)
    y = gt_func(X) + np.random.normal(0, 1, n) * std
    plt.plot(X, y, 'o', color='black', ms=4, alpha=0.5, markeredgewidth=0)


    # data to plot
    X_tile = np.linspace(0, 4, 500)
    y_tile = gt_func(X_tile)
    plt.plot(X_tile, y_tile, label='Groundtruth', color='black', lw=3)


    m1 = DecisionTreeRegressor(random_state=1) #, max_leaf_nodes=15)
    m1.fit(X.reshape(-1, 1), y)
    y_pred_dt = m1.predict(X_tile.reshape(-1, 1))
    plt.plot(X_tile, y_pred_dt, '-', label='CART', color=cb, alpha=0.5, lw=4)

    mshrunk = imodels.ShrunkTreeRegressor(deepcopy(m1), reg_param=reg_param)
    y_pred_shrunk = mshrunk.predict(X_tile.reshape(-1, 1))
    plt.plot(X_tile, y_pred_shrunk, label='Shrunk', color='#ff4b33', alpha=0.5, lw=4)


    plt.xlabel('X')
    plt.ylabel('Y')
    dvu.line_legend(adjust_text_labels=False)
    if show:
        plt.show()
    elif show=='save':
        plt.savefig(f'gif/{reg_param}.svg')
    else:
        plt.close()
        
    if return_curves:
        return X, y, X_tile, y_tile, y_pred_dt, y_pred_shrunk

#     print('dt', export_text(m1, feature_names=['X']))
#     print('dt', export_text(mshrunk.estimator_, feature_names=['X']))

# plot_sim(n=100)



In [None]:
from ipywidgets import interactive
from ipywidgets.widgets import *
interactive_plot = interactive(plot_sim,
                               n=IntSlider(value=50, min=20, max=300, step=10),
                               std=FloatSlider(value=0.5, min=0.01, max=3),
                               reg_param=(0, 150))
output = interactive_plot.children[-1]
output.layout.height = '800px'
interactive_plot # note this can't be exported to static

In [None]:
plot_sim(n=400, std=1, reg_param=100)
plt.savefig('intro_indicators.pdf')

In [None]:
plot_sim(n=400, std=1, reg_param=50, linear_data=True)
plt.savefig('intro_linear.pdf')

# try exporting to webpage
- ipywidgets doesn't support this properly
- bokeh requires writing JS code
- soln: save a bunch of svgs and then reopen them with a slider
    - ex: https://github.com/JanSellner/ImageSequenceAnimation
    - ```# save all images
for reg_param in [0, 1, 5, 10, 181]:
    plot_sim(reg_param=reg_param, show='save')
    plt.close()```
- soln: plotly can just store all the curves and then export

# plotly soln

In [4]:
!pip install plotly

Collecting plotly
  Using cached plotly-5.4.0-py2.py3-none-any.whl (25.3 MB)
Collecting tenacity>=6.2.0
  Using cached tenacity-8.0.1-py3-none-any.whl (24 kB)
Installing collected packages: tenacity, plotly
Successfully installed plotly-5.4.0 tenacity-8.0.1


In [3]:
X, y, X_tile, y_tile, y_pred_dt, y_pred_shrunk = plot_sim(n=50, std=1, reg_param=0,
                                                          show=False, linear_data=False, return_curves=True)

In [35]:
# Create figure
fig = go.Figure()
fig.layout.template = 'plotly_white'


# add permanent curves
kwargs = dict(visible=True, line=None)
fig.add_trace(
    go.Scatter(
        **kwargs,
        mode='markers',
        name="Raw data points",
        x=X,
        y=y,
    ))    
fig.add_trace(
    go.Scatter(
        **kwargs,
        name="Groundtruth data",
        line=dict('color'='black')
        x=X_tile,
        y=y_tile,
    ))    
fig.add_trace(
    go.Scatter(
        **kwargs,
        name="CART predictions",
        x=X_tile,
        y=y_pred_dt,
    ))   
OFFSET = 3
ACTIVE = 1

# Add traces, one for each slider step
VALS = [0, 1, 5, 10, 25, 50, 100]
for reg_param in tqdm(VALS):
    X, y, X_tile, y_tile, y_pred_dt, y_pred_shrunk = plot_sim(n=50, std=1, reg_param=reg_param,
                                                          show=False, linear_data=False, return_curves=True)
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=None, #dict(color="#00CED1", width=6),
            marker=dict(),
            name=f"Shrunk predictions",
            x=X_tile,
            y=y_pred_shrunk,
        ))




# Create and add slider
steps = []
for i in range(len(fig.data) - OFFSET): # remember offset
    step = dict(
        method="update",
        args=[{"visible": [False] * len(fig.data)},
              {"title": "Reg-param: " + str(VALS[i])}],  # layout attribute
    )
    step["args"][0]["visible"][i + OFFSET] = True  # Toggle i'th trace to "visible"
    step["args"][0]["visible"][0] = True  # final
    step["args"][0]["visible"][1] = True  # final    
    step["args"][0]["visible"][2] = True  # final    
    steps.append(step)

    
# Make ACTIVE trace visible
fig.data[ACTIVE + OFFSET].visible = True    
sliders = [dict(
    active=ACTIVE,
    currentvalue={"prefix": "Frequency: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    sliders=sliders
)


    
fig.write_html("export.html")
fig.show()



SyntaxError: invalid syntax (166774797.py, line 21)