### Low-Dimensional Dynamics in SSP Space

This notebook demonstrates a 2D harmonic oscillator being learned as a linear-nonlinear transformation in SSP space. Each dimension is represented by a single SSP. The nonlinear operation is binding those two SSPs together. The linear transformation is applied to the concatenation of those 3 SSPs. After each step, the two SSPs are normalized (made unitary), and the nonlinearity is recomputed.

In [None]:
%matplotlib inline

In [None]:
import nengo
import numpy as np
import nengo_spa as spa
from scipy.linalg import expm

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns

from IPython.display import HTML, display

from ssp.pointers import BaseVectors
from ssp.plots import create_gif

In [None]:
dim = 256
scale = 10

freq = 1
dt = 0.0005
T = 1

n_steps = int(T / dt)
t = np.arange(n_steps) * dt

In [None]:
omega = 2*np.pi*freq

# continuous time system (\dot{x} = Ax)
Acont = np.array([[0, omega], [-omega, 0]])

# discrete time system (x[t+dt] = Ax[t])
A = expm(dt*Acont)
print(A)

In [None]:
# simulate ideal discrete-time system
x = np.empty((n_steps, 2))
x[0, :] = np.array([1., 0.])
for i in range(n_steps-1):
    x[i+1, :] = np.dot(A, x[i, :])

plt.figure(figsize=(9, 6))
plt.plot(t, x)
plt.legend(['$x_0$', '$x_1$'])
plt.xlabel("Time (s)")
plt.show()

In [None]:
# sanity check that this data is sufficient for learning
# the same underlying state-space transformation
solver = nengo.solvers.Lstsq()
AhatT, info = solver(x[:-1, :], x[1:, :])
Ahat = AhatT.T
assert np.allclose(Ahat, A)
print(info)

In [None]:
# define a number of methods for representing the vectors in SSP
# space that will be used to learn the linear transformation
# note the cleanups are dead simple (optional normalization,
# with a recompute of the nonlinearity)


def normalize(v):
    return v / np.linalg.norm(v)


class LinearTransform:
    """Stacks the two SSPs without any nonlinearity."""
    
    def __init__(self, dim, renormalize=True):
        self.dim = dim
        self.renormalize = renormalize

    def __repr__(self):
        return "%s(%d, renormalize=%s)" % (
            type(self).__name__, self.dim, self.renormalize)

    @property
    def size(self):
        return 2*self.dim

    def encode(self, x0, x1):
        return np.hstack([x0.v, x1.v])
    
    def decode(self, encoded):
        return encoded[:self.dim], encoded[self.dim:]

    def cleanup(self, encoded):
        # decode, normalize x0 and x1, and then encode
        x0, x1 = self.decode(encoded)
        if self.renormalize:
            x0 = normalize(x0)
            x1 = normalize(x1)
        # for subclasses, this last step optionally computes
        # a nonlinear projection
        return self.encode(
            spa.SemanticPointer(x0), spa.SemanticPointer(x1),
        )
    

class BindingTransform(LinearTransform):
    """Incorporates a nonlinear binding of the two SSPs."""

    @property
    def size(self):
        return 3*self.dim

    def encode(self, x0, x1):
        return np.hstack([x0.v, x1.v, (x0*x1).v])

    def decode(self, encoded):
        return encoded[:self.dim], encoded[self.dim:2*self.dim]


class MultiplyTransform(LinearTransform):
    """Incorporates an elementwise product of the two SSPs."""

    @property
    def size(self):
        return 3*self.dim

    def encode(self, x0, x1):
        return np.hstack([x0.v, x1.v, x0.v * x1.v])

    def decode(self, encoded):
        return encoded[:self.dim], encoded[self.dim:2*self.dim]

In [None]:
rng = np.random.RandomState(seed=0)
voc = spa.Vocabulary(dim, pointer_gen=BaseVectors(dim, rng=rng))
voc.populate('X0; X1')
X0, X1 = voc['X0'], voc['X1']

In [None]:
# transform each of the (x0, x1) SSPs
# with an optional nonlinear projection (depends on the transform)

transform = BindingTransform(dim, renormalize=True)

if transform.size > n_steps:
    print("Warning: the rank of the linear transformation (%d) "
          "exceeds the number of time-steps (%d), and so in "
          "theory a least-squares solved can just memorize the "
          "targets." % (transform.size, n_steps))

ssps = np.empty((n_steps, transform.size), dtype=np.float64)
for i in range(n_steps):
    x0 = X0**(scale*x[i, 0])
    x1 = X1**(scale*x[i, 1])
    ssps[i, :] = transform.encode(x0, x1)

In [None]:
# solve for the linear transformation in the SSP space
# the regularization here is important

solver = nengo.solvers.LstsqL2(reg=5e-3)
AsspT, info = solver(ssps[:-1, :], ssps[1:, :])
Assp = AsspT.T
print(info['rmses'].mean())

In [None]:
AsspT.shape, ssps.shape

In [None]:
# test this linear system using only the first vector in
# ssps as input

# note that if the transform's cleanup does not use some of the
# dimensions of `encoded`, then we're basically igoring those rows
# in this learned matrix. we could optimize those away if needed.

ssps_hat = np.empty_like(ssps)
ssps_hat[0, :] = ssps[0, :]

for i in range(n_steps-1):
    ssps_hat[i+1, :] = transform.cleanup(Assp.dot(ssps_hat[i, :]))

In [None]:
# decode back out the (x0, x1) SSPs and then visualize them by similarity
name = repr(transform)

# make tiling over 1D space for computing SSP sims
size = 2
space = np.linspace(-size, size, 501)

def make_tiling(base, scale, space):
    array = np.zeros((len(space), len(base)))
    for i, point in enumerate(space):
        array[i, :] = (base ** (scale * point)).v
    return array

def plot(name, space, p_sims, v_sims, downsample=25):
    fig, ax = plt.subplots()
    images = []
    for step in range(0, len(p_sims), downsample):
        lines = ax.plot(
            space,
            p_sims[step], 'b',
            space,
            v_sims[step], 'r',
            animated=True)
        text = ax.text(space[0], 1, str(step),
                       verticalalignment='center', animated=True)
        images.append(lines + [text])

    # ax.set_xticks([])
    plt.title(name)
    plt.legend(['$x_0$', '$x_1$'])
    ani = animation.ArtistAnimation(fig, images, interval=80, blit=True)
    plt.close()
    return ani

decoded = np.asarray([transform.decode(encoded) for encoded in ssps_hat])
x0_sims = np.dot(decoded[:, 0], make_tiling(X0, scale, space).T)
x1_sims = np.dot(decoded[:, 1], make_tiling(X1, scale, space).T)

HTML('<img src="data:image/gif;base64,{0}" />'.format(
    create_gif(
        plot(name, space, x0_sims, x1_sims),
        fname="%s.gif" % name)))

In [None]:
plt.title(name)

plt.scatter(space[np.argmax(x0_sims, axis=1)],
            space[np.argmax(x1_sims, axis=1)],
            c=(np.max(x0_sims, axis=1)+np.max(x1_sims, axis=1))/2,
            vmin=-1, vmax=1, s=1, cmap='RdYlBu')
plt.colorbar()
plt.xlabel('$x_0$')
plt.ylabel('$x_1$')
plt.axis('equal')

plt.show()