In [1]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

In [2]:
df = pd.read_csv('../forecasts/data/target.csv')
df.dates = pd.to_datetime(df.dates)
# let's add a column that it will be used in the legend following this hack: https://github.com/altair-viz/altair/issues/984
df['legend'] = 'Data'
df.head()

Unnamed: 0,dates,target,legend
0,2021-12-26,54.0,Data
1,2022-01-02,212.0,Data
2,2022-01-09,236.0,Data
3,2022-01-16,174.0,Data
4,2022-01-23,115.0,Data


In [3]:
df_for = pd.read_csv('../forecasts/data/forecasts.csv')
df_for.dates = pd.to_datetime(df_for.dates)
df_for

Unnamed: 0,dates,model,predictions,lower,upper
0,2021-12-26,RF,224.030000,0.000000,657.151984
1,2022-01-02,RF,311.150000,0.000000,926.336532
2,2022-01-09,RF,989.630000,0.000000,2139.728589
3,2022-01-16,RF,1550.550000,104.709219,2996.390781
4,2022-01-23,RF,368.890000,0.000000,2020.111839
...,...,...,...,...,...
315,2023-06-04,DL - cluster,304.602886,158.264750,426.872812
316,2023-06-11,DL - cluster,346.068766,160.559812,536.159720
317,2023-06-18,DL - cluster,343.693966,191.552979,543.737511
318,2023-06-25,DL - cluster,363.179827,180.028784,503.317822


In [4]:
#Unique models
df_for.model.unique()

array(['RF', 'DL', 'RF - cluster', 'DL - cluster'], dtype=object)

In [279]:
#creating separate dataframes for each model
df_for_RF = df_for[df_for.model == 'RF']
df_for_DL = df_for[df_for.model == 'DL']
df_for_RLCluster = df_for[df_for.model == 'RF - cluster']
df_for_DLCluster = df_for[df_for.model == 'DL - cluster']

In [6]:
# Simple Scatter plot
scatter = go.Scatter(x=df.dates, y=df.target, mode='markers', name='Data')

#forecasts as line and colores by model
frf = go.Scatter(x=df_for_RF.dates, 
                 y=df_for_RF.predictions, 
                 mode='lines', trace='RF', 
                 )
fdl = go.Scatter(x=df_for_DL.dates, 
                 y=df_for_DL.predictions, 
                 mode='lines', trace='DL', 
                 )

frfc = go.Scatter(x=df_for_RLCluster.dates, 
                  y=df_for_RLCluster.predictions, 
                  mode='lines', trace='RL - Cluster', 
                 )

fdlc = go.Scatter(x=df_for_DLCluster.dates, 
                  y=df_for_DLCluster.predictions, 
                  mode='lines', name='DL - Cluster', 
                 )

#Select the line plot on the second subplot based on the selection of the model on the first subplot

# Data component 
temp_fig = make_subplots(rows=1, cols=2)
temp_fig.add_trace(scatter, row=1, col=1)
temp_fig.add_trace(frf, row=1, col=1)
temp_fig.add_trace(fdl, row=1, col=1)
temp_fig.add_trace(frfc, row=1, col=1)
temp_fig.add_trace(fdlc, row=1, col=1)
temp_fig.add_trace(scatter, row=1, col=2)
temp_fig.update_layout(title_text="Subplots")
temp_fig.show()

In [308]:
#forecasts as line and colores by model
f = go.FigureWidget()
for model in df_for.model.unique(): 
    df_for_model = df_for[df_for.model == model]
    f.add_trace(go.Scatter(x=df_for_model.dates, y=df_for_model.predictions, mode='lines', name=model))
f.add_trace(go.Scatter(x=df.dates, y=df.target, mode='markers', name='Data'))

FigureWidget({
    'data': [{'mode': 'lines',
              'name': 'RF',
              'type': 'scatter',
   …

In [305]:
default_linewidth = 2
highlighted_linewidth_delta = 2

m = {"RF": df_for_RF, "DL": df_for_DL, "RF - cluster": df_for_RLCluster, "DL - cluster": df_for_DLCluster}

def update_trace(trace, points, selector):
    # this list stores the points which were clicked on
    # in all but one trace they are empty
    if len(points.point_inds) == 0:
        return
        
    for i,_ in enumerate(f.data):
        f.data[i]['line']['width'] = default_linewidth + highlighted_linewidth_delta * (i == points.trace_index)

    if len(f.data) > 6:
        f.data = f.data[:6]

    #redraw the clicked trace on the second subplot
    trace = points.trace_index
    model = f.data[trace].name
    df_for_model = m[model]
    x = m[model].dates.to_list()
    x_rev = x[::-1]

    y = m[model].predictions.to_list()
    y_upper = m[model].upper.to_list()
    y_lower = m[model].lower.to_list()
    y_rev = y[::-1]

    color_hex = f.data[trace]['line']['color'].replace('#','')
    color_rgba = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
    color_rgba = 'rgba' + str(color_rgba).replace(')', ', 0.2)')

    f.add_trace(go.Scatter(
        name="var1",
        x=x+x_rev,
        y=y_upper,
        fill='none',
        fillcolor=color_rgba,
        line_color='rgba(255,255,255,0)',
        showlegend=False,
    ), row=1, col=2)

    f.add_trace(go.Scatter(
        name="var2",
        x = df_for_RF.dates,
        y=y_lower,
        fillcolor=color_rgba,
        line_color='rgba(255,255,255,0)',
        fill='tonexty',
        showlegend=False,),  row=1, col=2)
    
    f.add_trace(go.Scatter(x=df_for_model.dates, 
                           y=df_for_model.predictions, 
                           mode='lines', 
                           line_color=f.data[trace]['line']['color'],
                           name=model, 
                           showlegend=False,), 
                            row=1, col=2)


In [307]:
#Creating a figure with two subplots using FigureWidget
sub = make_subplots(rows=1, cols=2)

f = go.FigureWidget(sub)
for model in df_for.model.unique(): 
    df_for_model = df_for[df_for.model == model]
    f.add_trace(go.Scatter(x=df_for_model.dates, y=df_for_model.predictions, mode='lines', name=model))

# we need to add the on_click event to each trace separately       
for i in range( len(f.data) ):
    f.data[i].on_click(update_trace)

f.add_trace(go.Scatter(x=df.dates, y=df.target,
                        mode='markers', 
                        marker={'color':'rgba(50,50,50,0.7)'},
                        name='Data'), row=1, col=1)
f.add_trace(go.Scatter(x=df.dates, y=df.target, 
                       mode='markers',
                       name='Data',
                       marker={'color':'rgba(50,50,50,0.7)'},
                       showlegend=False), row=1, col=2)



f.update_layout(title_text="Forecast")
f

FigureWidget({
    'data': [{'mode': 'lines',
              'name': 'RF',
              'type': 'scatter',
   …