# Radial Basis Function Neural Network

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
from itertools import product

import numpy as np
import xarray as xr
import xarray.ufuncs as xf
import matplotlib.pyplot as plt

from system_identification.rbfnn_model import RadialBasisFunctionNeuralNetworkModel

## 1D

In [13]:
model = RadialBasisFunctionNeuralNetworkModel.new_grid_placement(
    n_inputs=1,
    grid_size=[9],
    input_range=np.array([(-1, 1)]),
    rbf_width=1.75,
    rbf_amplitude=1,
    log_dir="./"
)
model

In [14]:
def foo(x):
    return -x**2

inputs = np.linspace(-1, 1, 1000)
reference_outputs = np.array(list(map(foo, inputs)))
reference_outputs_noisy = reference_outputs + (np.random.random(reference_outputs.shape) - 0.5) * 0.1

inputs = inputs.reshape(-1, 1, 1)
reference_outputs = reference_outputs.reshape(-1, 1, 1)
reference_outputs_noisy = reference_outputs_noisy.reshape(-1, 1, 1)

In [15]:
model.train(inputs, reference_outputs_noisy)

In [16]:

fig = plt.figure()
outputs = model.evaluate(inputs)
plt.plot(inputs.squeeze(), reference_outputs.squeeze())
plt.plot(inputs.squeeze(), reference_outputs_noisy.squeeze(), ".", markersize=1)
plt.plot(inputs.squeeze(), outputs.squeeze())

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7fa339b6b790>]

## 2D

In [7]:
model = RadialBasisFunctionNeuralNetworkModel.new_grid_placement(
    n_inputs=2,
    grid_size=[5, 15],
    input_range=np.array([(-1, 1), (-3, 5)]),
    rbf_width=1.75,
    rbf_amplitude=1,
    log_dir="./"
)
model

In [8]:
def foo(x):
    return -x[0]**2 + np.sin(x[1]*2)


resolution = 100
inputs = np.array(tuple(product(np.linspace(*model.range[0, :], resolution),
                                np.linspace(*model.range[1, :], resolution))))

reference_outputs = np.array(list(map(foo, inputs)))
reference_outputs_noisy = reference_outputs + (np.random.random(reference_outputs.shape) - 0.5) * 0.1

inputs = inputs[..., None]
reference_outputs = reference_outputs.reshape(-1, 1, 1)
reference_outputs_noisy = reference_outputs_noisy.reshape(-1, 1, 1)

In [9]:
model.train(inputs, reference_outputs)
output = model.evaluate(inputs)

In [10]:
fig = plt.figure()
plt.plot(model.weights_c[:, 0], model.weights_c[:, 1], "x")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7fa33b5d3a90>]

In [11]:
fig = plt.figure()
pos = plt.imshow(reference_outputs.reshape((resolution, resolution)))
fig.colorbar(pos)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.colorbar.Colorbar at 0x7fa33b641c40>

In [12]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(inputs[:, 0, 0], inputs[:, 1, 0], reference_outputs.squeeze(), s=0.1)
ax.scatter(inputs[:, 0, 0], inputs[:, 1, 0], output.squeeze(), s=0.1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7fa33b555490>