# Topo MWTA Cifar Color

## Intro
* **Date**: 1/17/2021
* **What**: Literally just standard topo mwta, but I'm training it on Cifar10 colored 10x10 sections.  I simply yearn to see what the topology of the colored prototypes looks like.
* **Why**: Because I'm a child and I want to see pretty pictures.
* **Hopes**: This is incredibly qualitative.  I literally want to see the pretty pictures.  And hey, mother fucker?  You do to.  Trust me.  Whose been staring at topologically organized feature-prototype spaces, you or me?  Yeah, that's right.  Mama knows best.
* **Limitations**: Oh lord only knows.  Given how experiments have been going, a demi-gorgon might literally erupt from my screen at any second and eat my face of, thus ruining the experiment.

## Code

In [1]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from tensorflow.keras.datasets import cifar10
from tqdm import tqdm

(tr_x, _), (te_x, _) = cifar10.load_data()

tr_x = np.moveaxis(tr_x, 1, 3)
te_x = np.moveaxis(te_x, 1, 3)

In [2]:
s_len = 10
slh = s_len // 2
sl2 = s_len * s_len

data = tr_x[:, 16 - slh: 16 + slh, 16 - slh: 16 + slh]
flat_data = data.reshape(-1, sl2 * 3) / 255.0
fd_len = flat_data.shape[0]

te_data = te_x[:, 16 - slh: 16 + slh, 16 - slh: 16 + slh]
te_flat_data = te_data.reshape(-1, sl2 * 3) / 255.0

In [3]:
gp_data = cp.asarray(flat_data)

In [None]:
def draw_color_weights(w, Kx, Ky, fig, s_len):
    tapestry = np.zeros((s_len * Kx, s_len * Ky, 3))
    
    w_i = 0
    for y in range(Ky):
        for x in range(Kx):
            tapestry[y * s_len: (y + 1) * s_len, x * s_len: (x + 1) * s_len] = w[w_i].reshape(s_len, s_len, 3)
            w_i += 1
            
    plt.clf()        
    im = plt.imshow(tapestry / np.maximum(1e-10, np.max(tapestry)), vmin=0)
    plt.axis("off")
    fig.canvas.draw()

In [16]:
fig = plt.figure(figsize=(6, 6))
offset = 100
draw_color_weights(gp_data[offset:offset + 100].reshape(100, -1).get(), 10, 10, fig, 10)

<IPython.core.display.Javascript object>

In [8]:
def gen_lat_ex(Kx, Ky, ex, coeff):

    lat_ex = []

    for y in range(Ky):
        for x in range(Kx):
            curr_ex = np.zeros((Ky, Kx))
            curr_ex[max(0, y - ex): min(Ky, y + ex + 1), max(0, x - ex): min(Kx, x + ex + 1)] = 1

            if y - ex < 0:
                curr_ex[y - ex:, max(0, x - ex): min(Kx, x + ex + 1)] = 1
            if y + ex + 1 > Ky:
                curr_ex[:y + ex + 1 - Ky, max(0, x - ex): min(Kx, x + ex + 1)] = 1

            if x - ex < 0:
                curr_ex[max(0, y - ex): min(Ky, y + ex + 1), x - ex:] = 1
            if x + ex + 1 > Kx:
                curr_ex[max(0, y - ex): min(Ky, y + ex + 1), :x + ex + 1 - Kx] = 1

            if y - ex < 0 and x - ex < 0:
                curr_ex[y - ex:, x - ex:] = 1

            if x + ex + 1 > Kx and y + ex + 1 > Ky:
                curr_ex[:y + ex + 1 - Ky, :x + ex + 1 - Kx] = 1

            if y - ex < 0 and x + ex + 1 > Kx:
                curr_ex[y - ex:, :x + ex + 1 - Kx] = 1


            if x - ex < 0 and y + ex + 1 > Ky:
                curr_ex[:y + ex + 1 - Ky, x - ex:] = 1

            lat_ex.append(curr_ex.reshape(-1))

    lat_ex = cp.array(lat_ex)
    np.fill_diagonal(lat_ex, 0)

    lat_ex = (lat_ex / cp.linalg.norm(lat_ex, axis=1).reshape(-1, 1)) * coeff
    np.fill_diagonal(lat_ex, 1)
    
    return lat_ex

# Analysis Dialog

For some unholy reason, the axes were switched on the cifar10 data, so I just spent too much time (~3 min) getting that sorted out.  Aight bois, here we go.

In [9]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 20
T_s = 10_000

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl2 * 3
tr_len = 60_000
ex = 2
coeff = 0.5

prec = 1e-10

n_w = 10 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.05

lat_ex = gen_lat_ex(Kx, Ky, ex, coeff)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_color_weights(w.get(), Kx, Ky, fig, s_len)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1504.99it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1501.82it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1503.24it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1509.28it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1506.38it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1501.89it/s]
 45%|████▌     | 4522/10000 [00:03<00:03, 1472.65it/s]


KeyboardInterrupt: 

Yug.  I'm not seeing the pretty pictures I wanted to see.  I'm going to increase the number of winners to 30.

In [12]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 20
T_s = 10_000

Kx = 40
Ky = 40
N = Kx * Ky
m_len = sl2 * 3
tr_len = 60_000
ex = 2
coeff = 0.5

prec = 1e-10

n_w = 30 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.005

lat_ex = gen_lat_ex(Kx, Ky, ex, coeff)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_color_weights(w.get(), Kx, Ky, fig, s_len)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1545.62it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1554.01it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1495.71it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1483.28it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1542.96it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1485.69it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1486.61it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1480.76it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1483.85it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1479.17it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1480.82it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1478.61it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1553.90it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1551.52it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1552.54it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1564.13it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1563.08it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1561.

I think it's reasonable for me to want more than that.  I'm going to mess with `draw_color_weights`.

In [19]:
%matplotlib notebook
fig = plt.figure(figsize=(8, 8))

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_color_weights(w.get(), Kx, Ky, fig, s_len)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:06<00:00, 1559.12it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1552.05it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1554.17it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1555.70it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1557.98it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1547.97it/s]
 58%|█████▊    | 5765/10000 [00:03<00:02, 1530.66it/s]


KeyboardInterrupt: 

Yeah.  That's what I'm talking about, baby.  Ok, now that I've got that one looking pretty, I'm going to try 10 winners again, see how that goes.  Also, more neurons because why the fuck not?

In [21]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 10))

Nep = 20
T_s = 10_000

Kx = 50
Ky = 50
N = Kx * Ky
m_len = sl2 * 3
tr_len = 60_000
ex = 2
coeff = 0.5

prec = 1e-10

n_w = 10 #Number of winners

# Feedforward
w = cp.random.uniform(.15, .16, (N, m_len))
xi = 0.05

lat_ex = gen_lat_ex(Kx, Ky, ex, coeff)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_color_weights(w.get(), Kx, Ky, fig, s_len)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1342.76it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1346.24it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1351.16it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1355.00it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1347.28it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1349.67it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1349.46it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1349.80it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1350.19it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1349.46it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1346.47it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1354.05it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1346.77it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1346.07it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1343.06it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1343.93it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1349.10it/s]
  9%|▉         | 943/10000 [00:00<00:07, 1234.92

KeyboardInterrupt: 

You know what?  I might honestly like the feature space with 30 winners better.  I'm going to do a couple reconstructions though, see what we're looking at.

In [30]:
img_i = 100

v = gp_data[img_i].reshape(-1, 1)
p = w @ v
c = p + (lat_ex @ p)
winners = cp.argsort(c, axis=0)[-n_w:]
mask = cp.zeros((N, 1))
mask[winners] = 1
o = mask * p
r = w.T @ o

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.xticks([])
plt.yticks([])
plt.imshow(v.reshape(s_len, s_len, 3).get())

plt.subplot(1, 2, 2)
plt.xticks([])
plt.yticks([])
plt.imshow(r.reshape(s_len, s_len, 3).get())

plt.show()

<IPython.core.display.Javascript object>

Fam, that's not fantastic.  I'll tell you that much.  I'm going to do 30 winners again, and see what we can get.  Also these feature spaces are looking so good I might send them right on over to Olshausen.

In [37]:
%matplotlib notebook
fig = plt.figure(figsize=(8, 8))

Nep = 20
T_s = 10_000

Kx = 50
Ky = 50
N = Kx * Ky
m_len = sl2 * 3
tr_len = 50_000
ex = 2
coeff = 0.5

prec = 1e-10

n_w = 30 #Number of winners

# Feedforward
w = cp.random.uniform(.21, .22, (N, m_len))
xi = 0.05

lat_ex = gen_lat_ex(Kx, Ky, ex, coeff)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_color_weights(w.get(), Kx, Ky, fig, s_len)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1286.38it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1285.43it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1289.26it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1282.75it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1342.24it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1346.45it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1345.01it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1345.24it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1342.00it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1347.14it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1345.02it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1347.86it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1347.91it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1348.09it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1345.66it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1350.53it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1348.01it/s]
 22%|██▏       | 2160/10000 [00:01<00:05, 1311.4

KeyboardInterrupt: 

In [43]:
img_i = 5

v = gp_data[img_i].reshape(-1, 1)
p = w @ v
c = p + (lat_ex @ p)
winners = cp.argsort(c, axis=0)[-n_w:]
mask = cp.zeros((N, 1))
mask[winners] = 1
o = mask * p
r = w.T @ o

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.xticks([])
plt.yticks([])
plt.imshow(v.reshape(s_len, s_len, 3).get())

plt.subplot(1, 2, 2)
plt.xticks([])
plt.yticks([])
plt.imshow(r.reshape(s_len, s_len, 3).get())

plt.show()

<IPython.core.display.Javascript object>

You know what?  I'm going to train on 50 winners.

In [44]:
%matplotlib notebook
fig = plt.figure(figsize=(8, 8))

Nep = 20
T_s = 10_000

Kx = 50
Ky = 50
N = Kx * Ky
m_len = sl2 * 3
tr_len = 50_000
ex = 2
coeff = 0.5

prec = 1e-10

n_w = 50 #Number of winners

# Feedforward
w = cp.random.uniform(.21, .22, (N, m_len))
xi = 0.05

lat_ex = gen_lat_ex(Kx, Ky, ex, coeff)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_color_weights(w.get(), Kx, Ky, fig, s_len)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1296.86it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1295.01it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1296.39it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1295.10it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1297.57it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1301.82it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1294.80it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1298.73it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1299.69it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1300.54it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1299.19it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1301.09it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1298.24it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1296.33it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1301.35it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1293.58it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1302.35it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1304.

Hmm.  Ok, let's see what we can see with 40 winners.

In [45]:
%matplotlib notebook
fig = plt.figure(figsize=(8, 8))

Nep = 10
T_s = 10_000

Kx = 50
Ky = 50
N = Kx * Ky
m_len = sl2 * 3
tr_len = 50_000
ex = 2
coeff = 0.5

prec = 1e-10

n_w = 40 #Number of winners

# Feedforward
w = cp.random.uniform(.21, .22, (N, m_len))
xi = 0.05

lat_ex = gen_lat_ex(Kx, Ky, ex, coeff)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_color_weights(w.get(), Kx, Ky, fig, s_len)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:07<00:00, 1293.81it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1293.61it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1296.95it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1296.46it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1299.53it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1295.58it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1332.83it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1289.66it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1291.52it/s]
100%|██████████| 10000/10000 [00:07<00:00, 1291.56it/s]


In [54]:
img_i = 55

v = gp_data[img_i].reshape(-1, 1)
p = w @ v
c = p + (lat_ex @ p)
winners = cp.argsort(c, axis=0)[-n_w:]
mask = cp.zeros((N, 1))
mask[winners] = 1
o = mask * p
r = w.T @ o

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.xticks([])
plt.yticks([])
plt.imshow(v.reshape(s_len, s_len, 3).get())

plt.subplot(1, 2, 2)
plt.xticks([])
plt.yticks([])
plt.imshow(r.reshape(s_len, s_len, 3).get())

plt.show()

<IPython.core.display.Javascript object>

The reconstructions still aren't that great.  I think maybe training this on HSV might be better.  It looks like it's getting a good portion of the structure correct, just not the colors.

Ok, as a final measure, I'm going to train this on 10,000 neurons.  My fav.

In [55]:
%matplotlib notebook
fig = plt.figure(figsize=(8, 8))

Nep = 10
T_s = 10_000

Kx = 100
Ky = 100
N = Kx * Ky
m_len = sl2 * 3
tr_len = 50_000
ex = 2
coeff = 0.5

prec = 1e-10

n_w = 40 #Number of winners

# Feedforward
w = cp.random.uniform(.21, .22, (N, m_len))
xi = 0.1

lat_ex = gen_lat_ex(Kx, Ky, ex, coeff)

for ep in range(Nep):
    inputs = gp_data[np.random.permutation(tr_len)]
    for i in tqdm(range(T_s)):
        v = inputs[i].reshape(-1, 1)
        p = w @ v
        c = p + (lat_ex @ p)
        winners = cp.argsort(c, axis=0)[-n_w:]
        mask = cp.zeros((N, 1))
        mask[winners] = 1
        o = mask * p
        r = w.T @ o
        mod_r = cp.maximum(r, prec)
        e = cp.clip(v - r, -1, 1)
        
        w += w * o * (e / mod_r).T * xi
        
        cp.clip(w, 0, 1, out=w)
        
    draw_color_weights(w.get(), Kx, Ky, fig, s_len)

<IPython.core.display.Javascript object>

100%|██████████| 10000/10000 [00:24<00:00, 411.37it/s]
100%|██████████| 10000/10000 [00:24<00:00, 415.40it/s]
100%|██████████| 10000/10000 [00:24<00:00, 415.57it/s]
100%|██████████| 10000/10000 [00:24<00:00, 415.44it/s]
100%|██████████| 10000/10000 [00:24<00:00, 415.77it/s]
100%|██████████| 10000/10000 [00:24<00:00, 415.43it/s]
100%|██████████| 10000/10000 [00:24<00:00, 415.31it/s]
100%|██████████| 10000/10000 [00:24<00:00, 415.08it/s]
100%|██████████| 10000/10000 [00:24<00:00, 411.97it/s]
100%|██████████| 10000/10000 [00:24<00:00, 411.71it/s]


In [56]:
fig = plt.figure(figsize=(10, 10))
draw_color_weights(w.get(), Kx, Ky, fig, s_len)

<IPython.core.display.Javascript object>

RECONSTRUCTIONS!!

In [79]:
img_i = 13

v = cp.asarray(te_flat_data[img_i].reshape(-1, 1))
p = w @ v
c = p + (lat_ex @ p)
winners = cp.argsort(c, axis=0)[-n_w:]
mask = cp.zeros((N, 1))
mask[winners] = 1
o = mask * p
r = w.T @ o

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.xticks([])
plt.yticks([])
plt.imshow(v.reshape(s_len, s_len, 3).get())

plt.subplot(1, 2, 2)
plt.xticks([])
plt.yticks([])
plt.imshow(r.reshape(s_len, s_len, 3).get())

plt.show()

<IPython.core.display.Javascript object>

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Ok, that's pretty good. Unsurprisingly, the reconstructions get better as you add neurons.  :)

## Conclusions

Well dope.  They look cool.  In fact, the feature spaces look so good I might just send them over to Olshausen, see what he thinks.

Yeah that's what I'm going to do.

## Next steps

Email pretty pictures to Olshausen.  Also I need to implement the new sparsity model.  Threshold sparsity, baby.  Basically, we get sparsity, but there isn't a hard guarentee on the number of neurons that win.