In [1]:
import numpy as np
import tensorflow.keras as keras
import plotly
import plotly.graph_objects as go
from ipywidgets import widgets

## Load pre-trained tensorflow model

In [2]:
T=50
sigma = 0.7
beta = 0.9
model = keras.models.load_model("11_26_03", compile=False)  
model.compile(loss="mse")

## Setup interactive plotting 

In [3]:
beta_slider = widgets.FloatSlider(
    value=beta,
    min=0.7,
    max=0.99,
    step=0.01,
    description='beta:',
    continuous_update=False
)

In [4]:
sigma_slider = widgets.FloatSlider(
    value=sigma,
    min=0.5,
    max=1.5,
    step=0.05,
    description='sigma:',
    continuous_update=False
)

In [9]:
# "predict" consumption and savings for initial value pair
prediction = model.predict(np.array([sigma,  beta]).reshape(1, 2))
trace1 = go.Scatter(x=np.arange(T), y=prediction[0, :T], mode='markers', name="consumption")
trace2 = go.Scatter(x=np.arange(T+1), y=prediction[0, T:2*T+1], mode='markers', name="savings")

In [10]:
# define response on changes of beta and sigma
def response(change):
    beta = beta_slider.value
    sigma = sigma_slider.value
    with g.batch_update():
        prediction = model.predict(np.array([sigma,  beta]).reshape(1, 2))
        g.data[0].y = prediction[0, :T]
        g.data[1].y = prediction[0, T:2*T+1]

sigma_slider.observe(response, names="value")
beta_slider.observe(response, names="value")

In [11]:
g = go.FigureWidget(data=[trace1, trace2],
                    layout=go.Layout(
                        title=dict(
                            text='Life-Cycle Consumption-Savings Example'
                        ),
                        barmode='overlay',
                        showlegend=True
                    ))
container = widgets.HBox(children=[beta_slider, sigma_slider]) 
widgets.VBox([container, g])

VBox(children=(HBox(children=(FloatSlider(value=0.99, continuous_update=False, description='beta:', max=0.99, …