# Visualisation tool for cointegration 
This tool aims to create an interactive plot for user to visualise the relationship between a pair of stocks (available in the dataset `close_df.csv`). One could try using different pairs, methods and rolling windows and see how the dynamic changes.

In [1]:
from CointegrationCalculation import CointegrationCalculation
import ipywidgets as ipw
from IPython.display import display
import plotly.express as px 
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.offline import iplot
from matplotlib.dates import date2num, num2date

In [2]:
# error will be shown at first because no stock is inputted, but will run normally after inputting valid stocks
input_stock1 = ipw.Text(
    placeholder='type 1st ticker',
    description='Stock 1:',
)
input_stock2 = ipw.Text(
    placeholder='type 2nd ticker',
    description='Stock 2:',
)

tls_cal_checkbox = ipw.Checkbox(
    value=True,
    description='Total Least Square calculation'
)
roll_reg_checkbox = ipw.Checkbox(
    value=False,
    description='Rolling Regression calculation',
)
roll_window_slider = ipw.IntSlider(
    value=252, 
    min=10, 
    max=378, 
    step=1, 
    description='Rolling window: ',
    disabled=True,
    style={'description_width': 'initial'},
    layout=ipw.Layout(width='350px')
)
kf_cal_checkbox = ipw.Checkbox(
    value=False,
    description='Kalman Filter calculation'
)


def plot_data(stock1, stock2, show_tls, show_roll_reg, roll_window, show_kf):
    cm = CointegrationCalculation(stock1, stock2)

    fig = make_subplots(rows=5, 
                        cols=1, 
                        subplot_titles=(f'log prices of {stock1} and {stock2}', 
                                        f'{stock1} vs {stock2}', 
                                        'Estimated Intercept', 
                                        'Estimated Slope (Cointegration Ratio)', 
                                        'Estimated Spread'),
                        vertical_spacing=0.07)

    fig1_stock1 = go.Scatter(x=cm.data.index, y=cm.data[stock1], mode='lines', name=stock1)
    fig1_stock2 = go.Scatter(x=cm.data.index, y=cm.data[stock2], mode='lines', name=stock2)

    selected_dates = cm.data.index[::int(len(cm.data.index) / 4)]
    selected_numeric_dates = date2num(selected_dates)
    cb_text = [num2date(date_num).date() for date_num in selected_numeric_dates]
    fig2 = go.Scatter(x=cm.data[stock1],
                      y=cm.data[stock2],
                      mode='markers',
                      name='log price',
                      legendgroup = 'compare prices',
                      showlegend=False, 
                      marker=dict(
                          size=6, 
                          color=date2num(cm.data.index),
                          line=dict(color='black', 
                                    width=0.8
                                   ),
                          colorbar={
                              'tickvals': selected_numeric_dates,
                              'ticktext': cb_text,
                              'thickness': 10,
                              'x': 1,
                              'y': 0.714,
                              'len': 0.164,
                          }, 
                          colorscale='Jet',
                          showscale=True,
                      )
    )
    
    if show_tls:
        fig3_tls = go.Scatter(x=cm.tls_res[2].index,
                              y=[cm.tls_res[0]] * len(cm.tls_res[2].index),
                              mode='lines',
                              name='tls intercept',
                              legendgroup = 'tls',
                              showlegend=True
                             )
        fig4_tls = go.Scatter(x=cm.tls_res[2].index, 
                              y=[cm.tls_res[1]] * len(cm.tls_res[2].index),
                              mode='lines', 
                              name='tls slope', 
                              legendgroup = 'tls', 
                              showlegend=True)
        fig5_tls = go.Scatter(x=cm.tls_res[2].index,
                              y=cm.tls_res[2].values, 
                              mode='lines', 
                              name=f'tls spread: adf t-stat={cm.adf_test(cm.tls_res[2]):.2f}', 
                              legendgroup = 'tls', 
                              showlegend=True
                             )
        fig.append_trace(fig3_tls, 3, 1)
        fig.append_trace(fig4_tls, 4, 1)
        fig.append_trace(fig5_tls, 5, 1)
    
    if show_roll_reg:
        roll_window_slider.disabled = False
        roll_window = roll_window_slider.value
        cm.rolling_reg_res = cm.rolling_regression(roll_window)
        fig3_rr = go.Scatter(x=cm.rolling_reg_res.index, 
                             y=cm.rolling_reg_res.rolling_beta0,
                             mode='lines',
                             name='roll reg intercept',
                             legendgroup = 'roll reg',
                             showlegend=True
                            )
        fig4_rr = go.Scatter(x=cm.rolling_reg_res.index, 
                             y=cm.rolling_reg_res.rolling_beta1,
                             mode='lines', 
                             name='roll reg slope', 
                             legendgroup = 'roll reg', 
                             showlegend=True)
        fig5_rr = go.Scatter(x=cm.rolling_reg_res.index, 
                             y=cm.rolling_reg_res.rolling_spread, 
                             mode='lines', 
                             name=f'roll reg spread: adf t-stat={cm.adf_test(cm.rolling_reg_res.rolling_spread):.2f}', 
                             legendgroup = 'roll reg', 
                             showlegend=True
                            )
        fig.append_trace(fig3_rr, 3, 1)
        fig.append_trace(fig4_rr, 4, 1)
        fig.append_trace(fig5_rr, 5, 1)
    else:
        roll_window_slider.disabled = True
        
    
    if show_kf:
        fig3_kf = go.Scatter(x=cm.kf_res.index, 
                             y=cm.kf_res.kf_beta0,
                             mode='lines',
                             name='kf intercept',
                             legendgroup = 'kf',
                             showlegend=True
                            )
        fig4_kf = go.Scatter(x=cm.kf_res.index, 
                             y=cm.kf_res.kf_beta1,
                             mode='lines', 
                             name='kf slope', 
                             legendgroup = 'kf', 
                             showlegend=True)
        fig5_kf = go.Scatter(x=cm.kf_res.index, 
                             y=cm.kf_res.kf_spread, 
                             mode='lines', 
                             name=f'kf spread: adf t-stat={cm.adf_test(cm.kf_res.kf_spread):.2f}', 
                             legendgroup = 'kf', 
                             showlegend=True
                            )
        fig.append_trace(fig3_kf, 3, 1)
        fig.append_trace(fig4_kf, 4, 1)
        fig.append_trace(fig5_kf, 5, 1)

    fig.append_trace(fig1_stock1, 1, 1)
    fig.append_trace(fig1_stock2, 1, 1)
    fig.append_trace(fig2, 2, 1)
    fig['layout'].update(height=1200, 
                         width=950,
                         legend={
                             'orientation': 'h',
                             'yanchor': 'top'
                         }
                        )
    iplot(fig)

output = ipw.interactive_output(plot_data,
                                {'stock1': input_stock1, 
                                'stock2': input_stock2, 
                                'show_tls': tls_cal_checkbox, 
                                'show_roll_reg': roll_reg_checkbox, 
                                'roll_window': roll_window_slider, 
                                'show_kf': kf_cal_checkbox}
                               )
ipw.VBox([ipw.HBox([input_stock1, input_stock2]), 
          tls_cal_checkbox, 
          ipw.HBox([roll_reg_checkbox, roll_window_slider]), 
          kf_cal_checkbox,
          output
         ]
        )

VBox(children=(HBox(children=(Text(value='', description='Stock 1:', placeholder='type 1st ticker'), Text(valu…