In [None]:
import os; os.environ['KERAS_BACKEND'] = 'jax'
import dataset
import keras
import pandas as pd
import model
import plotly.express as px
import plotly.graph_objs as go
import plotly.subplots

In [None]:
os.chdir('..') # HACK, but this is a demo anyways

In [None]:
ts, xs, ys = dataset.mk_txy(dataset.fetch_testing_data())

In [None]:
nn: keras.Model = keras.models.load_model('ann.keras')
nn.summary()

In [None]:
keras.utils.plot_model(nn, show_shapes=True, expand_nested=True, show_layer_activations=True, rankdir='TB', show_layer_names=True)

In [None]:
outputs = pd.DataFrame(nn.predict(xs), columns=['mu', 'sigma'], index=ts).assign(
    upper_bound=lambda df: df['mu'] + df['sigma'],
    lower_bound=lambda df: df['mu'] - df['sigma'],
)

In [None]:
fig = plotly.subplots.make_subplots(
    rows=2,
    shared_xaxes=True,
    row_heights=[1,5],
)
fig.print_grid()

In [None]:
fig.add_trace(
    go.Scatter(
        x=outputs.index,
        y=outputs['sigma'],
        line=dict(color='rgb(0,100,80)'),
        mode='lines'
    ),
    col=1,
    row=1,
).add_traces([
        go.Scatter(
            x=outputs.index,
            y=outputs['mu'],
            line=dict(color='rgb(0,100,80)'),
            mode='lines'
        ),
        go.Scatter(
            name='Upper Bound',
            x=outputs.index,
            y=outputs['upper_bound'],
            mode='lines',
            marker=dict(color="#444"),
            line=dict(width=0),
            showlegend=False
        ),
        go.Scatter(
            name='Lower Bound',
            x=outputs.index,
            y=outputs['lower_bound'],
            marker=dict(color="#444"),
            line=dict(width=0),
            mode='lines',
            fillcolor='rgba(68, 68, 68, 0.3)',
            fill='tonexty',
            showlegend=False
        ),
        go.Scatter(
            name='Target',
            x=outputs.index[:-12],
            y=ys[12:,0],
            marker=dict(color="red"),
            mode='lines',
        ),
    ],
    cols=[1,1,1,1],
    rows=[2,2,2,2],
).show()

In [None]:
df = pd.DataFrame({'y_true': ys[:,0], 'y_pred': outputs.mu})
px.scatter(df, x='y_true', y='y_pred', width=600, height=500, title='Predictions vs Observations')

In [None]:
fig = plotly.subplots.make_subplots(
    cols=2,
    specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
    column_titles=['mu embeddings', 'sigma embeddings'],
).add_trace(px.scatter_3d(pd.concat([
    pd.DataFrame(nn.get_layer('sequential_1').get_layer('variational_sampling_layer')(nn.get_layer('sequential')(xs)), columns=['x','y','z']),
]), x='x', y='y', z='z', color=outputs['mu']).data[0], col=1, row=1
).add_trace(px.scatter_3d(pd.concat([
    pd.DataFrame(nn.get_layer('sequential_2').get_layer('variational_sampling_layer_1')(nn.get_layer('sequential')(xs)), columns=['x','y','z']),
]), x='x', y='y', z='z', color=outputs['sigma']).data[0], col=2, row=1).update_layout(height=500, width=900)
fig