In [1]:
import jax.numpy as jnp

In [13]:
# PLOTS
import plotly.graph_objects as go


def update_layout_of_graph(fig: go.Figure, title: str = 'Plot') -> go.Figure:
    fig.update_layout(
        width=800,
        height=600,
        autosize=False,
        plot_bgcolor='rgba(0,0,0,0)',
        title=title,

    )
    fig.update_layout(plot_bgcolor='rgba(0,0,0,0)',
                      xaxis_title='input values',
                      yaxis_title='output values',
                      legend=dict(yanchor="top",
                                  y=0.9,
                                  xanchor="right",
                                  x=0.95),
                      title={
                          'x': 0.5,
                          'xanchor': 'center'
                      })
    fig.update_xaxes(showline=True, linewidth=1, linecolor='black')
    fig.update_yaxes(showline=True, linewidth=1, linecolor='black')
    return fig


def line_scatter(
        visible: bool = True,
        x_lines: jnp.array = jnp.array([]),
        y_lines: jnp.array = jnp.array([]),
        name_line: str = 'Predicted function',
        showlegend: bool = True,
) -> go.Scatter:
    # Adding the lines
    return go.Scatter(
        visible=visible,
        line=dict(color="blue", width=2),
        x=x_lines,
        y=y_lines,
        name=name_line,
        showlegend=showlegend
    )


def dot_scatter(
        visible: bool = True,
        x_dots: jnp.array = jnp.array([]),
        y_dots: jnp.array = jnp.array([]),
        name_dots: str = 'Observed points',
        showlegend: bool = True
) -> go.Scatter:
    # Adding the dots
    return go.Scatter(
        x=x_dots,
        visible=visible,
        y=y_dots,
        mode="markers",
        name=name_dots,
        marker=dict(color='red', size=8),
        showlegend=showlegend
    )


def uncertainty_area_scatter(
        visible: bool = True,
        x_lines: jnp.array = jnp.array([]),
        y_upper: jnp.array = jnp.array([]),
        y_lower: jnp.array = jnp.array([]),
        name: str = "mean plus/minus standard deviation",
) -> go.Scatter:
    return go.Scatter(
        visible=visible,
        x=jnp.concatenate((x_lines, x_lines[::-1])),  # x, then x reversed
        # upper, then lower reversed
        y=jnp.concatenate((y_upper, y_lower[::-1])),
        fill='toself',
        fillcolor='rgba(189,195,199,0.5)',
        line=dict(color='rgba(200,200,200,0)'),
        hoverinfo="skip",
        showlegend=True,
        name=name,
    )


def add_slider_GPR(figure: go.Figure, parameters):
    figure.data[0].visible = True
    figure.data[1].visible = True

    # Create and add slider
    steps = []
    for i in range(int((len(figure.data) - 1) / 2)):
        step = dict(
            method="update",
            label=f'{parameters[i]: .2f}',
            args=[{
                "visible": [False] * (len(figure.data) - 1) + [True]
            }],
        )
        step["args"][0]["visible"][2 *
                                   i] = True  # Toggle i'th trace to "visible"
        step["args"][0]["visible"][2 * i + 1] = True
        steps.append(step)

    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps,
    )]
    figure.update_layout(sliders=sliders, )
    return figure


def add_slider_to_function(figure: go.Figure, parameters):
    figure.data[0].visible = True

    # Create and add slider
    steps = []
    for i in range(len(figure.data)):
        step = dict(
            method="update",
            label=f'{parameters[i]: .2f}',
            args=[{
                "visible": [False] * len(figure.data)
            }],
        )
        step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
        steps.append(step)

    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps,
    )]
    figure.update_layout(sliders=sliders, )
    return figure


def plot_GPR(data_x, data_y, model, x, visible=True) -> list:
    mean = model.predict(x)

    std = jnp.sqrt(model._memory['variance'])
    data = []

    for i in range(1, 4):
        data.append(
            uncertainty_area_scatter(
                x_lines=x,
                y_lower=mean - i * std,
                y_upper=mean + i * std,
                name=f"mean plus/minus {i}*standard deviation",
                visible=visible))

    data.append(line_scatter(x_lines=x, y_lines=mean, visible=visible))
    data.append(dot_scatter(x_dots=data_x, y_dots=data_y, visible=visible))
    return data

In [3]:
class SquaredExponentialKernel:
    def __init__(self, sigma_f: float = 1, length: float = 1):
        self.sigma_f = sigma_f
        self.length = length

    def __call__(self, argument_1, argument_2) -> float:
        return float(self.sigma_f * jnp.exp(-(jnp.linalg.norm(argument_1 - argument_2) ** 2) / (2 * self.length ** 2)))


# Helper function to calculate the respective covariance matrices
def cov_matrix(x1, x2, cov_function):
    return jnp.array([[cov_function(a, b) for a in x1] for b in x2])


class GPR:
    def __init__(self, data_x, data_y, covariance_function=SquaredExponentialKernel(), white_noise_sigma: float = 0):
        self.noise = white_noise_sigma
        self.data_x = data_x
        self.data_y = data_y
        self.covariance_function = covariance_function

        # Store the inverse of covariance matrix of input (+ machine epsilon on diagonal) since it is needed for every prediction
        self._inverse_of_covariance_matrix_of_input = jnp.linalg.inv(
            cov_matrix(data_x, data_x, covariance_function) +
            (3e-7 + self.noise) * jnp.identity(len(self.data_x)))

        self._memory = None

    # function to predict output at new input values. Store the mean and covariance matrix in memory.
    def predict(self, at_values):
        k_lower_left = cov_matrix(self.data_x, at_values, self.covariance_function)
        k_lower_right = cov_matrix(at_values, at_values, self.covariance_function)

        # Mean.
        mean_at_values = jnp.dot(k_lower_left,
                                 jnp.dot(self.data_y, self._inverse_of_covariance_matrix_of_input.T).T).flatten()

        # Covariance.
        cov_at_values = k_lower_right - jnp.dot(k_lower_left,
                                                jnp.dot(self._inverse_of_covariance_matrix_of_input, k_lower_left.T))

        # Adding value larger than machine epsilon to ensure positive semi definite
        cov_at_values = cov_at_values + 3e-7 * jnp.ones(jnp.shape(cov_at_values)[0])

        var_at_values = jnp.diag(cov_at_values)

        self._memory = {
            'mean': mean_at_values,
            'covariance_matrix': cov_at_values,
            'variance': var_at_values
        }
        return mean_at_values

In [4]:
x_values = jnp.array([0, 0.3, 1, 3.1, 4.7])

y_values = jnp.array([1, 0, 1.4, 0, -0.9])

x = jnp.arange(-1, 7, 0.1)

In [5]:
model = GPR(x_values, y_values, SquaredExponentialKernel())
model.predict(x)

DeviceArray([ 4.74131870e+00,  4.68346167e+00,  4.52992773e+00,
              4.28164673e+00,  3.94466066e+00,  3.53031325e+00,
              3.05494356e+00,  2.53927827e+00,  2.00740910e+00,
              1.48544598e+00,  9.99980271e-01,  5.76358259e-01,
              2.37036407e-01,  2.08460606e-05, -1.22379646e-01,
             -1.24415830e-01, -7.33258575e-03,  2.20683411e-01,
              5.45146406e-01,  9.46194351e-01,  1.39999259e+00,
              1.88031638e+00,  2.36026430e+00,  2.81392074e+00,
              3.21788573e+00,  3.55252504e+00,  3.80291080e+00,
              3.95935941e+00,  4.01756907e+00,  3.97844481e+00,
              3.84754109e+00,  3.63432932e+00,  3.35124612e+00,
              3.01271892e+00,  2.63413572e+00,  2.23095226e+00,
              1.81786871e+00,  1.40819466e+00,  1.01338220e+00,
              6.42745674e-01,  3.03342104e-01, -2.42002034e-06,
             -2.64539897e-01, -4.89366204e-01, -6.75126791e-01,
             -8.23698878e-01, -9.3786734

In [12]:
x_lines = jnp.arange(-10, 10, 0.1)
kernel = SquaredExponentialKernel(length=5, sigma_f=5)

fig0 = go.FigureWidget(data=[
    line_scatter(
        x_lines=x_lines,
        y_lines=jnp.array([kernel(x, 0) for x in x_lines]),
    )
])

fig0 = update_layout_of_graph(fig0, title='Squared exponential kernel')
fig0

AttributeError: type object 'DOMWidget' has no attribute '_ipython_display_'

In [22]:
model = GPR(x_values, y_values, SquaredExponentialKernel(sigma_f=0.5, length=.5))
data = plot_GPR(data_x=x_values, data_y=y_values, x=x, model=model)
fig4 = go.Figure(data=data)
fig4 = update_layout_of_graph(fig=fig4, title='GPR with length 1, sigma 0 and noise 0')

fig4.show()