# Feed-Forward Neural Network

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [3]:
from itertools import product
from pathlib import Path
from tqdm.auto import trange, tqdm
import pickle

import numpy as np
import xarray as xr
import xarray.ufuncs as xf
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

from system_identification.ffnn import FeedForwardNeuralNetwork, TrainingParameters
from system_identification.load_assignment_data import load_net_example_ff

## Define network

In [4]:
data = xr.open_dataset("data/data_smoothed.nc")
# data = data.isel(t=slice(None, None, 50))
display(data)

inputs = np.hstack((
    data.alpha_estimate.values[:, None],
    data.beta_m.values[:, None]
))
reference_outputs = data.c_m.values[:, None]

In [5]:
nn.saves_idxs(nn.log_dir)[-1]

NameError: name 'nn' is not defined

In [15]:
n_hidden_list = [2**i for i in [11]]
nns = []

for n_hidden in tqdm(n_hidden_list):
    nn = FeedForwardNeuralNetwork.load(f"./ffnn_{n_hidden}")
    if nn is None:
        print("Creating new FeedForwardNeuralNetwork")
        nn = FeedForwardNeuralNetwork.new(
            n_inputs=2,
            n_outputs=1,
            n_hidden=n_hidden,
            range=[[-1, 1], [-1, 1]],
            log_dir=f"./ffnn_{n_hidden}",
            training_parameters=TrainingParameters(
                epochs=1000,
                goal=0,
                min_grad=1e-10,
                mu=0.1,
            ),
        )
    nns.append(nn)
    nn.back_propagation(inputs, reference_outputs, epochs=int(40000 / 50 * 12))

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/9600 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [20]:
nn.bias_weights[1].shape

(2048, 1)

In [None]:
with open("training_logs_2.pickle", "wb") as f:
    pickle.dump(nns, f)

## Plotting

In [13]:
# 8.60445130660727

plt.figure()
for nn in nns:
    print(min(nn.training_log.error))
    nn.training_log.error.plot()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

1.3576746029642734


In [66]:
coords = np.linspace(-5, 5, 100)
nn_eval = np.empty((100,))
fn_eval = np.empty((100,))
plt.figure()

for y in np.linspace(-1, 1, 4):
    for xi, x in enumerate(coords):
        nn_eval[xi] = nn.evaluate([x, y])

    plt.plot(coords, fn_eval, "b")
    plt.plot(coords, nn_eval, "r")


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
alpha_grid = np.linspace(-0.2, 0.8, 1000)
beta_grid = np.linspace(-0.3, 0.3, 1000)
clim = min(reference_outputs), max(reference_outputs)
nn = nns[0]


nn_eval = np.empty(reference_outputs.shape)
for ii, input in enumerate(inputs):
    nn_eval[ii] = nn.evaluate(input)

fig = plt.figure()
ax = plt.gca()
zi = griddata((inputs[:, 0], inputs[:, 1]), nn_eval, (alpha_grid[None, :], beta_grid[:, None]), method='linear')
pos = plt.imshow(zi, cmap='viridis')
fig.colorbar(pos)
plt.clim(*clim)

fig = plt.figure()
ax = plt.gca()
zi = griddata((inputs[:, 0], inputs[:, 1]), reference_outputs, (alpha_grid[None, :], beta_grid[:, None]), method='linear')
pos = plt.imshow(zi, cmap='viridis')
fig.colorbar(pos)
plt.clim(*clim)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(inputs[:, 0], inputs[:, 1], reference_outputs)

# ax = fig.add_subplot(111, projection='3d')
ax.scatter(inputs[:, 0], inputs[:, 1], nn_eval)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7fb8a75e7df0>

## Evaluate error

In [20]:
errors = []
for idx in tqdm(FeedForwardNeuralNetwork.saves_idxs(nn.log_dir)):
    ffnn = FeedForwardNeuralNetwork.load(nn.log_dir, idx)
    errors.append(ffnn.evaluate_error(inputs, reference_outputs))

plt.figure()
plt.plot(errors)

  0%|          | 0/97 [00:00<?, ?it/s]

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7f406654da60>]