In [11]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation


In [3]:
def build_circle_image(size=[64,128], offset=[0,0], radius=15,
                       fg_color=[1.,0.,0.], bg_color=[0.,0.,1.], to_int=True):
    H,W = size
    center = (np.array(size) // 2) + np.array(offset)
    def _in_circle(y,x):
        r = np.sqrt((y - center[0])**2 + (x - center[1])**2)
        return r < float(radius)
    
    ones = np.ones((H,W), dtype=np.float32)
    him = np.arange(H)[:,None] * ones
    wim = np.arange(W)[None,:] * ones
    labels = _in_circle(him,wim).astype(np.uint8)
    image = np.ones(size + [3], dtype=np.float32) * np.array(bg_color).reshape((1,1,3))
    image[np.where(labels > 0.5)] = fg_color
    if to_int:
        image = (image * 255.).astype(np.uint8)
    return image, labels

def build_toy_inputs(size=[512,1024], batch_size=1, param_dicts=None):
    
    ims = []
    labs = []
    if param_dicts is None:
        param_dicts = [{} for _ in range(batch_size)]
    for b in range(batch_size):
        im, lab = build_circle_image(size, to_int=True, **param_dicts[b])
        ims.append(im)
        labs.append(lab)
        
    ims = np.stack(ims, 0)
    labs = np.stack(labs, 0)
    
    inputs = {
        'images': ims.astype(np.float32),
        'objects': labs,
        'semantics': labs
    }
    
    return inputs

In [321]:
b = 1
h,w = [64,64]
inp = build_toy_inputs(size=[h,w], batch_size=b, param_dicts=[{'radius':15}])
inp2 = build_toy_inputs(size=[h,w], batch_size=b, param_dicts=[{'radius':30, 'offset':[10,10]}])
inp3 = build_toy_inputs(size=[h,w], batch_size=b, param_dicts=[{'radius': 10, 'offset': [-20,-20]}])
plt.imshow(inp3['images'][0] / 255.)
plt.show()

In [412]:
def softmax(x, beta=1.0, axis=-1):
    probs = np.exp(beta * (x - np.max(x, axis=axis, keepdims=True)))
    denom = np.sum(probs, axis=axis, keepdims=True)
    return probs / denom

def sigmoid(x, beta=1.0):
    y = np.exp(-1. * beta * x)
    return (1. / (1. + y))

def normalize_pos(x, axis=-1, eps=1e-8):
    x = np.maximum(x, 0.)
    norm = np.sqrt(np.sum(x**2, axis=axis, keepdims=True))
    return x / (norm + eps)

def get_local_kernel_mask(size=[64,64], k=5):
    _h,_w = size
    him = np.tile(np.arange(_h)[:,None], [1,_w])
    wim = np.tile(np.arange(_w)[None], [_h,1])
    x = np.stack([him,wim], -1) # [_h,_w,2]
    mask = x[:,:,None,None] - x[None,None,:,:]
#     mask = (np.square(mask) < k**2).astype(np.float32)
#     mask = mask[...,0] * mask[...,1]
    mask = np.sqrt(np.sum(mask**2, -1)) < k
    return np.reshape(mask, (_h*_w, _h*_w))


obj = inp['objects'] + 2 * inp2['objects'] + 4 * inp3['objects']
adj = obj.reshape((b,-1))
adj = (adj[...,None] == adj[:,None]).astype(np.float32)
mask = get_local_kernel_mask([h,w], k=10)[None]
adj_e = adj * mask
adj_i = (1. - adj_e) * mask
adj_e2 = np.matmul(adj_e, adj_e)
adj_ei = np.matmul(adj_e, adj_i)

N = adj_e.shape[1]
Q = 8

plt.imshow(obj[0])
plt.show()

<IPython.core.display.Javascript object>

In [500]:
from tqdm import tqdm

beta = 5.0
seed = 2
num_iters = 25
excite = True
inhibit = True
push = True
smooth = True
damp = False
svd_init = False

rng = np.random.RandomState(seed)
n0 = rng.choice(N)
activated = np.zeros([b,N,1], dtype=np.float32)
activated[:,n0,:] = 1.0
act = activated

if svd_init:
    u,s,vh = np.linalg.svd(adj_e)
    x = x[...,:Q]
    x = normalize_pos(x)
else:
    x = np.random.uniform(size=[b,N,Q]).astype(np.float32)
    x = softmax(x, beta, axis=-1)
    x[:,n0] = (0 == np.arange(Q)).astype(np.float32)

xs = [x + 0.]

avg = np.sum(x * activated, axis=1, keepdims=True) / np.sum(activated, axis=1, keepdims=True)
avgs = [avg + 0.]
es = []
acts = [act + 0.]

for it in tqdm(range(num_iters)):
    n_senders_e = np.maximum(1., np.sum(adj_e * act, axis=-2, keepdims=True)) # [B,1,N]
    n_senders_i = np.maximum(1., np.sum(adj_i * act, axis=-2, keepdims=True)) # [B,1,N]
    
    if excite:
        e_effects = np.matmul(x.transpose(0,2,1), adj_e * act) / n_senders_e
        x += e_effects.transpose(0,2,1)
    
    if inhibit:
        dots = np.matmul(x, x.transpose(0,2,1)) # [b,N,N]
        i_effects = np.matmul(x.transpose(0,2,1), adj_i * act) / n_senders_i
        x -= i_effects.transpose(0,2,1)
        
    if push:
#         sc = np.minimum(1., np.cumsum(avg[...,0:-1], -1))
#         sc = np.concatenate([np.zeros([b,1,1]), sc], -1)
        sc = 1. - avg
#         sc = np.sum(np.stack([1- a for a in avgs], 0), 0)
        r = np.arange(Q,0,-1).astype(float).reshape((1,1,Q))
        e = softmax(sc * r, beta=beta, axis=-1) # [b,1,Q]
        p_effects = np.sum(adj_i * act, axis=1, keepdims=True) / n_senders_i # [B,1,N]
        p_effects = p_effects * ((1. - act) * e).transpose(0,2,1)
        es.append(e + 0.)
        x += p_effects.transpose(0,2,1)
    
#     x = softmax(x, beta, axis=-1)
    x = normalize_pos(x)
    receivers = np.max(np.maximum(adj_e,adj_i) * activated, axis=1, keepdims=False)
    activated[np.where(receivers)] += 1.0
    act = np.minimum(1., activated)
    
    # smooth by double excitation path among all neurons?
    if smooth:
        s_effects = act * adj_e2 * act.transpose(0,2,1) # [B,S,R] == [B,N,N]
        s_effects = np.matmul(x.transpose(0,2,1), s_effects) / (n_senders_e**2) # [B,Q,N]
        x += s_effects.transpose(0,2,1)
        x = normalize_pos(x)
        

    avg = np.sum(x * act, axis=1, keepdims=True) / np.sum(act, axis=1, keepdims=True)
    
    if damp:
        x += (1. - avg) * sigmoid(-activated)
        x = softmax(x, axis=-1)
    
    xs.append(x + 0.)
    avgs.append(avg + 0.)
    
plateau = np.reshape(xs[-1], (b,h,w,Q))

# plt.imshow(np.argmax(plateau[0], -1))
# plt.colorbar()
# plt.show()

100%|██████████| 25/25 [00:13<00:00,  1.82it/s]


In [501]:
%matplotlib nbagg
SAVE = True

def get_frame(xs, idx=0, b=0):
    fr = xs[idx][b].reshape((h,w,-1))
    fr = np.argmax(fr, axis=-1)
    return fr

idx = 0

fig, axes = plt.subplots(figsize=(4,4))
im = plt.imshow(get_frame(xs, idx, 0), origin='upper')

def animate(*args):
    global idx
    
    fr = get_frame(xs, idx)
    im.set_array(fr)
    idx += 1
    idx %= len(xs)
    
    return im,

def savegif(ani, directory, file):
    import os
    f = os.path.join(directory, file.split('.')[0]) + ".gif"
    writergif = animation.PillowWriter(fps=10)
    ani.save(f, writer=writergif)

ani = animation.FuncAnimation(fig, animate, interval=400)
plt.colorbar()
plt.show()

if SAVE:
    savegif(ani, '/Users/db/neuroailab/', 'local_grouping2.gif')

<IPython.core.display.Javascript object>

In [504]:
i,j = [60,10]
plt.imshow(np.stack([x.reshape(h,w,Q)[i,j] for x in xs], -1))
plt.colorbar()
plt.show()

<IPython.core.display.Javascript object>

In [490]:
plt.imshow(np.stack([a[0,0,:] for a in es], -1))
plt.colorbar()
plt.show()

<IPython.core.display.Javascript object>

In [491]:
plt.imshow(np.stack([a[0,0,:] for a in avgs], -1))
plt.colorbar()
plt.show()

<IPython.core.display.Javascript object>

In [482]:
cs = [0,1,5]
pl = np.stack([plateau[0,...,c] for c in cs], -1)
plt.imshow(pl)
plt.show()

<IPython.core.display.Javascript object>

In [None]:
### TODO: push effects 
### TODO: initizliation
### TODO: compare to hard and soft labelprop
### TODO: more complex images
### TODO: Tensorflow grnn

Push effects should speed up convergence by exciting a particular label channel of a neuron not yet activated, but currently being inhibited

which label should be excited? the next one down the order from the pushing neurons?

or the one that's most common among its excitatory "skeleton" and also not the one from the pushing neurons?

Average map of excited neurons?

Check whether coalitions are really coalitions by looking at plateau map 

Make activated tensor be graded rather than binary to favor earlier/more amplified neurons?

let skeleton grow in scale over time due to "feedback"? 

Double inhibitory path should prevent mutually inhibiting nodes from pointing in the same direction -- have to be pushed somehow; break symmetry by giving a stronger push by some?

Multiplex code so run twice, then take tensor product of channels? 8 by 8 or 6 by 6 by 6 performed with different intiializations?

In [379]:
6 * 6 * 6

216

In [117]:
u = np.ones([Q], dtype=np.float32)
s = avgs[1]
sc = np.minimum(1., np.cumsum(s[...,0:-1], -1))
sc = np.concatenate([np.zeros([1,1,1]), sc], -1)
r = np.arange(Q,0,-1).astype(float)
e = softmax(sc * r, beta=1., axis=-1)
sc * r, e

(array([[[0.        , 6.15600204, 5.39580309, 4.57645088, 3.72029305,
          2.85329515, 1.93036175, 0.9822765 ]]]),
 array([[[0.001165  , 0.54934183, 0.25685762, 0.11320153, 0.048087  ,
          0.02020669, 0.00802917, 0.00311116]]]))

In [298]:
np.linalg.svd?