# Topo Time Series MWTA

## Intro
* **Date**: 1/13/21
* **What**: I'm basically doing time-series mwta with strong lateral excitation to promote the formation of topological connections.  I also think I figured out how to prevent training from going totally haywire by clipping the reconstruction error.
* **Why**: Basically the more useful connections between different prototypes I can find, the better.  Begin able the organize the prototypes and promote certain neurons firing together is super important when moving from an invariant representation that can represent many things to one particular configuration that corresponds with that representation.
* **Hopes**: I basically just want to see interesting topological connections.
* **Limitations**: This has already worked before, but that obviously doesn't mean I won't run into difficulties.  I just don't know what those difficulties are yet.

## Code

In [1]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from tensorflow.keras.datasets import mnist
from tqdm import tqdm

(x_tr, _), _ = mnist.load_data()

x_tr = x_tr / 255.0

In [2]:
t_sl = 30 # Tapestry side length
m_sl = 28 # Side length of each images

tapestry = np.zeros((t_sl * m_sl, t_sl * m_sl))

x_i = 0

for x in range(t_sl):
    for y in range(t_sl):

        tapestry[y * m_sl : (y + 1) * m_sl, x * m_sl : (x + 1) * m_sl] = x_tr[x_i]
        x_i += 1
        
tapestry[(t_sl - 1) * m_sl:, :] = tapestry[: m_sl, :]
tapestry[:, (t_sl - 1) * m_sl:] = tapestry[:, : m_sl]

In [3]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))
plt.xticks([])
plt.yticks([])
plt.imshow(tapestry, cmap="gray_r")

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f726d4d7bd0>

In [4]:
x_o = 420
y_o = 420

sl = 20

x = x_o
y = y_o

v_x = 0
v_y = 0

v_max = 3

a_x = np.random.uniform(-1, 1)
a_y = np.random.uniform(-1, 1)

img_count = 10_000
imgs = []

del_t = 1

for i in range(img_count):
    if i % 20 == 0:
        a_x = np.random.uniform(-1, 1)
        a_y = np.random.uniform(-1, 1)
        
    x += v_x * del_t
    y += v_y * del_t
    v_x = np.clip(v_x + (a_x * del_t), -v_max, v_max)
    v_y = np.clip(v_y + (a_y * del_t), -v_max, v_max)
    
    x_f = int(x) % ((t_sl - 1) * m_sl)
    y_f = int(y) % ((t_sl - 1) * m_sl)
    
    imgs.append(tapestry[y_f: y_f + sl, x_f : x_f + sl])

In [5]:
%matplotlib notebook
fig = plt.figure(figsize=(5, 5))

ims = []
for i in range(500):
    im = plt.imshow(imgs[i], cmap="gray_r", animated=True)
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True,
                                repeat_delay=500)

plt.xticks([])
plt.yticks([])
plt.show()

<IPython.core.display.Javascript object>

In [6]:
img_array = np.array(imgs)
ts_data = img_array.reshape(-1, sl ** 2)
gp_data = cp.asarray(ts_data)

In [7]:
def draw_weights(w, Kx, Ky, s_len, fig):
    tapestry = np.zeros((s_len * Ky, s_len * Kx))
    
    w_i = 0
    for y in range(Ky):
        for x in range(Kx):
            tapestry[y * s_len: (y + 1) * s_len, x * s_len: (x + 1) * s_len] = w[w_i].reshape(s_len, s_len)
            w_i += 1
            
    plt.clf()        
    max_val = np.max(tapestry)
    im = plt.imshow(tapestry, cmap="Greys", vmax=max_val)
    fig.colorbar(im, ticks=[0, max_val])
    plt.axis("off")
    fig.canvas.draw()

## Analysis Dialog

Here we go.

In [12]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 5
T_s = 10_000

Kx = 20
Ky = 20
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.01

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        p += lat_ex @ p
        winners = cp.argsort(p, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
#         e = cp.clip(v - r, -1, 1)
        e = v - r
        
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:04<00:00, 2119.63it/s]
100%|██████████| 10000/10000 [00:04<00:00, 2028.59it/s]
100%|██████████| 10000/10000 [00:04<00:00, 2086.19it/s]
100%|██████████| 10000/10000 [00:04<00:00, 2090.93it/s]
100%|██████████| 10000/10000 [00:04<00:00, 2090.28it/s]


Wow, even with 400 neurons it's fascinating to watch the topological connections form.  For some reason, everything goes haywire if I make the lateral connections any stronger than 0.1.  Not sure why that is, but obviously we're still totally getting fantastic topological connections.  Ok, I'm going to up this to 1,600 neurons.

In [15]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 5
T_s = 10_000

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        p += lat_ex @ p
        winners = cp.argsort(p, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1595.47it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1594.43it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1599.69it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1591.85it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1577.80it/s]


Holy snoots.  The strong and the weak patterns almost look like a Fourier basis function.  That's absolutely bonkers.  I'm going to try this with 2,500 neurons and more epochs.

In [16]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 50
Ky = 50
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        p += lat_ex @ p
        winners = cp.argsort(p, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1359.37it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1356.15it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1361.25it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1368.47it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1369.94it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1369.24it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1374.72it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1380.04it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1369.47it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1369.13it/s]


That is absolutely wild.  I've got to try this with just the digits.

But first, here's the animation.

In [22]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 5))

w_np = w.get()
lx_np = lat_ex.get()

ims = []
for i in tqdm(range(500)):
    v = ts_data[i].reshape(-1, 1)
    p = w_np @ v
    p += lx_np @ p
    winners = np.argsort(p, axis=0)[-n_w:]
    mask = np.zeros((N, 1))
    mask[winners] = 1
    o = mask * p
    o = o / max(prec, np.max(o))
    r = w_np.T @ o
    r = r / max(prec, np.max(r))
    
    mini_tap = np.zeros((50, 150))
    
    mini_tap[:20, :20] = v.reshape(20, 20)
    mini_tap[:, 50:100] = o.reshape(50, 50)
    mini_tap[:20, -20:] = r.reshape(20, 20)
    
    im = plt.imshow(mini_tap, cmap="gray_r", animated=True)
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=150, blit=True,
                                repeat_delay=500)

plt.xticks([])
plt.yticks([])

plt.show()

<IPython.core.display.Javascript object>

100%|██████████| 500/500 [00:02<00:00, 220.10it/s]


Ok, I'm going to try something interesting.  I'm going to use a competition value that contains the lateral excitation, but the network trains only on the raw output.  You'll see what I mean.

In [23]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 50
Ky = 50
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1368.50it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1370.14it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1373.95it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1374.34it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1373.40it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1373.97it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1373.16it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1379.19it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1377.03it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1378.51it/s]


In [26]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 5))

w_np = w.get()
lx_np = lat_ex.get()

ims = []
for i in tqdm(range(500)):
    v = ts_data[i].reshape(-1, 1)
    
    p = w_np @ v
    c = p + (lx_np @ p)
    winners = np.argsort(c, axis=0)[-n_w:]
    mask = np.zeros((N, 1))
    mask[winners] = 1
    o = mask * p
    r = w_np.T @ o

    o = o / max(prec, np.max(o))
    r = r / max(prec, np.max(r))
    
    mini_tap = np.zeros((50, 150))
    
    mini_tap[:20, :20] = v.reshape(20, 20)
    mini_tap[:, 50:100] = o.reshape(50, 50)
    mini_tap[:20, -20:] = r.reshape(20, 20)
    
    im = plt.imshow(mini_tap, cmap="gray_r", animated=True)
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=150, blit=True,
                                repeat_delay=500)

plt.xticks([])
plt.yticks([])

plt.show()

<IPython.core.display.Javascript object>

100%|██████████| 500/500 [00:02<00:00, 202.57it/s]


I think we both know what I have to do.  10,000 neurons, baby.

In [27]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 100
Ky = 100
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.06

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:24<00:00, 407.25it/s]
  8%|▊         | 779/10000 [00:01<00:23, 394.63it/s]


KeyboardInterrupt: 

Nope.  That takes too long.  Let's see...

In [29]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 80
Ky = 80
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.1

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:13<00:00, 729.24it/s]
100%|██████████| 10000/10000 [00:13<00:00, 732.73it/s]
100%|██████████| 10000/10000 [00:13<00:00, 731.87it/s]
100%|██████████| 10000/10000 [00:13<00:00, 731.32it/s]
100%|██████████| 10000/10000 [00:13<00:00, 726.45it/s]
100%|██████████| 10000/10000 [00:13<00:00, 727.98it/s]
100%|██████████| 10000/10000 [00:13<00:00, 726.81it/s]
100%|██████████| 10000/10000 [00:13<00:00, 726.65it/s]
100%|██████████| 10000/10000 [00:13<00:00, 726.80it/s]
100%|██████████| 10000/10000 [00:13<00:00, 726.74it/s]


In [30]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 5))

w_np = w.get()
lx_np = lat_ex.get()

ims = []
for i in tqdm(range(500)):
    v = ts_data[i].reshape(-1, 1)
    
    p = w_np @ v
    c = p + (lx_np @ p)
    winners = np.argsort(c, axis=0)[-n_w:]
    mask = np.zeros((N, 1))
    mask[winners] = 1
    o = mask * p
    r = w_np.T @ o

    o = o / max(prec, np.max(o))
    r = r / max(prec, np.max(r))
    
    mini_tap = np.zeros((80, 120))
    
    mini_tap[:20, :20] = v.reshape(20, 20)
    mini_tap[:, 20:100] = o.reshape(80, 80)
    mini_tap[:20, -20:] = r.reshape(20, 20)
    
    im = plt.imshow(mini_tap, cmap="gray_r", animated=True)
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=150, blit=True,
                                repeat_delay=500)

plt.xticks([])
plt.yticks([])

plt.show()

<IPython.core.display.Javascript object>

100%|██████████| 500/500 [00:09<00:00, 51.22it/s]


Man.  I sincerely and deeply love the topological organization I'm seeing.  It would fascinating to see what would happen with an ema.  Imma try it with 1,600 neurons.

In [31]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_data
    o = cp.zeros((N, 1))
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o += ((mask * p) - o) * 0.5
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1456.76it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1465.19it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1464.69it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1475.14it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1410.88it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1456.88it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1463.51it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1462.40it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1459.13it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1456.96it/s]


In [34]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 5))

w_np = w.get()
lx_np = lat_ex.get()

ims = []
o = np.zeros((N, 1))
for i in tqdm(range(500)):
    v = ts_data[i].reshape(-1, 1)
    
    p = w_np @ v
    c = p + (lx_np @ p)
    winners = np.argsort(c, axis=0)[-n_w:]
    mask = np.zeros((N, 1))
    mask[winners] = 1
    o += ((mask * p) - o) * 0.5
    r = w_np.T @ o

    o_oo = o / max(prec, np.max(o))
    r = r / max(prec, np.max(r))
    
    mini_tap = np.zeros((40, 80))
    
    mini_tap[:20, :20] = v.reshape(20, 20)
    mini_tap[:, 20:60] = o_oo.reshape(40, 40)
    mini_tap[:20, -20:] = r.reshape(20, 20)
    
    im = plt.imshow(mini_tap, cmap="gray_r", animated=True)
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=150, blit=True,
                                repeat_delay=500)

plt.xticks([])
plt.yticks([])

plt.show()

<IPython.core.display.Javascript object>

100%|██████████| 500/500 [00:01<00:00, 464.74it/s]


Holy cow.  It's showing the trajectories.  That's incredibly dope.  Ok, I got a hike with Mal to do, so I need to go for now, but that dope, and that actually might be how I can make guesses at trajectories!! Dope!!!

Ok, we back.  I'm partnering with Mallory, and everything is dope.  Anyway, I think the strength of the topological connections is actually jacking things up a bit.  That being said, before I train on pure digits, I'm going to try turning down the lateral excitation a bit.

In [8]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.06
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_data
    o = cp.zeros((N, 1))
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o += ((mask * p) - o) * 0.5
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1356.92it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1485.37it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1479.33it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1481.21it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1478.61it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1481.03it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1474.66it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1478.25it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1474.20it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1477.63it/s]


In [9]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 5))

w_np = w.get()
lx_np = lat_ex.get()

ims = []
o = np.zeros((N, 1))
for i in tqdm(range(500)):
    v = ts_data[i].reshape(-1, 1)
    
    p = w_np @ v
    c = p + (lx_np @ p)
    winners = np.argsort(c, axis=0)[-n_w:]
    mask = np.zeros((N, 1))
    mask[winners] = 1
    o += ((mask * p) - o) * 0.5
    r = w_np.T @ o

    o_oo = o / max(prec, np.max(o))
    r = r / max(prec, np.max(r))
    
    mini_tap = np.zeros((40, 80))
    
    mini_tap[:20, :20] = v.reshape(20, 20)
    mini_tap[:, 20:60] = o_oo.reshape(40, 40)
    mini_tap[:20, -20:] = r.reshape(20, 20)
    
    im = plt.imshow(mini_tap, cmap="gray_r", animated=True)
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=150, blit=True,
                                repeat_delay=500)

plt.xticks([])
plt.yticks([])

plt.show()

<IPython.core.display.Javascript object>

100%|██████████| 500/500 [00:01<00:00, 382.58it/s]


Fascinating.  Ok, I'm going to turn it way down and see what I can see.

In [10]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.01
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_data
    o = cp.zeros((N, 1))
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o += ((mask * p) - o) * 0.5
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1477.45it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1482.06it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1479.58it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1477.45it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1483.20it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1483.02it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1475.48it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1480.89it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1480.54it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1476.94it/s]


Ok, so no topology, pretty much. Interesting.  It seems like 0.1 is kinda the optimal number for some reason.

Ok, I'm going to normalize the lateral connections so neurons on the edges are getting the same amout of lateral excitation.

In [13]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

lat_ex = (lat_ex / cp.linalg.norm(lat_ex, axis=1).reshape(-1, 1)) * 0.5

for ep in range(Nep):
    inputs = gp_data
    o = cp.zeros((N, 1))
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o += ((mask * p) - o) * 0.5
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1488.68it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1489.11it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1490.91it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1486.01it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1487.72it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1487.71it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1488.22it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1478.79it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1475.74it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1486.93it/s]


I'm going to try something different.  I'm going to decrease `ex` and see how that affects things.

In [14]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 1

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

lat_ex = (lat_ex / cp.linalg.norm(lat_ex, axis=1).reshape(-1, 1)) * 0.5

for ep in range(Nep):
    inputs = gp_data
    o = cp.zeros((N, 1))
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o += ((mask * p) - o) * 0.5
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1486.56it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1484.99it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1490.78it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1481.96it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1486.91it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1488.52it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1477.72it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1478.98it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1480.07it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1471.38it/s]


Ok, obviously still strong topology, but it's definitely more local.  I'm going to set `ex = 3` and see what happens then.

In [15]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 10
T_s = 10_000

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = img_count
ex = 3

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

lat_ex = (lat_ex / cp.linalg.norm(lat_ex, axis=1).reshape(-1, 1)) * 0.5

for ep in range(Nep):
    inputs = gp_data
    o = cp.zeros((N, 1))
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o += ((mask * p) - o) * 0.5
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1458.04it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1458.40it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1465.55it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1464.97it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1469.21it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1474.84it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1478.53it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1476.77it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1476.03it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1474.44it/s]


Well I think what's happening is that the larger field of excitation is causing some neurons to win way more frequently, which means that they are forced to localize to a much smaller prototype.  Interesting.  I'm going to do the animation, and see what it looks like.

In [16]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 5))

w_np = w.get()
lx_np = lat_ex.get()

ims = []
o = np.zeros((N, 1))
for i in tqdm(range(500)):
    v = ts_data[i].reshape(-1, 1)
    
    p = w_np @ v
    c = p + (lx_np @ p)
    winners = np.argsort(c, axis=0)[-n_w:]
    mask = np.zeros((N, 1))
    mask[winners] = 1
    o += ((mask * p) - o) * 0.5
    r = w_np.T @ o

    o_oo = o / max(prec, np.max(o))
    r = r / max(prec, np.max(r))
    
    mini_tap = np.zeros((40, 80))
    
    mini_tap[:20, :20] = v.reshape(20, 20)
    mini_tap[:, 20:60] = o_oo.reshape(40, 40)
    mini_tap[:20, -20:] = r.reshape(20, 20)
    
    im = plt.imshow(mini_tap, cmap="gray_r", animated=True)
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=150, blit=True,
                                repeat_delay=500)

plt.xticks([])
plt.yticks([])

plt.show()

<IPython.core.display.Javascript object>

100%|██████████| 500/500 [00:01<00:00, 379.07it/s]


This is absolutely fascinating.  Ok, I think the thing to do now is to train this on pure digits.  Here we go.

In [17]:
gp_tr = cp.asarray(x_tr.reshape(-1, 28 * 28))

In [18]:
fig = plt.figure(figsize=(10, 10))
draw_weights(gp_tr.get(), 10, 10, 28, fig)

<IPython.core.display.Javascript object>

Awesome.  Here we go.

In [21]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 12
T_s = 10_000

sl = 28

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = 60_000
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_tr[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        p += lat_ex @ p
        winners = cp.argsort(p, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1429.42it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1429.00it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1420.51it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1419.31it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1426.13it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1427.59it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1424.04it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1423.48it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1428.80it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1425.63it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1428.63it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1429.94it/s]


In [28]:
def reconstruct(img_i, w, gp_tr, n_w):
    v = gp_tr[img_i].reshape(-1, 1)
    p = w @ v
    p += lat_ex @ p
    winners = cp.argsort(p, axis=0)[-n_w:]
    mask = cp.zeros((N, 1))
    mask[winners] = 1
    o = mask * p
    r = w.T @ o

    tap = np.zeros((28, 2 * 28))
    tap[:, :28] = v.reshape(28, 28).get()
    tap[:, 28:] = r.reshape(28, 28).get()

    plt.figure(figsize=(4, 1))

    plt.xticks([])
    plt.yticks([])

    plt.imshow(tap, cmap='gray_r')
    plt.show()

In [26]:
offset = 100

for i in range(10):
    reconstruct(offset + i, w, gp_tr, n_w)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Reconstructions are pretty darn good, but they're too heavy.  Luckily, I think I know what the problem is. I wasn't using a designated competition value.

In [27]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 12
T_s = 10_000

sl = 28

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = 60_000
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_tr[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1426.70it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1426.84it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1424.15it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1426.29it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1428.29it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1431.89it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1429.27it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1425.50it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1429.77it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1429.19it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1437.21it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1435.37it/s]


In [30]:
offset = 0

for i in range(10):
    reconstruct(offset + i, w, gp_tr, n_w)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Wow the winners are way to heavy I wonder why that is.  As the final test before I wrap this up, I'm going to try training this for a bunch of epochs.

In [31]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 20
T_s = 10_000

sl = 28

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = 60_000
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_tr[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1436.31it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1444.45it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1441.12it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1452.45it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1442.55it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1442.88it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1442.64it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1442.26it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1445.27it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1441.21it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1431.13it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1435.17it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1442.71it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1444.78it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1440.16it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1439.72it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1444.11it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1440.

The network seems determined to not move away from prototypes.  Why would that be?  I'm going to train the same network with no lateral excitation just to see what happens.

In [32]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 20
T_s = 10_000

sl = 28

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = 60_000
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03


for ep in range(Nep):
    inputs = gp_tr[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        winners = cp.argsort(p, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:05<00:00, 1730.54it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1737.62it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1733.24it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1735.67it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1738.73it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1737.69it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1733.26it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1731.43it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1731.26it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1727.85it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1727.08it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1727.44it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1727.63it/s]
100%|██████████| 10000/10000 [00:05<00:00, 1731.82it/s]
 79%|███████▉  | 7939/10000 [00:04<00:01, 1707.92it/s]


KeyboardInterrupt: 

OOOOH.  So it's not that the lateral excitation is messing stuff up, it's more that the network isn't training fast enough.  Well boys, *that's* an easy problem to solve.

In [33]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 20
T_s = 10_000

sl = 28

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = 60_000
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.1

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_tr[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1422.48it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1425.94it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1429.94it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1429.95it/s]
 43%|████▎     | 4316/10000 [00:03<00:03, 1425.27it/s]


KeyboardInterrupt: 

...?! What the flippety flip?

Time to clippety clip!

In [34]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 20
T_s = 10_000

sl = 28

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = 60_000
ex = 2

prec = 1e-10

n_w = 5 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.1

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_tr[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1372.43it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1369.48it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1373.15it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1368.57it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1365.05it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1365.89it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1358.36it/s]
 49%|████▉     | 4941/10000 [00:03<00:03, 1345.12it/s]


KeyboardInterrupt: 

Hmm, I'm just going to try more winners.

In [35]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 20
T_s = 10_000

sl = 28

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl ** 2
tr_len = 60_000
ex = 2

prec = 1e-10

n_w = 10 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.03

lat_ex = []

for y in range(Ky):
    for x in range(Kx):
        curr_ex = np.zeros((Ky, Kx))
        curr_ex[max(0, y - ex): min(Ky - 1, y + ex + 1), max(0, x - ex): min(Kx - 1, x + ex + 1)] = 0.1
        lat_ex.append(curr_ex.reshape(-1))
    
lat_ex = cp.array(lat_ex)
cp.fill_diagonal(lat_ex, 0)

for ep in range(Nep):
    inputs = gp_tr[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_weights(w.get(), Kx, Ky, sl, fig)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1364.40it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1366.76it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1365.36it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1349.10it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1363.38it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1362.90it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1365.03it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1361.63it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1366.66it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1366.18it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1363.10it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1343.95it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1358.33it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1367.56it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1370.24it/s]
 34%|███▍      | 3418/10000 [00:02<00:05, 1308.55it/s]


KeyboardInterrupt: 

In [37]:
offset = 10

for i in range(10):
    reconstruct(offset + i, w, gp_tr, n_w)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Ok I'm convinced.

## Conclusions

Lateral excitation is my best frikin friend.  This provides an excellent way to organize the sparse feature space.  I think this will be incredibly helpful for my invariant layer.  Before I mess with the invariant layer, however, I straight-up need to see the topology that's learned for cifar-10.


## Next steps

Train this same architecture on cifar10. That's the only way.  After I've done that, figure out how this helps the invariant layer.