In [None]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf

import ipywidgets as w
import bqplot.pyplot as plt
import bqplot as bq

In [None]:
original_data = sm.datasets.engel.load_pandas().data
original_data = original_data.rename(columns={'income': 'x', 'foodexp': 'y'})

In [None]:
def quantile_loss(x, q=.5):
    return np.maximum(q * x, (q - 1) * x)

def fit_quantile_reg(q, data):
    qr_model = smf.quantreg('y ~ x', data)
    res = qr_model.fit(q=q)
    return res.params['Intercept'], res.params['x']

def fit_ols_reg(data):
    ols = smf.ols('y ~ x', data).fit()
    return ols.params['Intercept'], ols.params['x']

In [None]:
qloss_fig_title_tmpl = 'Quantile Loss Function (q = {q})'
q_slider = w.FloatSlider(description='Quantile', 
                         value=.5, min=0.05, max=.95,
                         step=.05, continuous_update=False)

qloss_fig = plt.figure(animation_duration=750,
                       layout={'width': '500px', 'height': '400px'})
x = np.arange(-10, 10, .1)
qloss_line = plt.plot(x, [], colors=['dodgerblue'])

In [None]:
scat_fig = plt.figure(title='Quantile Regression', 
                      animation_duration=750,
                      layout={'width': '900px', 'height': '600px'},
                      legend_location='top-left')
plt.scales(scales={'x': bq.LinearScale(min=300, max=2600),
                   'y': bq.LinearScale(min=200, max=1600)})

scat = plt.scatter(original_data['x'], original_data['y'], colors=['limegreen'], 
                   default_opacities=[.6],
                   enable_move=True,
                   interactions={'click': 'add'},
                   default_size=50, stroke='black')
plt.xlabel('X')
plt.ylabel('Y')

reg_lines = plt.plot(original_data['x'], [],
                     colors=['#ccc', 'magenta'], 
                     labels=['Linear Regression', 'Quantile Regression'],
                     display_legend=True)

qr_lines = plt.plot(original_data['x'], [], 'y-.', opacities=[.3])

show_qr_lines_cb = w.Checkbox(description='Display all QR lines', value=False)

reset_btn = w.Button(description='Reset Points', button_style='success')
reset_btn.layout.margin = '0px 0px 0px 50px'

def update_reg_lines(*args):
    new_data = pd.DataFrame({'x': scat.x, 'y': scat.y})
    
    q = q_slider.value
    qloss_line.y = quantile_loss(x, q=q)
    qloss_fig.title = qloss_fig_title_tmpl.format(q=q)
    
    qr_a, qr_b = fit_quantile_reg(q, new_data)
    ols_a, ols_b = fit_ols_reg(new_data)
    with reg_lines.hold_sync():
        reg_lines.x = new_data['x']
        reg_lines.y = [ols_a + ols_b * new_data['x'], 
                       qr_a + qr_b * new_data['x']]
    
    qrs = []
    for q in np.arange(.1, 1, .1):
        qr_a, qr_b = fit_quantile_reg(q, new_data)
        qrs.append(qr_a + qr_b * new_data['x'])
    
    with qr_lines.hold_sync():
        qr_lines.x = new_data['x']
        qr_lines.y = qrs
        
def reset_points():
    with scat.hold_trait_notifications():
        scat.x = original_data['x']
        scat.y = original_data['y']

reset_btn.on_click(lambda btn: reset_points())
_ = w.link((show_qr_lines_cb, 'value'), (qr_lines, 'visible'))

scat.observe(update_reg_lines, names=['x', 'y'])
q_slider.observe(update_reg_lines)

In [None]:
update_reg_lines()
w.HBox([w.VBox([q_slider, qloss_fig]), w.VBox([scat_fig, w.HBox([reset_btn, show_qr_lines_cb])])])