goal
- conv architecture for bagnet using 1x1 convolution. 
    - verify receptive field computation etc.
- implement functions to draw bounding boxes

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:/usr/local/cuda/lib64'

import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze

import matplotlib.pyplot as plt
import matplotlib.patches as patches


In [None]:

def compute_receptive_fields(model_def, in_shape, spike_loc=None):
    """Computes receptive fields using gradients
        For images, returns receptive fields for (h, w)
    """
    x = np.ones(in_shape)
    model = model_def()
    params = model.init(random.PRNGKey(0), x)
    params = freeze(jax.tree_map(lambda w: np.ones(w.shape),
                                 unfreeze(params)))
    # vjp (𝑥,𝑣)↦∂𝑓(𝑥)ᵀv
    # vjp :: (a -> b) -> a -> (b, CT b -> CT a)
    #     vjp: (f, x) -> (f(x), vjp_fn) where vjp_fn: u -> v
    f = lambda x: model.apply(params, x)
    y, vjp_fn = vjp(f, x)
    S = y.shape
    gy = np.zeros(S)
    if spike_loc is not None:
        ind = jax.ops.index[0, spike_loc[:,0], spike_loc[:,1], ...]
    else:
        ind = jax.ops.index[0, S[1]//2, S[2]//2, ...]
    gy = jax.ops.index_update(gy, ind, 1)
    gx = vjp_fn(gy)[0]
    I = np.where(gx!=0)
    rf = np.array([np.max(idx)-np.min(idx)+1
                   for idx in I])[[1,2]] # (y, x)
    return rf, gx


class CNN(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        return x
    

h, w = 28, 28
in_shape = (1, h, w, 1)
spike_loc = np.array([[1, 1], [1, -2], [-2, 1], [-2, -2]])
# input_loc = np.array([[1, 1], []])
rf, _ = compute_receptive_fields(CNN, in_shape)
_, gx = compute_receptive_fields(CNN, in_shape, spike_loc)

fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(gx.squeeze(), cmap='Greys')
r = rf//2
xy = (h//2-r[1]-.5, w//2-r[0]-.5) # half-pixel
rect = patches.Rectangle(xy, rf[1], rf[0],
    linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)



rf


In [None]:
x = np.ones((1,10,10,1))
model = CNN()
key = random.PRNGKey(0)
params = model.init(key, x)
model.apply(params, x).shape