### Representing Geometric Changes in the Julia Set Fractal

This notebook investigates the use of place cell SSPs to represent a particular complex dynamical system: the Julia set ($c = -0.1 + 0.65i$). This set is a fractal that is infinitely self-similar in nature, making it an interesting candidate to stress-test the ability of a single SSP to capture details at increasingly fine scales. 

After demonstrating that an SSP can represent the Julia set at increasingly fine scales, we use an MLP to learn the transformations between each SSP representation. We find that it possible for an MLP with fewer parameters than there is training data to accurately and autonomously model all of transformations from one SSP to the next by recurrently feeding its prediction back to itself.

In [None]:
%matplotlib inline

In [None]:
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from IPython.display import HTML

import nengo
import nengo_spa as spa
from ssp.maps import Spatial2D
from ssp.plots import heatmap_animation, create_gif

In [None]:
# Adaoted from ssp_grid_cell_utils.py and ssp_grid_cell_examples.ipynb
# Accurate representation for spatial cognition using grid cells
# Nicole Sandra-Yaffa Dumont & Chris Eliasmith

def ssp_plane_basis(K):
    # Create the bases vectors X,Y as described in the paper with the wavevectors 
    # (k_i = (u_i,v_i)) given in a matrix K. To get hexganal patterns use 3 K vectors 120 degs apart
    # To get mulit-scales/orientation, give many such sets of 3 K vectors 
    # K is _ by 2 
    d = K.shape[0]
    FX = np.ones((d*2 + 1,), dtype="complex")
    FX[0:d] = np.exp(1.j*K[:,0])
    FX[-d:] = np.flip(np.conj(FX[0:d]))
    FX = np.fft.ifftshift(FX)
    FY = np.ones((d*2 + 1,), dtype="complex")
    FY[0:d] = np.exp(1.j*K[:,1])
    FY[-d:] = np.flip(np.conj(FY[0:d]))
    FY = np.fft.ifftshift(FY)
    
    X = spa.SemanticPointer(data=np.fft.ifft(FX))
    Y = spa.SemanticPointer(data=np.fft.ifft(FY))
    return X, Y


def generate_grid_cell_basis(n_scales, n_rotates, scale_min=0.5, scale_max=1.8):
    """Generates basis vectors with ``d = n_scales * n_rotates * 6 + 1``."""
    K_hex = np.array(
        [[0, 1],
         [np.sqrt(3) / 2, -0.5],
         [-np.sqrt(3) / 2, -0.5]]
    )

    # Combining multiple n_scales sets of 3 wave vectors that give hexagonal grid interference patterns
    # each set of 3 giving a different grid resolution
    scales = np.linspace(scale_min, scale_max, n_scales)
    K_scales = np.vstack([K_hex * i for i in scales])

    # Combining multiple n_rotates sets of 3 wave vectors that give hexagonal grid interference patterns
    # each set of 3 giving a different grid orientation
    thetas = np.arange(0, n_rotates) * np.pi / (3 * n_rotates)
    R_mats = np.stack([np.stack([np.cos(thetas), -np.sin(thetas)], axis=1),
                       np.stack([np.sin(thetas), np.cos(thetas)], axis=1)],
                      axis=1)
    # TODO: don't double transpose
    K_rotates = (R_mats @ K_hex.T).transpose(1, 2, 0).T.reshape(-1, 2)

    # Multiple resolutions and orientations
    # TODO: don't double transpose
    K_scale_rotates = (R_mats @ K_scales.T).transpose(1, 2, 0).T.reshape(-1, 2)

    # Generate the (X, Y) basis vectors
    X, Y = ssp_plane_basis(K_scale_rotates)
    d = n_scales * n_rotates * 6 + 1
    assert len(X) == len(Y) == d
    return X, Y, d

In [None]:
try:
    from numba import njit  # optional

except ImportError:
    print("`pip install numba` to significantly speed this up!")
    njit = lambda x: x  # this decorator does nothing


@njit
def julia_set(c, n_iter=1000, R=10,
              resolution=(500, 500), x=(-1.5, 1.5), y=(-1.5, 1.5)):
    """Adapted from https://scipython.com/book/chapter-7-matplotlib/problems/p72/the-julia-set/"""
    im_width, im_height = resolution
    xmin, xmax = x
    xwidth = xmax - xmin
    ymin, ymax = y
    yheight = ymax - ymin

    julia = np.zeros((im_width, im_height))
    for ix in range(im_width):
        for iy in range(im_height):
            nit = 0
            # Map pixel position to a point in the complex plane
            z = complex(ix / im_width * xwidth + xmin,
                        iy / im_height * yheight + ymin)
            # Do the iterations
            while abs(z) <= R and nit < n_iter:
                z = z**2 + c
                nit += 1
            shade = 1 - np.sqrt(nit / n_iter)
            ratio = nit / n_iter
            julia[ix,iy] = ratio

    return julia

In [None]:
import matplotlib.cm as cm
cmap = cm.hot
# cmap = sns.diverging_palette(220, 20, sep=20, as_cmap=True)

def make_plot(sims, ax, cmap=cmap, vmin=-1, vmax=1):
    ax.imshow(sims, cmap=cmap, vmin=vmin, vmax=vmax)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('equal')
    sns.despine(left=True, bottom=True, ax=ax)

In [None]:
zoom = 1.45
res = 201
# c = complex(-0.4, 0.6)
c = complex(-0.1, 0.65)

R_range = (
    0.52,
    np.roots([1, -1, -np.abs(c)])[0],  # R^2 - R >= |c|
)

fig, ax = plt.subplots(1, 2, figsize=(12, 6))

for i in range(2):
    julia = julia_set(
        c=c,
        R=R_range[i],
        resolution=(res, res),
        x=(-zoom, zoom),
        y=(-zoom, zoom),
    )
    make_plot(julia, ax=ax[i], vmin=0, vmax=None)

plt.show()

print("|z| Range:", R_range)
print("Min, Max:", julia.min(), julia.max())

In [None]:
X, Y, d = generate_grid_cell_basis(
    n_scales=30, n_rotates=21,
    scale_min=1, scale_max=100)  # scale_min=0.8, scale_max=3.6)
dim = d

print("Dimensionality:", d)

In [None]:
class IterSpatial2D(Spatial2D):

    def encode_point(self, x, y, name=None):
        tag = self.voc["Identity"] if name is None else self.voc[name]

        # reduces spatial artifacts in large-scale complex objects by
        # changing the base vectors continuously as a function of x^2 + y^2
        # offset is added so that x or y close to 0 don't turn their
        # respective base vector into the identity
        c = x**2 + y**2 + 1
        IX = (self.X ** (c + 1)) ** (x * self.scale)
        IY = (self.Y ** (c + 1)) ** (y * self.scale)

        return tag * IX * IY

In [None]:
ssp_map = IterSpatial2D(
    dim=dim, scale=20, X=X, Y=Y,
    rng=np.random.RandomState(seed=0))
ssp_map.build_grid(x_len=1, y_len=1, x_spaces=res, y_spaces=res, centered=True)

# Compute the heatmap for the columns of the identity matrix
# Based on https://github.com/ctn-waterloo/metric-representation/blob/b822b49ade9aca81f564182bedb5e35a78761367/metric_representation/regions/region_utils.py#L9
# which in turn is adapted from Terry's code for solving from region vectors from way back
I = np.eye(dim)
A = np.asarray(
    [ssp_map.compute_heatmap(ssp=spa.SemanticPointer(I[i])).flatten()
     for i in range(dim)]
)

In [None]:
from nengo.utils.progress import Progress, ProgressTracker

n_steps = 50
solver = nengo.solvers.LstsqL2(reg=1e-3)

heatmaps = []
ssps = []

with ProgressTracker(
    True, Progress("Computing", "Computed", n_steps)
) as progress_bar:
    for R in np.linspace(R_range[0], R_range[1], n_steps):
        heatmap = julia_set(
            c=c,
            R=R,
            resolution=(res, res),
            x=(-zoom, zoom),
            y=(-zoom, zoom),
        )

        solution, info = solver(A.T, heatmap.reshape(-1, 1))
        ssp = spa.SemanticPointer(solution.squeeze(axis=1))
        sims = ssp_map.compute_heatmap(ssp)

        heatmaps.append([julia, sims])
        ssps.append(ssp)
        
        progress_bar.total_progress.step()

In [None]:
ani = heatmap_animation(
    list(zip(*heatmaps)), figsize=(8, 4), interval=50,
    titles=['Ideal', 'SSP'],
    cmap=cmap, vmin=0, vmax=None,  # vmax=None will adjust the colors to the maximum per subplot
)

HTML('<img src="data:image/gif;base64,{0}" />'.format(
    create_gif(ani, fname="place-cells-juliaset.gif")))

In [None]:
from ssp.models import MLP

mlp = MLP(dim, [5, 128, 5], dim)
mlp.model.summary()

vectors = np.asarray([ssp.normalized().v for ssp in ssps])
print("Number of values to memorize (approx):", vectors.size)

mlp.train(vectors[:-1], vectors[1:], n_steps=2000)

plt.figure()
plt.plot(mlp.costs)
plt.xlabel("Epoch")
plt.ylabel("Training Loss (MSE)")
plt.show()

In [None]:
test_heatmaps = []

ssp = ssps[0].normalized()
for i in range(1, n_steps):
    ssp = mlp(ssp)
    sims = ssp_map.compute_heatmap(ssp)
    test_heatmaps.append([heatmaps[i][0], heatmaps[i][1], sims])

ani = heatmap_animation(
    list(zip(*test_heatmaps)), figsize=(12, 4), interval=50,
    titles=['Ideal Map', 'Target SSP', 'MLP'],
    cmap=cmap, vmin=0, vmax=None,  # vmax=None will adjust the colors to the maximum per subplot
)

HTML('<img src="data:image/gif;base64,{0}" />'.format(
    create_gif(ani, fname="place-cells-juliaset-test.gif")))

In [None]:
Wout, _ = mlp.model.layers[-1].get_weights()

fig, ax = plt.subplots(1, Wout.shape[0], figsize=(4 * Wout.shape[0], 4))

for i in range(Wout.shape[0]):
    heatmap = ssp_map.compute_heatmap(spa.SemanticPointer(Wout[i]))
    ax[i].imshow(heatmap, cmap=cmap, vmin=0, vmax=None)
    ax[i].set_xticks([])
    ax[i].set_yticks([])

fig.show()