goal
- conv architecture for bagnet using 1x1 convolution. 
    - verify receptive field computation etc.
- implement functions to draw bounding boxes

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"itertools.product(np.arange(a), np.arange(b))
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:/usr/local/cuda/lib64'

from functools import partial


import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze


import matplotlib.pyplot as plt
import matplotlib.patches as patches


In [None]:
from bagnet import bagnet33



model = bagnet33(); model.eval()
h, w = 224, 224
img_np = np.ones((1, 3, h, w))
rf, gx = compute_RF_numerical(model, img_np, out_cnn_idx=1)
print(rf)

fig, ax = plt.subplots(1,1,figsize=(10,10))
gx[gx!=0] = 1
ax.imshow(gx.squeeze(), cmap='Greys')


In [None]:
import numpy as onp
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

class CNN(nn.Module):
    # rf: [46, 46]
    #
    def __init__(self,
                 num_classes=1,
                 n_filters=16):
        super(CNN, self).__init__()

        def _make_block(in_channels, out_channels, stride=2, padding=1):
            return [nn.Conv2d(in_channels, out_channels,
                              kernel_size=4, stride=stride, padding=padding),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU()]
        n_layers = 4
        layers = []
        for i in range(n_layers):
            layers.extend(_make_block(
                1 if i == 0 else n_filters*(2**(i-1)), n_filters*(2**i)))
        self.conv_blocks = nn.Sequential(*layers)

    def forward(self, x, output_feat=True):
        # (1,224,224)
        x = self.conv_blocks(x)
#         # (128,14,14)
        return x


def cnn16(pretrained=False, **kwargs):
    if pretrained:
        raise ValueError('No pretrained model for CNN')
    kwargs['n_filters'] = 16
    model = CNN(**kwargs)
    return model

def compute_RF_numerical(net,img_np, out_cnn_idx=None):
    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            m.weight.data.fill_(1)
            if hasattr(m, 'bias') and m.bias is not None:
                m.bias.data.fill_(0)
        if isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.fill_(0)
            m.eval()
    
    net.apply(weights_init)
    img_ = torch.from_numpy(onp.array(img_np)).float()
    img_.requires_grad = True
    out_cnn=net(img_)
    if out_cnn_idx is not None:
        out_cnn = out_cnn[out_cnn_idx]
    out_shape=out_cnn.size()
    ndims=len(out_cnn.size())
    grad=torch.zeros(out_cnn.size())
    l_tmp=[]
    for i in range(ndims):
        if i==0 or i ==1:#batch or channel
            l_tmp.append(0)
        else:
            l_tmp.append(out_shape[i]//2)
    l_tmp = tuple(int(x) for x in l_tmp)
    grad[l_tmp]=1
    print('outshape: ', out_shape)
    out_cnn.backward(gradient=grad)
    grad_np=img_.grad[0,0].data.numpy()
    idx_nonzeros=np.where(grad_np!=0)
    RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros]

    return RF, grad_np



model = cnn16()
h, w = 224, 224
img_np = np.ones((1, 1, h, w))
rf, gx = compute_RF_numerical(model, img_np)
print(rf)

fig, ax = plt.subplots(1,1,figsize=(10,10))
# gx[gx!=0] = 1
ax.imshow(gx.squeeze(), cmap='Greys')

model

In [None]:

import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze


def compute_receptive_fields(model_def, in_shape, spike_loc=None):
    """Computes receptive fields using gradients
        For images, returns receptive fields for (h, w)
    """
    x = np.ones(in_shape)
    model = model_def()
    params = model.init(random.PRNGKey(0), x)
    params = freeze(jax.tree_map(lambda w: np.ones(w.shape),
                                 unfreeze(params)))
    # vjp (𝑥,𝑣)↦∂𝑓(𝑥)ᵀv
    # vjp :: (a -> b) -> a -> (b, CT b -> CT a)
    #     vjp: (f, x) -> (f(x), vjp_fn) where vjp_fn: u -> v
    f = lambda x: model.apply(params, x)
    y, vjp_fn = vjp(f, x)
    S = y.shape
    gy = np.zeros(S)
    if spike_loc is not None:
        ind = jax.ops.index[0, spike_loc[:,0], spike_loc[:,1], ...]
    else:
        ind = jax.ops.index[0, S[1]//2, S[2]//2, ...]
    gy = jax.ops.index_update(gy, ind, 1)
    gx = vjp_fn(gy)[0]
    I = np.where(gx!=0)
    rf = np.array([np.max(idx)-np.min(idx)+1
                   for idx in I])[[1,2]] # (y, x)
    return rf, gx


class CNNmnist(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        return x
    

class CNNcxr(nn.Module):

    @nn.compact
    def __call__(self, x):
        conv = partial(nn.Conv, kernel_size=(4, 4), strides=(2, 2))
#         assert(x.shape[1] == 224 and x.shape[2] == 224)
#         x = x.reshape(-1, 224, 224, 1)
        # (1, 224, 224, 1)
        x = conv(features=16)(x)
        x = nn.relu(x)
        x = conv(features=32)(x)
        x = nn.relu(x)
        x = conv(features=64)(x)
        x = nn.relu(x)
        x = conv(features=128)(x)
        x = nn.relu(x)
        return x

    
model_def = CNNmnist; h, w = 28, 28
model_def = CNNcxr; h, w = 224, 224
in_shape = (1, h, w, 1)
spike_loc = np.array([[1, 1], [1, -2], [-2, 1], [-2, -2]])
# input_loc = np.array([[1, 1], []])
rf, _ = compute_receptive_fields(model_def, in_shape)
_, gx = compute_receptive_fields(model_def, in_shape, None)
# gx = jax.ops.index_update(gx, gx!=0, 1)

fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(gx.squeeze(), cmap='Greys')
r = rf//2
xy = (h//2-r[1]-.5, w//2-r[0]-.5) # half-pixel
rect = patches.Rectangle(xy, rf[1], rf[0],
    linewidth=1, edgecolor='r', facecolor='none')
# ax.add_patch(rect)

rf


In [None]:
in_shape = (1,46,46,1)
x = np.ones(in_shape)
model = model_def()
key = random.PRNGKey(0)
params = model.init(key, x)
print(model.apply(params, x).shape)

spike_loc = np.array([[1, 1]])
_, gx = compute_receptive_fields(model_def, in_shape, spike_loc)
gx = jax.ops.index_update(gx, gx!=0, 1)


fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(gx.squeeze(), cmap='Greys')
ax.grid()
# r = rf//2
# xy = (h//2-r[1]-.5, w//2-r[0]-.5) # half-pixel
# rect = patches.Rectangle(xy, rf[1], rf[0],
#     linewidth=1, edgecolor='r', facecolor='none')
# ax.add_patch(rect)


In [None]:
gx.shape