# ORIGINAL CODE COPIED FROM CNGR WATERLOO GITHUB REPO (particularly Aaron Voelker) : 
    
## Delay Network on BrainDrop
This notebook explores the implementation of a "delay network" (DN), i.e., a dynamical system optimized to represent a rolling window of input history, on the real BrainDrop. For technical details, see:

http://compneuro.uwaterloo.ca/publications/voelker2018.html
https://arvoelke.github.io/nengolib-docs/nengolib.networks.RollingWindow.html
https://github.com/arvoelke/cosyne2018/raw/master/abstract.pdf

In [1]:
%matplotlib inline


In [2]:
import seaborn as sns  # does this override any style?

import matplotlib.pyplot as plt
#plt.style.use('~/Downloads/ieee_tran.mplstyle')
# We don't have this file locally 

SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

plt.rc('text', usetex=True)

In [3]:
import numpy as np

import nengo
import nengo_brainstorm as brd
from nengolib import Lowpass
from nengolib.networks import readout
from nengolib.signal import Balanced
from nengolib.synapses import PadeDelay, pade_delay_error, ss2sim



ModuleNotFoundError: No module named 'pystorm.PyDriver'

## Numerical Simulation
First, let's just look at the delay network's LTI system, and simulate it directly in discrete-time using the state-space equations:

$${\bf x}[t+1] = \bar{A} {\bf x}[t] + \bar{B} u[t] \text{.}$$
For debugging purposes, we compute the Padé approximation error for the given cut-off input frequency.

We also compute the optimal linear transformation that maps from the state-space back to the rolling window of input history. This gives us a way to efficiently compute functions across time within the lower-dimensional (temporally-compressed) space. Note the transformation accounts for the fact that the state-space has been "balanced".

In [None]:

theta = .1    #0.1
order = 15    # 3
freq = 30       # 3
power = 1.5  # chosen to keep radii normalized to [-1, 1]

print("PadeDelay(%s, %s) => %f%% error @ %sHz" % (
    theta, order, 100*abs(pade_delay_error(theta*freq, order=order)), freq))
pd = PadeDelay(theta=theta, order=order)

# Heuristic for normalizing state so that each dimension is ~[-1, +1]
rz = Balanced()(pd, radii=1./(np.arange(len(pd))+1))
sys = rz.realization

# Compute matrix to transform from state (x) -> sampled window (u)
t_samples = 1
C = np.asarray([readout(len(pd), r)
                for r in np.linspace(0, 1, t_samples)]).dot(rz.T)
assert C.shape == (t_samples, len(sys))
C = C.transpose()

In [None]:
length = 2000
dt = 0.001
seed = 1  # chosen just to get a pretty sample

process = nengo.processes.WhiteSignal(
    period=length*dt, rms=power, high=freq, y0=0, seed=seed)
t = process.ntrange(length, dt=dt)
u = process.run_steps(length, dt=dt)
x = sys.X.filt(u, dt=dt)  # discretizes sys using ZOH


In [None]:
y = np.matmul(x,C) # Compute Delayed Signal Using Matrix C from state x. 
plt.plot( t[0:499], u[0:499],t[0:499], y[0:499])
plt.xlabel('Time (s)')
plt.ylabel('State Value')
plt.legend(['Input', 'Output'])
plt.show()



## nengo_brainstorm (NEF on BrainDrop) Implementation
Now we apply "Principle 3" to map $\dot{{\bf x}}(t) = A{\bf x}(t) + Bu(t)$ onto a model for the synapse being used by BrainDrop. The input signal $u(t)$ is drawn from the same band-limited white noise process used above.

Each dimension is represented by a separate neural ensemble. Then we probe the value of each dimension, and linearly map them to the rolling window.

In [None]:
n_neurons = 128  # per dimension
tau = 0.018329807108324356  # guess from Terry's notebook
map_hw = ss2sim(sys, synapse=Lowpass(tau), dt=None)
assert np.allclose(map_hw.A, tau*sys.A + np.eye(len(sys)))
assert np.allclose(map_hw.B, tau*sys.B)

with nengo.Network() as model:
    brd.add_params(model)

    u = nengo.Node(output=process, label='u')
    p_u = nengo.Probe(u, synapse=None)
    
    # This is needed because a single node can't connect to multiple
    # different ensembles. We need a separate node for each ensemble.
    Bu = [nengo.Node(output=lambda _, u, b_i=map_hw.B[i].squeeze(): b_i*u,
                     size_in=1, label='Bu[%d]' % i)
          for i in range(len(sys))]
    
    X = []
    for i in range(len(sys)):
        ens = nengo.Ensemble(
            n_neurons=n_neurons, dimensions=1, label='X[%d]' % i)

        from nengo_brainstorm import solvers
        solver = solvers.FallbackSolver([nengo.solvers.LstsqL2(reg=0.01),
                                         solvers.CVXSolver(reg=0.01)])
        model.config[ens].solver = solver

        X.append(ens)
 
    P = []
    for i in range(len(sys)):
        nengo.Connection(u, Bu[i], synapse=None)
        nengo.Connection(Bu[i], X[i], synapse=tau)
        for j in range(len(sys)):
            nengo.Connection(X[j], X[i], synapse=tau,
                             function=lambda x_j, a_ij=map_hw.A[i, j]: a_ij*x_j)
        P.append(nengo.Probe(X[i], synapse=None))

In [None]:
with brd.Simulator(model, dt=dt,
                   precompute_inputs=True,
                   compute_stats=False,
                   generate_offset=1.0,
                   precompute_offset=1.0) as sim:
    sim.run(length*dt)

In [None]:
with brd.Simulator(model, dt=dt,
                   precompute_inputs=True,
                   compute_stats=False,
                   generate_offset=1.0,
                   precompute_offset=1.0) as sim:
    sim.run(length*dt)

In [None]:
from matplotlib.collections import LineCollection
from matplotlib.legend_handler import HandlerLineCollection
from matplotlib.lines import Line2D


class HandlerDashedLines(HandlerLineCollection):
    """Adapted from http://matplotlib.org/examples/pylab_examples/legend_demo5.html"""  # noqa: E501

    def create_artists(self, legend, orig_handle,
                       xdescent, ydescent, width, height, fontsize, trans):
        # figure out how many lines there are
        numlines = len(orig_handle.get_segments())
        xdata, xdata_marker = self.get_xdata(
            legend, xdescent, ydescent, width / numlines, height, fontsize)
        leglines = []
        for i in range(numlines):
            legline = Line2D(
                xdata + i * width / numlines,
                np.zeros_like(xdata.shape) - ydescent + height / 2)
            self.update_prop(legline, orig_handle, legend)
            # set color, dash pattern, and linewidth to that
            # of the lines in linecollection
            try:
                color = orig_handle.get_colors()[i]
            except IndexError:
                color = orig_handle.get_colors()[0]
            try:
                dashes = orig_handle.get_dashes()[i]
            except IndexError:
                dashes = orig_handle.get_dashes()[0]
            try:
                lw = orig_handle.get_linewidths()[i]
            except IndexError:
                lw = orig_handle.get_linewidths()[0]
            if dashes[0] is not None:
                legline.set_dashes(dashes[1])
            legline.set_color(color)
            legline.set_transform(trans)
            legline.set_linewidth(lw)
            leglines.append(legline)
        return leglines


def plot(name, t, u, x_hat, C=C):
    print("Radii:", np.max(np.abs(x_hat), axis=0))
    w = C.dot(x_hat.T)
    w_ideal = C.dot(x.T)
    assert C.shape == (t_samples, order)
    
    top_cmap = sns.color_palette('GnBu_d', t_samples)[::-1]
    fig, ax = plt.subplots(2, 1, sharex=True, figsize=(3.5, 3.5))
    for c, w_i in list(zip(top_cmap, w))[::-1]:
        ax[0].plot(t, w_i, c=c, alpha=0.7)
    target_line, = ax[0].plot(t, u, c='green', linestyle='--', lw=1)
    ax[0].set_ylim(np.min(w), np.max(w) + 1)
    
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
    from nengolib.signal import nrmse
    insert = inset_axes(ax[0], width="25%", height=0.3, loc='upper right')
    insert.patch.set_alpha(0.8)
    insert.xaxis.tick_top()
    insert.tick_params(axis='x', labelsize=4)
    insert.tick_params(axis='y', labelsize=4)
    insert.xaxis.set_label_position('top') 
    t_window = np.linspace(0, theta, t_samples)
    e_window = nrmse(w, target=w_ideal, axis=1)
    for i in range(1, t_samples):
        insert.plot([t_window[i-1], t_window[i]],
                    [e_window[i-1], e_window[i]],
                    c=top_cmap[i])
    insert.set_xlabel("Delay Length (s)", size=4)
    insert.set_ylabel("NRMSE", size=4)
    #insert.set_ylim(0, max(e_window))
    
    bot_cmap = sns.color_palette('bright', order)
    for i in range(order):
        ax[1].plot(t, x_hat[:, i], c=bot_cmap[i], alpha=0.9)
        ax[1].plot(t, x[:, i], c=bot_cmap[i], linestyle='--', lw=1)
    
    ax[0].set_title("Delay Network")
    ax[1].set_title("State Vector")
    ax[-1].set_xlabel("Time (s)")

    top_lc = LineCollection(
        len(C) * [[(0, 0)]], lw=1, colors=top_cmap)
    ax[0].legend([target_line, top_lc], ["Input", "Output"],
                 handlelength=3.2, loc='lower right',
                 handler_map={LineCollection: HandlerDashedLines()})

    bot_lc_ideal = LineCollection(
        order * [[(0, 0)]], lw=1, colors=bot_cmap, linestyle='--')
    bot_lc_actual = LineCollection(
        order * [[(0, 0)]], lw=1, colors=bot_cmap)
    ax[1].legend([bot_lc_ideal, bot_lc_actual], ["Ideal", "Actual"],
                 handlelength=3.2, loc='lower right',
                 handler_map={LineCollection: HandlerDashedLines()})
    
    for fmt in ('pdf', 'png'):
        fig.savefig('%s.%s' % (name, fmt), dpi=600, bbox_inches='tight')

    fig.show()
    
plot("delay_network", t, u, x_hat)

In [None]:
# Ensure no overflows
overflows = sim.hal.get_overflow_counts()
assert overflows == 0, overflows

In [None]:
import time
np.savez("delay_network_%s" % time.time(), t=sim.trange(), u=u, x_hat=x_hat, x=x, C=C)

In [None]:
d = np.load("delay_network_1538398994.3587925.npz")
t = d['t']
u = d['u']
x_hat = d['x_hat']
x = d['x']
C = d['C']

In [None]:
from nengolib.signal import nrmse
w = C.dot(x_hat.T)
w_ideal = C.dot(x.T)
e = nrmse(w.flatten(), target=w_ideal.flatten())#, axis=1)
print(e) #np.mean(e_window))