In [2]:
%load_ext autoreload
%autoreload 2
from Functions import PredictiveAnalysis

# reload
# %reload_ext autoreload

In [None]:
import plotly.graph_objects as go
import dash
from dash import html, dcc, ALL, MATCH, Output, Input, State, no_update, ctx, Patch
import dash_mantine_components as dmc
import dash_bootstrap_components as dbc


In [None]:

# Sep up each component
legends = dmc.ChipGroup(
    id={'func': 'compare_perf', 'obj': 'legend'},
    children=[
        dmc.Chip(
            children=str(sc),
            value=str(sc),
            variant="outline", 
            color=colors[i]
            ) for i, sc in enumerate(sc_opts)
    ],
    value=[str(sc) for sc in sc_opts],
    multiple=True,
    style={'display': 'flex', 'justifyContent': 'center'}
)

graphs = html.Div(
    dbc.Row(
        [
            dbc.Col(
                dbc.Card(
                    dbc.CardBody([
                        dcc.Graph(
                            id={'func': 'compare_perf', 'obj': 'fig', 'id': str(i)},
                            figure=fig,
                        )
                    ]),
                    style={'margin': '5px'}
                ),
                width='auto'
            ) for i, fig in enumerate(figs)
        ],
        style={'display': 'flex', 'overflowX': 'auto', 'width': '100%'},
    ),
    style={'maxWidth': '100vw'}
)

@app.callback(
    output=Output({'func': 'compare_perf', 'obj': 'fig', 'id': ALL}, 'figure'),
    inputs=Input({'func': 'compare_perf', 'obj': 'legend'}, 'value'),
    state=State({'func': 'compare_perf', 'obj': 'fig', 'id': ALL}, 'figure')
)
def update_visibility(value, fig):
    # Determine which input was triggered
    triggered_id = ctx.triggered_id
    if triggered_id:
        # define output
        outputs = []
        # get the checked value as set 
        checked = set(value)
        # traversing all figure data
        for f in fig:
            # define patch
            p = Patch()
            for i in range(len(f['data'])):
                p['data'][i].update({'visible': f['data'][i]['name'] in checked})
            
            outputs.append(p)
        
        return outputs
    
    else:
        return no_update

In [None]:
app = dash.Dash(__name__)
# Set up the Dash app layout 
app.layout = html.Div([
    legends, graphs
])

if __name__ == '__main__':
    app.run_server(debug=False)

html.H3('Comparing the Logistic Regression Model Results on Different Conditions', 
            style={ 'color': 'white'}),
    html.Div(id='fixed-legend', style={'color': 'white'}, children=[
        dcc.Markdown("""
            - Scopes: how many months of the latest data is used for parameter adjustments.
            - Error Bars: Showing mean, min, and max of each measure among various future predictions.
        """)
    ]),

### Memo

In [None]:
import dash
from dash import Dash, html, dcc, Input, Output, State, ctx
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [None]:
# Assuming 'sample' is your DataFrame with the data you want to plot
# And it has columns 'X', 'Y1', 'Y2', ... for your data

# Sample data for three dataframes with similar structure
s1 = pd.DataFrame({
    'X': pd.date_range(start='1/1/2020', periods=100),
    'Y1': np.random.randn(100).cumsum(),
    'Y2': np.random.randn(100).cumsum(),
    'Y3': np.random.randn(100).cumsum()
})
s2 = pd.DataFrame({
    'X': pd.date_range(start='1/1/2020', periods=100),
    'Y1': np.random.randn(100).cumsum()*0.5,
    'Y2': np.random.randn(100).cumsum()*0.5,
    'Y3': np.random.randn(100).cumsum()*0.5
})  # Just for example, modify as needed

s3 = pd.DataFrame({
    'X': pd.date_range(start='1/1/2020', periods=100),
    'Y1': np.random.randn(100).cumsum()*2,
    'Y2': np.random.randn(100).cumsum()*2,
    'Y3': np.random.randn(100).cumsum()*2
})   # Just for example, modify as needed


# Define the number of graphs you want to create
fig = make_subplots(rows=1, cols=3, shared_yaxes=True)

# Define colors for traces to ensure consistency across subplots
colors = {'Y1': 'blue', 'Y2': 'red', 'Y3': 'green'}

for i, s in enumerate([s1, s2, s3], start=1):
    for col in ['Y1', 'Y2', 'Y3']:
        fig.add_trace(
            go.Scatter(
                x=s['X'],
                y=s[col],
                name=col,
                mode='lines+markers',
                marker=dict(color=colors[col]),
                showlegend=False, # Only the first subplot shows the legend,
                visible=True
            ),
            row=1, col=i
        )

# Update layout to Plotly's dark theme
fig.update_layout(
    plot_bgcolor='black',
    paper_bgcolor='black',
    font={'color': 'white'},
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    uirevision='constant' # keeps the user-selected legend state consistent across updates
)


app = dash.Dash(__name__)
# Set up the Dash app layout
app.layout = html.Div([
    html.Div(id='custom-legend', children=[
        html.Button('Y1', id='legend-y1', n_clicks=0),
        html.Button('Y2', id='legend-y2', n_clicks=0),
        html.Button('Y3', id='legend-y3', n_clicks=0)
    ], style={'display': 'flex', 'justifyContent': 'center'}),
    html.Div(style={'width': '600px', 'overflowX': 'scroll'}, children=[
        dcc.Graph(id='subplots-graph', figure=fig, style={'width': '1500px'})
    ])
])


@app.callback(
    output=Output('subplots-graph', 'figure'),
    inputs=dict(
        data=dict(
            y1=Input('legend-y1', 'n_clicks'),
            y2=Input('legend-y2', 'n_clicks'),
            y3=Input('legend-y3', 'n_clicks'),
        ),
    ),
    state=dict(fig=State('subplots-graph', 'figure'))
)
def update_graph_visibility(data, fig):
    # Determine which input was triggered
    triggered_id = ctx.triggered_id

    if triggered_id in {'legend-y1', 'legend-y2', 'legend-y3'}:
        # Get series name (e.g., 'Y1', 'Y2', 'Y3')
        series_name = triggered_id.split('-')[-1]
        # Toggle visibility
        visibility = False if data[series_name] % 2 == 1 else True
    
        # Update traces
        for trace in fig['data']:
            if trace['name'].lower() == series_name:
                trace['visible'] = visibility

    return fig


if __name__ == '__main__':
    app.run_server(debug=True)
