# Supervised learning of heterogeneous delays in a single layer of spiking neurons for ultrafast motion detection


We design a model based on heterogeneous delays. We will define in the assembly of neurons where each synapse
is defined by a weight *ws* at different delays *τs*.


In [6]:
datetag = '2023-09-22_FastMotionDetection'

## Initialization

Let's first initialize the framework.

In [7]:
# to install all dependencies use, uncomment the following line and restart the kernel
# %pip install -U -r requirements.txt

In [8]:
import os
HOST = os.uname()[1]
HOST

'CONEC-LID-001'

In [9]:
import matplotlib
import matplotlib.pyplot as plt
import torch
import numpy as np
%matplotlib inline

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# if HOST in ['obiwan.local', 'fortytwo']:
#     # device = 'cpu'    
#     #  pyTorch has not implementation for conv3d on mps yet
#     if torch.backends.mps.is_available():
#         device = torch.device('mps')
#         print('going 🤘')
#     else:
#         device = 'cpu'    
# elif torch.cuda.is_available():
#     device = torch.device('cuda')
# else:
#     device = 'cpu'
print(device)

cuda


In [11]:
DEBUG = 8
DEBUG = 4
DEBUG = 2
DEBUG = 1

In [12]:
seed = None
seed = 1973
seed_show = 2018 + 1973 + 42
np.random.seed(seed)
# size of the stimulus:
N_X, N_Y, N_T = 128//DEBUG, 128//DEBUG, 200//DEBUG
N_pola = 2
# average length of a block with constant speed
block_length = 24 # make it short like kernel_size?
# number of different motion directions possible
N_V_phi = 12
N_V_speed = 3
N_PGs = N_V_speed * N_V_phi # todo: change N_PGs in N_K

noise = .001 # Gaussian noise added to natural images
# variation of the speed is kept small for the motion detection task we want: motion detection
V_speed_0 = 0.5 # median speed
V_speed_base = 2.
V_max = 2.5 # V_speed_0 * V_speed_base # median speed

In [13]:
# inversely related to the threshold used for triggering events
selectivity = 1.0

In [14]:
# p_B = 1.5e-1
p_B = 1.
p_B_test = 1.5e-1
p_B_test_MC = 1.5e-1

In [15]:
lr = 2.0e-4 # learning rate
weight_init = 1.e-2
weight_init_center = 4.000
N_epochs = 200
N_train = 80000//DEBUG
N_epochs_scan = 50
N_train_scan = (N_epochs_scan * N_train)//N_epochs
seed_train = 2023
N_test = 200//DEBUG
loss_samples = 30 # number of last samples used to compute the loss
seed_test = seed_train + (N_train//N_epochs)
# dimensions of kernel_size along D, H, W
kernel_size = (21, 17, 17)
# do_wta = True
do_mask = True
do_adam = True
do_bias = True

In [16]:
N_X, N_Y, N_T, p_B * N_X * N_Y * N_T, p_B * np.prod(kernel_size)

(128, 128, 200, 3276800.0, 6069.0)

In [17]:
# meta-parameters for scanning parameters

# dimensions of kernel_size along D, H, W
# note: k_T has to be uneven
kernel_sizes = [(5, 5, 5), (11, 21, 21), (7, 7, 7),  (21, 21, 21)]
kernel_sizes = [(5, 5, 5), (7, 7, 7), (7, 11, 11), (11, 21, 21)]
if DEBUG>4: kernel_sizes = []

if DEBUG>4: k_Ts = []
else: k_Ts = [5, 9, 13, 21] # has to be uneven

N_scan = 3 if DEBUG>4 else 6
lrs = lr * np.logspace(-1, 1, N_scan, base=10)
# p_Bs = p_B * np.logspace(-1, 1, N_scan, base=10)
p_Bs = np.logspace(-2, 0, N_scan, base=10)
# p_B_inits = np.geomspace(1, p_B, N_scan)
weight_inits = weight_init * np.logspace(-1, 1, N_scan, base=10)
weight_init_centers = weight_init_center * np.logspace(-1, 1, N_scan, base=10)

In [18]:
timestamp = 1
i_y = 16
N_T_show = N_X
N_show_every = max((N_train//100, 1))

In [19]:
import os
import time
import pandas as pd
# to create empty files for parallel processing
def touch(fname): open(fname, 'w').close()

In [20]:
%ls data/

Olshausen_IMAGES.mat  vanhateren.npy


In [21]:
# # db_path = 'data/serre07_targets_whitening_contrastnorm.npy'
# db_path = 'data/serre07_targets.npy'
# # db_path = 'data/Olshausen_IMAGES.npy'
db_path = 'data/vanhateren.npy'
IMAGES = np.load(db_path)
N_X_data, N_Y_data, N_data = IMAGES.shape
N_X_data, N_Y_data, N_data

(1024, 1024, 100)

In [22]:
fig_width, phi = 15, np.sqrt(5)/2 + 1/2

In [23]:
do_figures = False
if HOST in ['obiwan.local', 'fortytwo']: do_figures = True

In [24]:
if do_figures:
    index_image = 3
    fig, ax = plt.subplots(figsize=(fig_width, fig_width))
    ax.imshow(IMAGES[:, :, index_image], cmap=plt.gray(), interpolation='nearest')
    plt.subplots_adjust(left=0.1, bottom=0.1, right=0.95, top=0.9)

In [25]:
figpath = None

In [26]:
figpath = '../../2023-07-05_Grimaldi-etal-BiologicalCybernetics_630f9044c38e7a3cea81a7b2/figures/'

In [27]:
figpath = 'figures/'

In [28]:
# %ls -ltr {figpath}

In [29]:
if not figpath is None:
    if not os.path.isdir(figpath): figpath = None
figpath

'figures/'

In [30]:
data_cache, figures = 'cache', 'figures'

In [31]:
%mkdir -p {data_cache} {figures}

In [32]:
# if DEBUG == 8: 
#     %rm -fr {os.path.join(data_cache, datetag)}
# %rm -fr {os.path.join(data_cache, datetag)}
# %rm -fr {os.path.join(data_cache, datetag)}/*lock
#%rm -fr {os.path.join(figures, datetag)}*mp4

In [33]:
%mkdir -p {os.path.join(data_cache, datetag)}

In [34]:
def cachepath(data_cache, datetag, DEBUG):
    return os.path.join(data_cache, datetag,  f'SL_DEBUG={DEBUG}')

## Natural World model

We will define the world model as the full generative model from the input image to the classes that generated them.

### motion domain

In [35]:
def logpol_speed(label):
    """
    label can be a single index or a list
    
    """
    V_phis = 2 * np.pi * ((label % N_V_phi)  + .5*((label // N_V_phi)%2)) / N_V_phi
    V_speeds = V_speed_0 * np.exp( np.log(V_speed_base) * (2* (label // N_V_phi) / N_V_speed -1))
    # V_speeds = V_speed_0 * np.geomspace(V_speed_0/V_speed_base, V_speed_0*V_speed_base, N_V_phi * N_V_speed, endpoint=True)
    return V_phis, V_speeds

V_phis_line, V_speeds_line = logpol_speed(np.arange(N_PGs))
V_phis_line, V_speeds_line

(array([0.        , 0.52359878, 1.04719755, 1.57079633, 2.0943951 ,
        2.61799388, 3.14159265, 3.66519143, 4.1887902 , 4.71238898,
        5.23598776, 5.75958653, 0.26179939, 0.78539816, 1.30899694,
        1.83259571, 2.35619449, 2.87979327, 3.40339204, 3.92699082,
        4.45058959, 4.97418837, 5.49778714, 6.02138592, 0.        ,
        0.52359878, 1.04719755, 1.57079633, 2.0943951 , 2.61799388,
        3.14159265, 3.66519143, 4.1887902 , 4.71238898, 5.23598776,
        5.75958653]),
 array([0.25      , 0.25      , 0.25      , 0.25      , 0.25      ,
        0.25      , 0.25      , 0.25      , 0.25      , 0.25      ,
        0.25      , 0.25      , 0.39685026, 0.39685026, 0.39685026,
        0.39685026, 0.39685026, 0.39685026, 0.39685026, 0.39685026,
        0.39685026, 0.39685026, 0.39685026, 0.39685026, 0.62996052,
        0.62996052, 0.62996052, 0.62996052, 0.62996052, 0.62996052,
        0.62996052, 0.62996052, 0.62996052, 0.62996052, 0.62996052,
        0.62996052]))

In [36]:
V_phis_line.shape

(36,)

In [37]:
label_down_by_one = N_V_phi * (N_V_speed-1)
label_down_by_one, V_phis_line[label_down_by_one], V_speeds_line[label_down_by_one]

(24, 0.0, 0.6299605249474365)

In [38]:
label_up_by_one = N_V_phi * (N_V_speed-1) + N_V_phi//2
label_up_by_one, V_phis_line[label_up_by_one], V_speeds_line[label_up_by_one]

(30, 3.141592653589793, 0.6299605249474365)

In [39]:
# https://github.com/matplotlib/cmocean
import cmocean


In [40]:
cmocean.cm.phase(0)

(0.6583083928922511, 0.4699391690315134, 0.049412882039880514, 1.0)

In [41]:
from matplotlib.colors import hsv_to_rgb
color_bar= np.empty((N_V_phi*N_V_speed, 3))
for i_h_speed in range(N_V_speed):
    for i_h_phi in range(N_V_phi):
        # H for direction, S for speed
        # color_bar[i_h_phi+i_h_speed*N_V_phi, :]= hsv_to_rgb((V_phis_line[i_h_phi]/2/np.pi, V_speeds_line[i_h_speed]/V_speed_0/V_speed_base, 1))
        color_bar[i_h_phi+i_h_speed*N_V_phi, :]= cmocean.cm.phase(V_phis_line[i_h_phi]/2/np.pi)[:-1]

def plot_speedspace(V_speed, V_phi, arrow_width=.04, dot_width=200, alpha=1., 
                    fig=None, ax=None, do_labels=False, color_bar=color_bar, fig_width=fig_width):
    if fig is None: fig, ax = plt.subplots(1, 1, figsize=(fig_width, fig_width))

    V_phis_line, V_speeds_line = logpol_speed(np.arange(N_PGs))
    ax.scatter(V_speeds_line*np.sin(V_phis_line), V_speeds_line*np.cos(V_phis_line), 
               s=dot_width*(V_speeds_line/np.median(V_speeds_line))**2, marker='.', color=color_bar, alpha=alpha)
    ax.plot([0], [0], 'k+')
    ax.arrow(0, 0, V_speed*np.sin(V_phi), V_speed*np.cos(V_phi), width=arrow_width, 
             color='red', length_includes_head=True)
    if do_labels:
        ax.set_ylabel('V_X')
        ax.set_xlabel('V_Y')
    else:
        ax.set_xticks([])
        ax.set_yticks([])
    V_speed_max = np.abs(V_speeds_line).max()
    margin = 1.15
    ax.set_xlim(-V_speed_max*margin, V_speed_max*margin)
    ax.set_ylim(-V_speed_max*margin, V_speed_max*margin)
    ax.invert_yaxis()
    #ax.axis('scaled')
    return fig, ax

if do_figures:
    fig, ax = plot_speedspace(V_speed=V_speed_0, V_phi=+.75*np.pi, do_labels=True, dot_width=500, alpha=1)
    ax.plot(V_speeds_line*np.sin(V_phis_line), V_speeds_line*np.cos(V_phis_line), 'orange', lw=.5);

### parameters of the "world model"

In [42]:
do_event = True

In [43]:
class NatWorld:
    def __init__(self, block_length=block_length,N_X=N_X, N_Y=N_Y, N_T=N_T, seed=seed,
                 N_data=N_data, i_image=None, noise=noise, selectivity=selectivity,
                 N_V_phi=N_V_phi, N_V_speed=N_V_speed,
                 do_event=do_event, label=None,
                ):
        self.block_length = block_length
        self.N_X = N_X
        self.N_Y = N_Y
        self.N_data = N_data
        self.N_T = N_T
        self.do_event = do_event
        self.seed = seed
        self.N_V_speed = N_V_speed
        self.N_V_phi = N_V_phi
        self.N_PGs = N_V_speed * N_V_phi
        self.selectivity = selectivity
        self.i_image = i_image
        self.noise = noise
        self.draw(seed, label)

    def draw(self, seed=None, label=None):
        if not(seed is None): self.seed = seed
        np.random.seed(seed=seed)
        # draw a random image from the database
        if self.i_image is None: self.i_image = np.random.randint(N_data)

        # we define the trajectory like a brownian motion (over a discretized set of motions)
        # inspired by https://github.com/laurentperrinet/bayesianchangepoint/blob/master/bayesianchangepoint/bcp.py#L35
        if label is None:
            self.label = np.zeros((self.N_T,), dtype=int)
            self.label[0] = np.random.randint(self.N_PGs)
            for i_T in range(1, self.N_T):
                if np.random.rand() < 1/self.block_length : #switch
                    self.label[i_T] = np.random.randint(self.N_PGs)
                else: # no switch
                    self.label[i_T] = self.label[i_T-1]
        elif type(6) == int:
            self.label = label * np.ones((self.N_T,), dtype=int)
        else:
            self.label = label

        self.V_phis, self.V_speeds = logpol_speed(self.label)

    def get_input(self, do_cache=False, do_show=False):
        if do_cache:
            tensor_fname = cachepath(data_cache, datetag, DEBUG) + f'_It_bool_seed={self.seed}.pt'
            if os.path.isfile(tensor_fname):
                It_bool = torch.load(tensor_fname)
            else:
                It_bool = self.get_input(do_cache=False, do_show=do_show)
                torch.save(It_bool, tensor_fname)
            return It_bool
        else:
            return make_natmovie_events(self, do_show=do_show)

one instance of the world

In [44]:
w = NatWorld(seed=seed)
w.label, w.V_phis

(array([22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 11, 11, 11, 11, 11,
        11, 11, 11, 11, 11, 11,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25]),
 array([5.49778714, 5.49778714, 5.49778714, 5.49778714, 5.49778714,
        5.49778714, 5.49778714, 5.

In [45]:
assert((w.V_phis==V_phis_line[w.label]).all())

In [46]:
w.draw(123456)
w.label

array([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  5,  5,  5,  5,  5,  5,  5,  5,  5, 10, 10, 12, 12, 12, 12,
       12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
       12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  4,  4,  4,
        4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
        4,  4,  4,  4, 21, 21, 21, 21, 21, 21, 21, 21, 21, 10, 10, 10, 10,
       10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
       10, 10, 10, 10, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31,  2,  2,
        2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
        2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
        2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 22, 22])

It is possible to design a block with no switch:

In [47]:
NatWorld(seed=seed, block_length=np.inf).label

array([22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
       22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22])

This makes it possible to draw random world in a reproducible way

In [48]:
NatWorld(seed=seed).label, NatWorld(seed=seed+1).label, NatWorld(seed=seed+2).label

(array([22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 11, 11, 11, 11, 11,
        11, 11, 11, 11, 11, 11,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25]),
 array([24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
        24, 24, 24, 24, 24

### natural images

In [49]:
if do_figures:
    index_image = NatWorld(seed=seed).i_image
    fig, ax = plt.subplots(figsize=(fig_width, fig_width))
    ax.imshow(IMAGES[:, :, index_image], cmap=plt.gray(), interpolation='nearest')
    plt.subplots_adjust(left=0.1, bottom=0.1, right=0.95, top=0.9)

Extracting the tools to perform subpixel translations from https://github.com/bicv/SLIP/blob/master/SLIP/SLIP.py#L507 :

In [50]:
from numpy.fft import fft2, fftshift, ifft2, ifftshift

class Image:
    def __init__(self, N_X, N_Y):
        fx, fy = np.mgrid[(-N_X//2):(N_X+1)//2, (-N_Y//2):(N_Y+1)//2]
        fx, fy = fx*1./N_X, fy*1./N_Y
        self.f_x = fx
        self.f_y = fy
        
    # Fourier number crunching
    def invert(self, FT_image, full=False):
        if full:
            return ifft2(ifftshift(FT_image))
        else:
            return ifft2(ifftshift(FT_image)).real
    def fourier(self, image, full=True):
        """
        Using the ``fourierr`` function, it is easy to retieve its Fourier transformation.
        """
        FT = fftshift(fft2(image))
        if full:
            return FT
        else:
            return np.absolute(FT)

    def FTfilter(self, image, FT_filter, full=False):
        """
        Using the ``FTfilter`` function, it is easy to filter an image with a filter defined in Fourier space.
        """
        FT_image = self.fourier(image, full=True) * FT_filter
        return self.invert(FT_image, full=full)

    def trans(self, u, v):
        return np.exp(-1j*2*np.pi*(u*self.f_x + v*self.f_y))

    def translate(self, image, vec, preshift=True):
        """
        Translate image by vec (in pixels)
        Note that the convention for coordinates follows that of matrices: the origin is at the top left of the image, and coordinates are first the rows (vertical axis, going down) then the columns (horizontal axis, going right).
        """
        u, v = vec
        u, v = u * 1., v * 1.

        if preshift:
            # first translate by the integer value
            image = np.roll(np.roll(image, int(u), axis=0), int(v), axis=1)
            u -= int(u)
            v -= int(v)

        # sub-pixel translation
        return self.FTfilter(image, self.trans(u, v))
# define object
slip = Image(N_X_data, N_Y_data)

In [51]:
if do_figures:
    # move the image up (<0) and right (>0)
    im_translated = slip.translate(IMAGES[:, :, index_image], (-10.21, 34.5))
    # im_translated = slip.translate(IMAGES[:, :, index_image], (0., 0.))

    fig, ax = plt.subplots(figsize=(fig_width, fig_width))
    ax.imshow(im_translated, cmap=plt.gray(), interpolation='nearest')
    plt.subplots_adjust(left=0.1, bottom=0.1, right=0.95, top=0.9)

In [52]:
def rectif(z_in, contrast=1., method='Michelson', verbose=False):
    """
    Transforms an image (can be 1, 2 or 3D) with normal histogram into
    a 0.5 centered image of determined contrast
    method is either 'Michelson' or 'Energy'
    Phase randomization takes any image and turns it into Gaussian-distributed
    noise of the same power (or, equivalently, variance).
    # See: Peter J. Bex J. Opt. Soc. Am. A/Vol. 19, No. 6/June 2002 Spatial
    frequency, phase, and the contrast of natural images

    adapted from 
    https://github.com/NeuralEnsemble/MotionClouds/blob/bf266726ea44bb70efe953fd764ec16744201b5f/MotionClouds/MotionClouds.py#L73
    """
    z = z_in.copy()
    # Final rectification
    if verbose:
        print('Before Rectification of the frames')
        print( 'Mean=', np.mean(z[:]), ', std=', np.std(z[:]), ', Min=', np.min(z[:]), ', Max=', np.max(z[:]), ' Abs(Max)=', np.max(np.abs(z[:])))

    z -= np.mean(z[:]) # this should be true *on average* in MotionClouds

    if (method == 'Michelson'):
        z = (.5* z/np.max(np.abs(z[:]))* contrast + .5)
    else:
        z = (.5* z/np.std(z[:])  * contrast + .5)

    if verbose:
        print('After Rectification of the frames')
        print('Mean=', np.mean(z[:]), ', std=', np.std(z[:]), ', Min=', np.min(z[:]), ', Max=', np.max(z[:]))
        print('percentage pixels clipped=', np.sum(np.abs(z[:])>1.)*100/z.size)
    return z

### make it move

In [53]:
def make_trajectory(w, margin=5):
    """
    
    w: world model containing parameters and functions
    margin: an additional margin to avoid including the border of an image

    this function assumes the trajectory will not cross borders. 

    """

    # the sequence of speeds
    V_Xs = -w.V_speeds * np.cos(w.V_phis)
    V_Ys = -w.V_speeds * np.sin(w.V_phis)

    # the random walk computed as the accumulation of steps
    xs = np.cumsum(V_Xs)
    ys = np.cumsum(V_Ys)

    # drawing a random position for the upper left corner
    # the range of starting position is such that the ending position is in the image
    np.random.seed(w.seed)
    x_0 = np.random.randint(xs[0]-xs.min() + margin, N_X_data - w.N_X - xs.max() - margin)
    y_0 = np.random.randint(ys[0]-ys.min() + margin, N_Y_data - w.N_Y - ys.max() - margin)

    # coordinates of the top, left of the bounding box
    return x_0 + xs, y_0 + ys

In [54]:
for seed_ in range(8601, 18720):
    w = NatWorld(seed=seed_)
    xs, ys = make_trajectory(w)

In [55]:
w.N_Y, w.N_X

(128, 128)

In [56]:
from matplotlib.colors import hsv_to_rgb
hsv_to_rgb((.1, 1, 1))

array([1. , 0.6, 0. ])

In [57]:
def plot_trajectory(w, color='blue', alpha=.6, linewidth=1.,
                    fig=None, ax=None, do_labels=False, fig_width=fig_width):
    if fig is None: fig, ax = plt.subplots(1, 1, figsize=(fig_width, fig_width))

    # coordinates of the top, left of the bounding box
    xs, ys = make_trajectory(w)
    if (xs.max()+w.N_X>N_X_data) or (ys.max()+w.N_Y>N_Y_data):
        print(xs.max(), ys.max())
    if (xs.min()<0) or (ys.min()<0):
        print(xs.min(), ys.min())

    ax.plot(ys, xs, lw=linewidth, color=color, alpha=alpha)
    ax.plot(ys+w.N_Y, xs+w.N_X, lw=linewidth, color=color, alpha=alpha)

    ax.plot([N_Y_data/2, N_Y_data/2+w.N_Y], [N_X_data/2, N_X_data/2], 'r', lw=2)
    ax.plot([N_Y_data/2, N_Y_data/2], [N_X_data/2, N_X_data/2+w.N_X], 'b', lw=2)
    ax.plot([N_Y_data/2], [N_X_data/2], 'w+', lw=6)
    if do_labels:
        ax.set_ylabel('X')
        ax.set_xlabel('Y')
    ax.set_xlim(0, N_X_data)
    ax.set_ylim(0, N_Y_data)
    ax.invert_yaxis()
    #ax.axis('scaled')
    return fig, ax

if do_figures:

    fig, ax = plot_trajectory(w, do_labels=True)
    ax.set_facecolor((.0, 0., 0.))
    for _ in range(180):
        w = NatWorld(seed=seed+_)
        fig, ax = plot_trajectory(w, color=hsv_to_rgb((np.random.rand(), 1, 1)), fig=fig, ax=ax, do_labels=True)


In [58]:
xs

array([512.54556182, 512.54556182, 512.54556182, 512.54556182,
       512.54556182, 512.54556182, 512.54556182, 512.54556182,
       512.54556182, 512.54556182, 512.54556182, 512.54556182,
       512.54556182, 512.54556182, 512.54556182, 512.54556182,
       512.42056182, 512.29556182, 512.17056182, 512.04556182,
       511.92056182, 511.79556182, 511.67056182, 511.54556182,
       511.42056182, 511.73554208, 512.05052234, 512.36550261,
       512.68048287, 512.99546313, 513.31044339, 513.62542366,
       513.94040392, 514.25538418, 514.57036444, 514.88534471,
       515.20032497, 515.51530523, 515.83028549, 516.14526576,
       516.46024602, 516.77522628, 517.09020654, 517.4051868 ,
       517.72016707, 518.03514733, 518.35012759, 518.66510785,
       518.98008812, 519.29506838, 519.61004864, 519.9250289 ,
       520.24000917, 520.55498943, 520.86996969, 521.18494995,
       521.49993022, 521.81491048, 522.12989074, 522.444871  ,
       522.75985127, 523.07483153, 523.38981179, 523.70

In [59]:
def make_nat_movie(w):

    # coordinates of the top, left of the bounding box
    xs, ys = make_trajectory(w=w)
    
    # draw the image from the database
    image = IMAGES[:, :, w.i_image]

    #image += w.noise * np.random.randn(N_X_data, N_X_data)

    movie = np.zeros((w.N_X, w.N_Y, w.N_T))
    for i_T in range(w.N_T):
        movie[:, :, i_T] = slip.translate(image, 
                                          (-xs[i_T], -ys[i_T]))[:w.N_X, :w.N_Y]

    movie = 2 * rectif(movie) - 1
    movie += w.noise * np.random.randn(w.N_X, w.N_Y, w.N_T)
    movie = np.clip(movie, -1, 1)

    return movie

In [60]:
w = NatWorld(seed=seed)
movie = make_nat_movie(w)
movie.shape, movie.min(),  movie.mean(), movie.max()#, movie.min(axis=(0,1)), movie.max(axis=(0,1))

((128, 128, 200),
 -0.7847192158239998,
 -1.4763959152841936e-08,
 0.998922494720195)

In [61]:
if do_figures:
    N_show, N_repet = 5, 14 #* DEBUG

    for i_repet in range(N_repet):
        fig, axs = plt.subplots(1, N_show, figsize=(fig_width, fig_width/N_show))

        w = NatWorld(seed=seed_show+i_repet)
        movie = make_nat_movie(w)
        vmax = np.abs(movie).max()

        i_step = 4
        for i_T in range(N_show):
            ax = axs[i_T]
            ax.imshow(movie[:, :, timestamp+i_T*i_step], cmap=plt.gray(), interpolation='nearest', vmin=-vmax, vmax=vmax)
            ax.text(0, -5, f'@t={timestamp+i_T*i_step}/{w.label[timestamp+i_T*i_step]}: V_phi={w.V_phis[timestamp+i_T*i_step]*180/np.pi:.1f}, V_speed={w.V_speeds[timestamp+i_T*i_step]:.1f}', fontsize=7)
            ax.set_xticks([])
            ax.set_yticks([])
        plt.show()


### generative model from movies to events

In [62]:
# np.diff?

In [63]:
movie_ones = np.ones((4, 4, 10))
movie_ones[:, :, 3] = 0
It = np.diff(movie_ones, axis=-1)
movie_ones[0, 0, :], It[0, 0, :], movie_ones.shape, It.shape

(array([1., 1., 1., 0., 1., 1., 1., 1., 1., 1.]),
 array([ 0.,  0., -1.,  1.,  0.,  0.,  0.,  0.,  0.]),
 (4, 4, 10),
 (4, 4, 9))

In [64]:
movie_ones = np.ones((4, 4, 10))
movie_ones[:, :, 3] = 0
It = np.diff(movie_ones, axis=-1, prepend=np.ones((4, 4, 1)))
movie_ones[0, 0, :], It[0, 0, :], movie_ones.shape, It.shape

(array([1., 1., 1., 0., 1., 1., 1., 1., 1., 1.]),
 array([ 0.,  0.,  0., -1.,  1.,  0.,  0.,  0.,  0.,  0.]),
 (4, 4, 10),
 (4, 4, 10))

In [65]:
movie_ones = np.ones((4, 4, 10))
movie_ones[:, :, 3] = 0
It = np.diff(movie_ones, axis=-1, n=2)
It[0, 0, :]

array([ 0., -1.,  2., -1.,  0.,  0.,  0.,  0.])

In [66]:
# It = np.diff(movie, axis=-1)
# threshold = selectivity * It.std()
# threshold, selectivity, It.min(), It.max()

In [67]:
def make_events(movie, selectivity=selectivity, do_event=do_event):
    if do_event:
        N_X, N_Y, N_T = movie.shape
        It = np.diff(movie, axis=-1)
        threshold = selectivity * It.std()

        It_full = np.zeros((N_X, N_Y, N_T))
        It_full[:, :, 1:] = It.copy()

        movie_spike = np.zeros((N_X, N_Y, N_T))
        for i_T in range(1, N_T):
            It_full[:, :, i_T] -= threshold * movie_spike[:, :, i_T-1]
            movie_spike[:, :, i_T] = 1. * (It_full[:, :, i_T] > threshold) - 1. * (It_full[:, :, i_T] < -threshold)
        return movie_spike
    else:
        return movie

### sum up and plot

In [68]:
def plot_cloud_events(timestamp, movie, It_bool, V_speeds, V_phis, N_show=5, fig=None, axs=None, fig_width=fig_width, type='Natural'):

    # two, N_T, N_X, N_Y = It_bool.shape
    N_columns = 3 if N_show==1 else N_show+1
    if fig is None: fig, axs = plt.subplots(1, N_columns, figsize=(fig_width, fig_width/(N_columns+.3)))

    # draw image
    axs[0].imshow(movie[:, :, timestamp], cmap=plt.gray(), vmin=-1, vmax=1, interpolation='nearest')
    if N_show==1: 
        axs[0].set_title(f'{type} stimulus at t={timestamp}')
        fig, axs[1] = plot_speedspace(V_speeds[timestamp], V_phis[timestamp], fig=fig, ax=axs[1])
        axs[1].axis('equal')
        axs[1].set_title('Direction of motion (red arrow)')
        #axs[1].text(1, 5, f't={timestamp}')
        axs[2].imshow(-It_bool[0, timestamp, :, :]+It_bool[1, timestamp, :, :], 
                      cmap=plt.cm.seismic, vmin=-1, vmax=1, interpolation='nearest')
        axs[2].set_title('ON (red) and OFF (blue) events')

    else:
        for i_T in range(N_show):
            ax = axs[i_T+1]
            ax.imshow(-It_bool[0, timestamp+i_T, :, :]+It_bool[1, timestamp+i_T, :, :], 
                      cmap=plt.cm.seismic, vmin=-1, vmax=1, interpolation='nearest')
            ax.text(1, 5, f't={timestamp+i_T}', backgroundcolor='white')

            ax_inset = ax.inset_axes([0., 0., 0.2, 0.2])
            fig, ax_inset = plot_speedspace(V_speeds[timestamp+i_T], V_phis[timestamp+i_T],
                                            dot_width=25/N_show, fig=fig, ax=ax_inset)
            ax_inset.set_xticks([])
            ax_inset.set_yticks([])

    for ax in axs.ravel():
        ax.set_xticks([])
        ax.set_yticks([])

    margin = 0.1 if N_show==1 else 0.0
    fig.subplots_adjust(left=margin, bottom=margin/N_columns, right=1-margin, top=1-margin, hspace=.01, wspace=0.)#-margin/N_columns)
    return fig, axs

In [69]:
w.V_speeds, w.V_phis, len(w.V_phis)

(array([0.39685026, 0.39685026, 0.39685026, 0.39685026, 0.39685026,
        0.39685026, 0.39685026, 0.39685026, 0.39685026, 0.39685026,
        0.39685026, 0.39685026, 0.39685026, 0.39685026, 0.39685026,
        0.39685026, 0.39685026, 0.25      , 0.25      , 0.25      ,
        0.25      , 0.25      , 0.25      , 0.25      , 0.25      ,
        0.25      , 0.25      , 0.25      , 0.25      , 0.25      ,
        0.62996052, 0.62996052, 0.62996052, 0.62996052, 0.62996052,
        0.62996052, 0.62996052, 0.62996052, 0.62996052, 0.62996052,
        0.62996052, 0.62996052, 0.62996052, 0.62996052, 0.62996052,
        0.62996052, 0.25      , 0.25      , 0.25      , 0.25      ,
        0.25      , 0.25      , 0.25      , 0.25      , 0.25      ,
        0.25      , 0.25      , 0.25      , 0.25      , 0.25      ,
        0.25      , 0.25      , 0.25      , 0.25      , 0.25      ,
        0.25      , 0.25      , 0.25      , 0.25      , 0.25      ,
        0.25      , 0.25      , 0.25      , 0.25

In [70]:
def make_natmovie_events(w, do_show=False, N_show=5, fname=None, timestamp=timestamp):

    movie = make_nat_movie(w)
    movie_spike = make_events(movie, selectivity=w.selectivity, do_event=w.do_event)

    # convert to pytorch format
    It_bool = np.zeros((w.N_X, w.N_Y, w.N_T, 2))
    It_bool[movie_spike==-1, 0] = 1
    It_bool[movie_spike==1, 1] = 1

    It_bool = torch.from_numpy(It_bool)
    It_bool = torch.swapaxes(It_bool, 2, 1)
    It_bool = torch.swapaxes(It_bool, 3, 2)
    It_bool = torch.swapaxes(It_bool, 2, 0)

    if not fname is None:
        from matplotlib.animation import FFMpegWriter
        writer = FFMpegWriter(fps=15)
        fig, axs = plot_cloud_events(0, movie, It_bool, w.V_speeds, w.V_phis, N_show=1)
        with writer.saving(fig, fname, w.N_T-1):
            for timestamp_ in range(1, w.N_T-1):
                for ax in axs: ax.cla()
                fig, axs = plot_cloud_events(timestamp_, movie, It_bool, w.V_speeds, w.V_phis, N_show=1, fig=fig, axs=axs)
                writer.grab_frame()

    if do_show:
        fig, axs = plot_cloud_events(timestamp, movie, It_bool, w.V_speeds, w.V_phis, N_show=N_show)
        return fig, axs
    else:
        return It_bool.float()

In [71]:
if do_figures:
    w = NatWorld(seed=seed_show)
    fig, ax = make_natmovie_events(w, N_show=1, do_show=True);

In [72]:
w.i_image

41

In [73]:
if do_figures:
    w = NatWorld(seed=seed_show, block_length=np.Inf)
    w.draw(seed=seed_show, label=label_down_by_one * np.ones_like(w.label))
    print(w.label)
    fig, axs = make_natmovie_events(w, N_show=5, do_show=True);

In [74]:
w.i_image

41

In [75]:
if do_figures:
    if DEBUG<8: fig.savefig(os.path.join(figures, datetag + '_input.png'))


In [76]:
w = NatWorld(seed=seed_show)
# w.label[timestamp:], N_show

In [77]:
w.label[58:]

array([16, 16, 16, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8, 17, 17, 17, 17,
       17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
       17, 17, 17, 17, 17, 17])

In [78]:
if do_figures:
    # label = np.zeros(N_T, dtype=int)
    # label[:20] = label_down_by_one
    # label[20:] = 30
    # w = NatWorld(seed=seed_show, selectivity=2.0, label=label)    
    w = NatWorld(seed=seed_show, selectivity=1.0)
    N_show = 5
    timestamp_show = 58
    print(w.label[timestamp_show:(timestamp_show+N_show)], w.label, w.label.shape, N_show)
    xs, ys = make_trajectory(w)
    xs_ = -xs#[timestamp_show:]
    ys_ = -ys#[timestamp_show:]
    xs_ = xs_ - xs_.mean()
    ys_ = ys_ - ys_.mean()
    xs_ = xs_/np.abs(xs_).max()
    ys_ = ys_/np.abs(ys_).max()
    xs_ = (xs_/1.+2.5)/2.5
    ys_ = (ys_/1.+1.5)/1.5
    # xs_ = N_X - xs_* N_X/2 #+ 15
    # ys_ = N_Y - ys_* N_Y/2 # + 22
    xs_ = xs_* N_X/2 #+ 15
    ys_ = ys_* N_Y/2 # + 22
    #label = w.label.copy()
    #label[6:12] = 112
    #label[13:15] = 118
    #w.draw(seed=201100, label=label)
    # w.label[18:35]

    fig, axs = make_natmovie_events(w, N_show=N_show, do_show=True, timestamp=timestamp_show)
    axs[0].plot(ys_, xs_, lw=7, color='w', alpha=1.)
    axs[0].plot(ys_, xs_, lw=4, color='g', alpha=1.)
    axs[0].plot(ys_, xs_, lw=4, color='g', alpha=1.)
    axs[0].plot(ys_[timestamp_show-5], xs_[timestamp_show-5], '.', ms=4, color='w', alpha=1.)
    axs[0].plot(ys_[timestamp_show+5], xs_[timestamp_show+5], '.', ms=4, color='k', alpha=1.)
    fig.set_figwidth(fig.get_figwidth()*1.1)
    fig.set_figheight(fig.get_figheight()*1.3)
    axs[1].annotate('time', xy=(0.5, -0.1), xycoords='axes fraction', xytext=(4.5, -0.1), 
                arrowprops=dict(arrowstyle="<-", color='b'), fontsize=16)
    fig


In [79]:
if do_figures:
    # figpath
    if figpath != None: fig.savefig(os.path.join(figpath, 'motion_task.pdf'), bbox_inches='tight')

https://matplotlib.org/stable/gallery/animation/dynamic_image.html

In [80]:
w.V_speeds.shape

(200,)

In [81]:
if do_figures:
    fname = os.path.join(figures, datetag + '_input.mp4')
    if not os.path.isfile(fname):
        make_natmovie_events(w, do_show=False, fname=fname)

In [82]:
if do_figures:

    N_seq = 8
    np.random.seed(2022)
    for i_seq in range(N_seq):
        w = NatWorld(seed=seed+i_seq)
        # to debug: w.draw(seed=seed, label=label_down_by_one * np.ones_like(w.label))

        fname = os.path.join(figures, datetag + '_input_' + str(i_seq) + '.png')

        if not os.path.isfile(fname):
            fig, axs = make_natmovie_events(w, N_show=5, do_show=True, fname=None);
            if DEBUG<8: fig.savefig(fname, bbox_inches='tight')

        fname_mp4 = os.path.join(figures, datetag + '_input_' + str(i_seq) + '.mp4')
        if not os.path.isfile(fname_mp4):
            make_natmovie_events(w, N_show=1, do_show=False, fname=fname_mp4);


### Control the input firing frequency:

In [83]:
movie_spike = make_natmovie_events(w)
np.prod(movie_spike.shape), N_X * N_Y * N_T, (movie_spike!= 0).sum()

(6553600, 3276800, tensor(366955))

In [84]:
N_test_fr = 30 if DEBUG<8 else 10
path_results = cachepath(data_cache, datetag, DEBUG) + '_firing_frequency.npy'
try:
    moduls = np.logspace(-1, 1, N_scan, base=10, endpoint=True)
    firing_frequency = np.load(path_results)
except:
    firing_frequency = np.zeros((N_scan, N_test_fr))
    for i_test in range(N_test_fr):
        w = NatWorld(seed=seed+i_test)
        movie = make_nat_movie(w)
        for i_modul, modul in enumerate(moduls):
            w.selectivity = modul*selectivity
            movie_spike = make_natmovie_events(w)
            firing_frequency[i_modul, i_test] = (movie_spike!= 0).sum() / np.prod(movie_spike.shape)
    np.save(path_results, firing_frequency)

Firing frequency per voxel, and average input spikes in a kernel

In [85]:
firing_frequency.mean(axis=-1), firing_frequency.mean(axis=-1)*np.prod(kernel_size), np.prod(kernel_size)

(array([3.94208887e-01, 2.79820303e-01, 1.33690165e-01, 3.24902292e-02,
        2.84411113e-03, 8.14361569e-05]),
 array([2.39245374e+03, 1.69822942e+03, 8.11365613e+02, 1.97183201e+02,
        1.72609105e+01, 4.94236036e-01]),
 6069)

In [86]:
if do_figures:
    fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi/2))
    ax.plot(moduls*selectivity, firing_frequency.mean(axis=-1), lw=2, marker='.', markersize=10)
    ax.spines['left'].set_position(('axes', -0.01))
    ax.grid(which='both')
    for side in ['top', 'right'] :ax.spines[side].set_visible(False)
    ax.set_xlabel('selectivity')
    ax.set_ylabel('mean firing_frequency')
    ax.set_xscale('log')
    ax.set_yscale('log');
    #ax.set_ylim(0.);

## Supervised motion detection

### Detection model


We can easily describe this model with a 3D convolution layer defined in PyTorch. Dimensions of the 3D kernels correspond to the spatial position of the synapse and the temporal dimension to represent the delays. 


#### masking kernels

In [87]:
torch.linspace(0, 1, steps=kernel_size[0])

tensor([0.0000, 0.0500, 0.1000, 0.1500, 0.2000, 0.2500, 0.3000, 0.3500, 0.4000,
        0.4500, 0.5000, 0.5500, 0.6000, 0.6500, 0.7000, 0.7500, 0.8000, 0.8500,
        0.9000, 0.9500, 1.0000])

In [88]:
def create_mask(kernel_size, mask_exponent=4, r_0=.05, V_max=V_max):
    # dimensions of kernel_size along D, H, W
    ts = torch.linspace(0, 1, steps=kernel_size[0])
    xs = torch.linspace(-1, 1, steps=kernel_size[1]+2)[1:-1]
    ys = torch.linspace(-1, 1, steps=kernel_size[2]+2)[1:-1]
    # https://pytorch.org/docs/stable/generated/torch.meshgrid.html#torch.meshgrid
    t, x, y = torch.meshgrid(ts, xs, ys, indexing='ij')
    r = torch.sqrt(x **2 + y **2)

    # circular mask 
    # https://github.com/bicv/SLIP/blob/master/SLIP/SLIP.py#L220
    filter_mask = ((np.cos(np.pi*r)+1)/2  * (r < 1.))**(1./mask_exponent)

    # cone
    filter_mask *= np.exp(-.5 * r**2 / ( V_max * (1-t) + r_0)**2 )
    return filter_mask

filter_mask = create_mask(kernel_size)
#filter_mask = torch.ones(kernel_size)
filter_mask.shape

torch.Size([21, 17, 17])

In [89]:
if do_figures:
    fig, axs = plt.subplots(1, 2, figsize=(fig_width, fig_width/phi))
    axs[0].matshow(filter_mask.mean(axis=0))
    axs[0].set_xlabel('X')
    axs[0].set_ylabel('Y')
    axs[1].plot(filter_mask.mean(axis=(1, 2)))
    axs[1].set_ylim(0)
    axs[1].set_box_aspect(1)
    axs[1].set_xlabel('T')
    plt.tight_layout();

In [90]:
filter_mask[0, :, :]

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0914, 0.2631, 0.3395, 0.3795, 0.3921,
         0.3795, 0.3395, 0.2631, 0.0914, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2455, 0.3795, 0.4629, 0.5171, 0.5481, 0.5582,
         0.5481, 0.5171, 0.4629, 0.3795, 0.2455, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.2797, 0.4285, 0.5275, 0.5979, 0.6459, 0.6741, 0.6833,
         0.6741, 0.6459, 0.5979, 0.5275, 0.4285, 0.2797, 0.0000, 0.0000],
        [0.0000, 0.2455, 0.4285, 0.5481, 0.6365, 0.7018, 0.7472, 0.7740, 0.7829,
         0.7740, 0.7472, 0.7018, 0.6365, 0.5481, 0.4285, 0.2455, 0.0000],
        [0.0914, 0.3795, 0.5275, 0.6365, 0.7200, 0.7829, 0.8271, 0.8533, 0.8620,
         0.8533, 0.8271, 0.7829, 0.7200, 0.6365, 0.5275, 0.3795, 0.0914],
        [0.2631, 0.4629, 0.5979, 0.7018, 0.7829, 0.8446, 0.8881, 0.9141, 0.9227,
         0.9141, 0.8881, 0.8446, 0.7829, 0.7018, 0.5979, 0.4629, 0.2631],
        [0.3395, 0.5171, 0.6459, 0.7472, 0.8271, 0.8881, 0.9313, 0.957

In [91]:
if do_figures:
    fig, axs = plt.subplots(1, 2, figsize=(fig_width, fig_width/phi))
    axs[0].matshow(filter_mask[0, :, :])
    axs[0].set_xlabel('X')
    axs[0].set_ylabel('Y')
    axs[1].matshow(filter_mask[-1, :, :])
    axs[1].set_xlabel('X')
    axs[1].set_ylabel('Y')
    plt.tight_layout();

In [92]:
if do_figures:
    fig, axs = plt.subplots(1, kernel_size[0], figsize=(fig_width, fig_width/kernel_size[0]))
    for i_t, ax in enumerate(axs):
        ax.matshow(filter_mask[i_t, :, :])

In [93]:
filter_mask.shape

torch.Size([21, 17, 17])

In [94]:
filter_mask[-1, kernel_size[1]//2, kernel_size[2]//2]

tensor(1.)

### Training

We then perform a supervised training on motion detection with the previously created dataset. 

#### detection model as a logistic regression

https://pytorch.org/docs/stable/generated/torch.nn.functional.conv3d.html

In [95]:
movie_ones = torch.zeros((1, 2, 17, 4, 4))
movie_ones[0, 0, 8, :, :] = 1.
kernel = torch.zeros((5, 2, kernel_size[0], kernel_size[1], kernel_size[2]))
kernel[0, 0, -1, kernel_size[1]//2, kernel_size[2]//2] = 1

It = torch.nn.functional.conv3d(movie_ones, kernel, padding=(kernel_size[0]//2, kernel_size[1]//2, kernel_size[2]//2))
movie_ones[0, 0, :, 0, 0], It[0, 0, :, 0, 0], movie_ones.shape, It.shape

(tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 torch.Size([1, 2, 17, 4, 4]),
 torch.Size([1, 5, 17, 4, 4]))

In [96]:
It = torch.nn.functional.conv3d(movie_ones, kernel, padding=(kernel_size[0]//2, kernel_size[1]//2, kernel_size[2]//2))
movie_ones[0, 0, :, 0, 0], It[0, 0, :, 0, 0], movie_ones.shape, It.shape

(tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 torch.Size([1, 2, 17, 4, 4]),
 torch.Size([1, 5, 17, 4, 4]))

In [97]:
kernel_size[0]//2

10

In [98]:
It = torch.roll(It, kernel_size[0]//2, dims=2) # fix the roll
movie_ones[0, 0, :, 0, 0], It[0, 0, :, 0, 0], movie_ones.shape, It.shape

(tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 torch.Size([1, 2, 17, 4, 4]),
 torch.Size([1, 5, 17, 4, 4]))

In [99]:
class Net(torch.nn.Module):
    def __init__(self, kernel_size, N_PGs, N_pola=N_pola, do_mask=do_mask, #do_softmax=do_softmax, 
                 do_bias=do_bias, weight_init=weight_init, weight_init_center=weight_init_center): 
        super(Net, self).__init__()
        self.do_mask = do_mask
        filter_mask = create_mask(kernel_size)
        padding = (kernel_size[0]//2, kernel_size[1]//2, kernel_size[2]//2)
        # https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html?highlight=conv3d#torch.nn.Conv3d
        # dimensions of kernel_size along D=time, H, W
        self.conv_layer = torch.nn.Conv3d(N_pola, N_PGs, kernel_size=kernel_size, 
                                          padding=padding, padding_mode='zeros', bias=do_bias)
        weight_init = weight_init * torch.randn_like(self.conv_layer.weight) * filter_mask
        weight_init[:, 1, -1, kernel_size[1]//2, kernel_size[2]//2] = weight_init_center
        weight_init[:, 0, -1, kernel_size[1]//2, kernel_size[2]//2] = -weight_init_center 
        
        self.conv_layer.weight = torch.nn.Parameter(weight_init, requires_grad=True)

        # if not(do_softmax is None):
        #     if do_softmax:
        #         self.sigma = torch.nn.Softmax(dim=1)
        #     else:
        #         self.sigma = torch.nn.Sigmoid()
        # https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.register_buffer
        self.register_buffer('filter_mask', filter_mask)
        self.kernel_size = kernel_size

    def forward(self, a):
        if self.do_mask: self._mask_conv_filter()
        logit_b = self.conv_layer(a)
        logit_b = torch.roll(logit_b, self.kernel_size[0]//2, dims=2)
        return logit_b
        # if do_softmax is None:
        #     return logit_b
        # else:
        #     return self.sigma(logit_b)

    def _mask_conv_filter(self):
        with torch.no_grad():
            # https://pytorch.org/docs/master/generated/torch.clamp.html
            max_weight = torch.abs(self.conv_layer.weight).max()
            self.conv_layer.weight.data.clamp_(min=-self.filter_mask*max_weight, 
                                               max=self.filter_mask*max_weight)

#### learning routine

In [100]:
# https://pytorch.org/docs/master/generated/torch.nn.functional.cross_entropy.html
cross_entropy = torch.nn.functional.cross_entropy

def learn_model(path=None, w=w, 
                kernel_size=kernel_size, N_train=N_train, N_epochs=N_epochs,
                lr=lr, N_show_every=N_show_every, 
                p_B=p_B, weight_init=weight_init, max_quant=10000000, 
                weight_init_center=weight_init_center, do_cache=True,
                do_adam=do_adam, do_bias=do_bias, do_mask=do_mask, verbose=False):

                    
    model = Net(kernel_size, w.N_PGs, N_pola=N_pola, do_mask=do_mask, do_bias=do_bias, 
                weight_init=weight_init, weight_init_center=weight_init_center)
    model = model.to(device)
    
    if os.path.isfile(path + '.pth'):
        model.load_state_dict(torch.load(path + '.pth', map_location=torch.device(device)))
        df_train = pd.read_csv(path + '.csv')

    elif os.path.isfile(path + '.lock'):
        # we want to have a file but it's locked
        print(f'Path {path} is locked')
        return None, None 

    else:
        if verbose: print('Starting the learning at', path + '.pth')
        # either we do not need a file or it does not exist (or it's locked)
        touch(path + '.lock') # we want to have a file let's lock it        

        df_train = pd.DataFrame([], columns=['correct', 'loss', 'time', 'ETA'], dtype=float)

        if do_adam:
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        else:
            # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
            optimizer = torch.optim.SGD(model.parameters(), lr=lr)

        tic = time.time()
        seed_train_ = w.seed
        # transform events in the loader to dense tensor
        for i_train in range(N_train):
            w.draw(seed=seed_train_ + i_train % (N_train//N_epochs))
    
            # one-hot encoding of the labels
            b_target = torch.zeros((w.N_PGs, w.N_T, w.N_X, w.N_Y), device=device, requires_grad=False)
            for i_T in range(w.N_T):
                b_target[w.label[i_T], i_T, :, :] = 1

            # generate the input
            It_bool = w.get_input(do_cache=do_cache)
            It_bool = It_bool.to(device)

            # flip the ON/OFF bit
            It_bool_flip = It_bool[[1, 0], ...]

            # learn
            optimizer.zero_grad()
            logit_b_hat = model(It_bool.unsqueeze(0)).squeeze(0)
            # TODO: remonter cette partie dans le modele ?
            logit_b_hat_flip = model(It_bool_flip.unsqueeze(0)).squeeze(0) # flipping ON and OFF events
            logit_b_hat_max = torch.maximum(logit_b_hat, logit_b_hat_flip) # element-wise maximum

            # sparse selection of coefficients which should give a detection
            logit_b_hat_top, ind_top = torch.max(logit_b_hat_max, dim=0)
            if p_B<1:
                # using a winner takes all on the dimension of hypothesis
                logit_b_hat_top = logit_b_hat_top * torch.ones((w.N_PGs, 1, 1, 1), device=device)
                # b_target[0, logit_b_hat_top < torch.quantile(logit_b_hat_top, 1-p_B)] = 1. / w.N_PGs

                if len(logit_b_hat_top.ravel()) > max_quant:
                    ind_quant = torch.randperm(len(logit_b_hat_top.ravel()))[:max_quant]
                    b_threshold = torch.quantile(logit_b_hat_top.ravel()[ind_quant], 1-p_B)
                else:
                    b_threshold = torch.quantile(logit_b_hat_top, 1.-p_B)

                top_of_the_top = logit_b_hat_top > b_threshold
                b_target = torch.where(top_of_the_top, b_target, 1./w.N_PGs * torch.ones_like(b_target))

            loss = cross_entropy(logit_b_hat_max.unsqueeze(0), b_target.unsqueeze(0), reduction='mean')

            loss.backward()
            optimizer.step()

            correct = (ind_top == b_target.squeeze(0)).float().mean().item()
                
            ETA = (time.time()-tic)/(i_train+1)*(N_train-i_train)
            df_train.loc[i_train] = {'correct':correct, 'loss':loss.item(), 'time':time.time()-tic, 'ETA':ETA}
            if verbose and (i_train % N_show_every == 0):
                print(f'Trial [{i_train:06d}/{N_train:06d}]\t correct:{correct:.3f}\t Loss: {loss.item():.3e}\t ETA {ETA:.3f} s')
            df_train.to_csv(path + '.csv')

        torch.save(model.state_dict(), path + '.pth')

    if os.path.isfile(path + '.lock'): os.remove(path + '.lock')
    return model, df_train

#### do the learning

In [101]:
w = NatWorld(seed=seed_train)
model, df_train = learn_model(path=cachepath(data_cache, datetag, DEBUG), w=w, kernel_size=kernel_size, N_train=N_train, N_epochs=N_epochs, lr=lr, verbose=True)

Starting the learning at cache/2023-09-22_FastMotionDetection/SL_DEBUG=1.pth
Trial [000000/080000]	 correct:0.010	 Loss: 3.571e+00	 ETA 945974.236 s
Trial [000800/080000]	 correct:0.001	 Loss: 3.264e+00	 ETA 3416394.223 s
Trial [001600/080000]	 correct:0.001	 Loss: 3.109e+00	 ETA 3495749.586 s
Trial [002400/080000]	 correct:0.001	 Loss: 3.041e+00	 ETA 3500660.490 s
Trial [003200/080000]	 correct:0.001	 Loss: 3.000e+00	 ETA 3476610.051 s
Trial [004000/080000]	 correct:0.001	 Loss: 2.972e+00	 ETA 3456222.960 s
Trial [004800/080000]	 correct:0.001	 Loss: 2.950e+00	 ETA 3429746.712 s
Trial [005600/080000]	 correct:0.001	 Loss: 2.933e+00	 ETA 3328919.703 s
Trial [006400/080000]	 correct:0.001	 Loss: 2.918e+00	 ETA 3198508.473 s
Trial [007200/080000]	 correct:0.001	 Loss: 2.906e+00	 ETA 3073780.281 s
Trial [008000/080000]	 correct:0.001	 Loss: 2.896e+00	 ETA 2901196.323 s
Trial [008800/080000]	 correct:0.001	 Loss: 2.887e+00	 ETA 2756830.973 s
Trial [009600/080000]	 correct:0.001	 Loss: 2.88

#### analyzing kernels

In [None]:
done_learning = not(model is None)

In [None]:
def plot_loss(df_train):
    fig_width, phi = 15, np.sqrt(5)/2 + 1/2
    fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi/2))
    #ax = df_train['loss'].plot(lw=0, marker='.', markersize=1, ax=ax)
    #ax = df_train['loss'].rolling(N_train//64).median().plot(lw=10, alpha=.4, ax=ax)
    ax = df_train['loss'].rolling(N_train//64).median().plot(lw=1, alpha=.9, ax=ax)
    ax.set_xlabel("Steps")
    ax.set_ylabel("Loss")
    ax.spines['left'].set_position(('axes', -0.01))
    ax.grid(which='both')
    for side in ['top', 'right']: ax.spines[side].set_visible(False)
    #ax.set_ylim(0.)
    ax.set_yscale('log');
    return fig, ax

if done_learning:
    if do_figures:
        plot_loss(df_train);

In [None]:
def plot_weight_distribution(kernels, alpha=.3):
    kernel_max =np.abs(kernels).max()
    bins=np.linspace(-kernel_max, kernel_max, 50)
    K, two, k_T, k_X, k_Y = kernels.shape
    fig, axs = plt.subplots(1, 2, figsize=(fig_width, fig_width/phi))
    for k in range(K):
        pos_kernels = kernels[k, 1, :, :, :]
        neg_kernels = kernels[k, 0, :, :, :]
        axs[0].hist(neg_kernels.ravel(), bins=bins, alpha=alpha)
        axs[0].set_xlabel('OFF polarities')
        axs[1].hist(pos_kernels.ravel(), bins=bins, alpha=alpha)
        axs[1].set_xlabel('ON polarities')
    return fig, axs

if done_learning:

    learned_kernels = model.conv_layer.weight.data.cpu().numpy()

    if do_figures:
        plot_weight_distribution(learned_kernels);

In [None]:
#import seaborn as sns
def plot_joint_weight_distribution(kernels, alpha=.3):
    kernel_max =np.abs(kernels).max()
    #bins=np.linspace(-kernel_max, kernel_max, 50)
    K, two, k_T, k_X, k_Y = kernels.shape
    fig, ax = plt.subplots(1, 1, figsize=(fig_width/phi, fig_width/phi))
    for k in range(K):
        pos_kernels = kernels[k, 1, :, :, :]
        neg_kernels = kernels[k, 0, :, :, :]
        #sns.histplot(x=pos_kernels.ravel(), y=neg_kernels.ravel(), bins=30,  pthresh=.05, pmax=.9, 
                     #    cbar=True, cbar_kws=dict(shrink=.75), 
        #             ax=ax)
        ax.scatter(pos_kernels.ravel(), neg_kernels.ravel(), alpha=.01)
        ax.set_xlabel('OFF polarities')
        ax.set_ylabel('ON polarities')
    return fig, ax

if done_learning and do_figures:
    fig, ax= plot_joint_weight_distribution(learned_kernels);

In [None]:
if done_learning and do_figures:
    if DEBUG<8: fig.savefig(os.path.join(figures, datetag + '_joint_ON-OFF.png'), bbox_inches='tight')

#### plot kernels

In [None]:
K_min = N_V_phi #+  N_V_phi * ((3*N_V_speed) // 4)
K_max = K_min + N_V_phi
K_min, K_max, N_V_phi, N_V_speed

In [None]:
def plot_kernels(kernels, space=0.05, K_min=K_min, K_max=K_max, k_T_skip=0,
                 skip_last=1, do_norm=True, do_ON_OFF=True):
    # https://matplotlib.org/stable/gallery/color/colormap_reference.html
    cmap_pos = 'RdBu_r'
    cmap_neg = 'RdGy_r'
    # cmap_pos = cmocean.cm.balance
    # cmap_neg = cmocean.cm.balance
    #cmap_pos = 'autumn_r'
    #cmap_neg = 'winter_r'
    kernels = kernels[:, :, :-skip_last, :, :]
    # kernels -= kernels.mean()
    K, N_pol, k_T, k_X, k_Y = kernels.shape
    
    N_pol = 2 if do_ON_OFF else 1
    fig, axs = plt.subplots((K_max-K_min)*N_pol, 1+k_T-k_T_skip, figsize=(fig_width, N_pol*(K_max-K_min)/(k_T-k_T_skip+1)*fig_width))
    #vmin, vmax = 0, 1
    # vmin, vmax = -np.abs(kernels).max(), np.abs(kernels).max()
    # for k in range(K):
    for k in range(K_min, K_max):
        for pol in range(N_pol):
            ax = axs[(k-K_min)*N_pol+pol][0]
            # vmin, vmax = -np.abs(kernels[k, pol, k_T_skip:k_T, ...]).max(), np.abs(kernels[k, pol, k_T_skip:k_T, ...]).max()
            kernel_ = kernels[k, 1-pol, ...]
            # vmax = np.abs(kernels[k, 1-pol, k_T_skip:k_T, ...]).max()
            kernel_[kernel_>0] /= 2.0* kernel_.max()
            kernel_[kernel_<0] /= -kernel_.min()

            # print(vmin, vmax)
            fig, ax = plot_speedspace(V_speeds_line[k], V_phis_line[k], fig=fig, ax=ax, alpha=.3)
            for t in range(k_T_skip, k_T):
                ax = axs[(k-K_min)*N_pol+pol][t-k_T_skip+1]
                #print(kernels[k, pol, t, :, :].min(), kernels[k, pol, t, :, :].max())
                # https://matplotlib.org/stable/tutorials/colors/colormaps.html#diverging
                # ax.imshow(kernels[k, 1-pol, t, :, :], interpolation='nearest', vmin=-vmax, vmax=vmax, cmap=cmap_neg if pol else cmap_pos)
                ax.imshow(kernel_[t, :, :], interpolation='nearest', vmin=-1, vmax=1, cmap=cmap_neg if pol else cmap_pos)
                # ax.contourf(kernels[k, 1-pol, t, :, :], vmin=vmin, vmax=vmax, cmap=cmap_neg if pol else cmap_pos)
                ax.axis('off')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
    plt.subplots_adjust(wspace=.001*k_T/K*space, hspace=space)
    # plt.tight_layout()
    return fig, axs

learned_kernels = model.conv_layer.weight.data.cpu().numpy()
print(learned_kernels.shape)
if done_learning:
    if do_figures:
        fig, axs = plot_kernels(learned_kernels);

In [None]:
if done_learning:
    if do_figures:
        fig, axs = plot_kernels(learned_kernels, K_min=0, K_max=N_PGs);

In [None]:
if do_figures:
    fig, axs = plot_kernels(learned_kernels,  k_T_skip=7, K_min=K_min, K_max=K_min+8, do_ON_OFF=False);

In [None]:
if done_learning:
    if do_figures:
        k_T_skip = 8
        fig, axs = plot_kernels(learned_kernels,  k_T_skip=k_T_skip, K_min=K_min, K_max=K_min+8, do_ON_OFF=False);
        axs[-1][1].annotate('delay', xy=(0, -0.25), xycoords='axes fraction', 
                            xytext=((kernel_size[0]-k_T_skip)*0.95 - .5, -0.25), 
                    arrowprops=dict(arrowstyle="->", color='b'), fontsize=16)
        fig

In [None]:
if done_learning:
    if do_figures:
        if DEBUG<8: fig.savefig(os.path.join(figures, datetag + '_kernel.png'), bbox_inches='tight')
        if figpath != None: fig.savefig(os.path.join(figpath, 'motion_kernels.pdf'), bbox_inches='tight')

#### Analyse activity

In [None]:
if done_learning:
    # label = np.zeros(N_T, dtype=int)
    # label[:32] = label_down_by_one
    # label[32:56] = 30
    # label[56:167] = label_up_by_one
    # label[167:] = 29
    # w = NatWorld(seed=seed+4, noise=0.02, label=label)
    w = NatWorld(seed=seed_show, selectivity=1., noise=0.)
    print(w.label)

    It_bool = w.get_input().to(device)
    
    # i_y = 6 #N_Y//4

    if do_figures:
        import matplotlib.patches as patches

        fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi))
        ax.imshow((It_bool[0, :, :, i_y]-It_bool[1, :, :, i_y]).numpy().T, cmap = 'RdBu_r')
        ax.set_xlabel("Time")
        ax.set_ylabel("X");

In [None]:
if done_learning:
    print(It_bool.shape)
    if do_figures:
        fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi))

        for i_pol, color in enumerate(['b', 'r']) :
            ax.eventplot([np.where(It_bool[i_pol, :, x, i_y] == 1.)[0] for x in range(0, N_X)], 
                          color=color, lineoffsets=1, linelengths=0.9)
        ax.set_xlabel("Time")
        ax.set_ylim(0, N_X)
        ax.invert_yaxis()
        ax.set_xlim(32, 96)
        ax.set_xlabel("Time")
        ax.set_yticklabels([])
        # ax.axis('equal')
        # ax.set_aspect('equal', 'box')

In [None]:
if done_learning:
    It_bool_flip = torch.roll(It_bool, 1, dims=0)
    logit_b_hat = model(It_bool.unsqueeze(0)).squeeze(0)
    logit_b_hat_flip = model(It_bool_flip.unsqueeze(0)).squeeze(0) # flipping ON and OFF events
    logit_b_hat_max = torch.maximum(logit_b_hat, logit_b_hat_flip) # element-wise maximum
    # b_hat_max = torch.softmax(logit_b_hat_max, dim=1)
    # logit_b_hat_max = torch.logit(b_hat_max)-torch.logit(torch.tensor(1/w.N_PGs))


In [None]:

if done_learning:
    logit_b_hat_max_np = logit_b_hat_max.cpu().detach().numpy()
    print(logit_b_hat_max_np.max(), logit_b_hat_max_np.min(), logit_b_hat_max_np.shape, logit_b_hat_max_np.max(axis=-1).max(axis=-1).shape)
    # b_hat_max_np = b_hat_max.squeeze(0).detach().numpy()
    # print(b_hat_max_np.max(), b_hat_max_np.min(), b_hat_max_np.shape, b_hat_max_np.max(axis=-1).max(axis=-1).shape)


In [None]:
It_bool.shape, It_bool_flip.shape, logit_b_hat.shape

In [None]:
logit_b_hat = model(It_bool.unsqueeze(0)).squeeze(0)

In [None]:
torch.cat((It_bool.unsqueeze(0), It_bool_flip.unsqueeze(0), It_bool.unsqueeze(0))).shape

In [None]:
if done_learning:
    if do_figures:
        idx = label_up_by_one
        
        fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi/2))
        ax.plot(logit_b_hat_max_np.max(axis=-1).max(axis=-1)[idx, :])
        ax.plot(np.hstack((np.diff(w.label)!=0, 0)))
        ax.plot(w.V_speeds)
        ax.set_xlabel("Time")
        ax.set_ylabel("Max activity");

In [None]:
if done_learning:
    if do_figures:

        fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi/2))
        ax.plot(logit_b_hat_max_np.max(axis=-1).max(axis=-1).T)
        ax.set_xlabel("Time")
        ax.set_ylabel("Max activity");

In [None]:
if done_learning:
    if do_figures:

        fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi/2))
        ax.plot(logit_b_hat_max_np.mean(axis=-1).mean(axis=-1).T)
        ax.set_xlabel("Time")
        ax.set_ylabel("Max activity")
        ax.legend();

In [None]:
if done_learning:
    if do_figures:
        fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi/2))
        ax.hist(logit_b_hat_max_np.ravel(), bins=100)
        ax.set_ylabel("distribution")
        ax.set_xlabel("output predicted logit probability")
        ax.set_yscale('log');

In [None]:
# if done_learning:
#     if do_figures:
#         # b_hat_np = 

#         np.tensordot(logit_b_hat_max_np[:, :, :, i_y].swapaxes(0, -1), color_bar, axes=1).shape
#         # logit_b_hat_max_np[:, :, :, i_y]
#         fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi))
#         ax.imshow(np.tensordot(logit_b_hat_max_np.swapaxes(0, -1), color_bar, axes=1))
#         ax.set_xlabel("Time")
#         ax.set_ylabel("X");

In [None]:
if done_learning:
    if do_figures:
        i_y = N_Y//2 + 4
        t_min, t_max = N_T//4, 3*N_T//4
        t_min, t_max = 0, N_T
        x_min, x_max = N_X//4, N_X//2
        x_min, x_max = 0, N_X


        fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi))
        logit_max = logit_b_hat_max_np[idx, :, :, i_y].max()
        ax.imshow(logit_b_hat_max_np[idx, :, :, i_y].T, vmin=-logit_max, vmax=logit_max, cmap = 'RdBu_r')
        ax.set_xlabel("Time")
        ax.set_xlim(t_min, t_max)
        ax.set_ylabel("X");


In [None]:
# logit_b_hat_top.shape, color_bar.shape, np.tensordot(logit_b_hat_top.swapaxes(0, -1), color_bar, axes=1).shape, np.tensordot(logit_b_hat_top.swapaxes(0, -1), color_bar, axes=1)[i_y, :, :, :].swapaxes(0, 1).shape

In [None]:
p_B_show = 1.e-2
p_B_show = 2.e-2
if done_learning:
    if do_figures:
        threshold = np.quantile(logit_b_hat_max_np, 1-p_B_show)
        logit_b_hat_top = logit_b_hat_max_np > threshold

        fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi))
        ax.imshow(np.tensordot(logit_b_hat_top.swapaxes(0, -1), color_bar, axes=1)[i_y, :, :, :].swapaxes(0, 1))
        # ax.imshow(logit_b_hat_top[idx, :, :, i_y].T, vmin=-logit_max, vmax=logit_max, cmap = 'RdBu_r')
        ax.set_xlim(t_min, t_max)
        ax.set_xlabel("Time")
        ax.set_ylabel("X");

In [None]:
# if done_learning:
#     logit_b_hat_top.swapaxes(0, -1)[i_y, :, :, :].shape, N_PGs, N_X

In [None]:
color_bar.shape, color_bar[label_down_by_one], color_bar[label_up_by_one]

In [None]:
if done_learning:
    if do_figures:
        cm = plt.get_cmap('brg')
        fig, ax = plt.subplots(figsize=(fig_width, fig_width))
        for label, color in zip([label_down_by_one, label_up_by_one], [color_bar[label_down_by_one], color_bar[label_up_by_one]]):
            for i_x in range(x_min, x_max):
                ax.plot(logit_b_hat_max_np[label, :, i_x, i_y]/logit_b_hat_max_np.max()+i_x*1.05, lw=0.3, c=color)
        ax.set_xlabel("Time")
        ax.set_ylabel("X")
        # ax.set_xlim(x_min, x_max)
        # ax.set_ylim(0, N_X)
        ax.invert_yaxis()

In [None]:
if done_learning:
    if do_figures:
        threshold = np.quantile(logit_b_hat_max_np, 1-p_B_show)
        logit_b_hat_top = logit_b_hat_max_np > threshold

        cm = plt.get_cmap('brg')
        fig, ax = plt.subplots(figsize=(fig_width, fig_width))
        slice = logit_b_hat_top.swapaxes(0, -1)[i_y, :, :, :]
        for i_class in range(0, N_PGs):
            ax.eventplot([np.where(slice[:, x, i_class] == 1.)[0] for x in range(x_min, x_max)], 
                          color=cm(1.*i_class/N_PGs), lineoffsets=1, linelengths=0.9)
        ax.set_xlabel("Time")
        ax.set_ylabel("X")
        ax.set_xlim(t_min, t_max)
        ax.set_ylim(0, N_X)
        ax.invert_yaxis()
        
        # colors = ['r', 'g' , 'b' , 'k']
        # ax.eventplot([np.where(b[0, i, :] == 1.)[0] for i in range(0, N_PGs)], 
        #       colors=colors, lineoffsets=1, linelengths=0.9);

In [None]:
#slice /= slice.max()

# slice.shape, logit_b_hat_top.shape

In [None]:
if done_learning:
    if do_figures:
        slice = logit_b_hat_top.swapaxes(0, -1)[i_y, :, :, :]
        cm = plt.get_cmap('brg')
        fig, ax = plt.subplots(figsize=(fig_width, fig_width))
        for label, color in zip([label_down_by_one, label_up_by_one], ['r', 'b']):
            for i_x in range(x_min, x_max):
                ax.plot(-logit_b_hat_max_np[label, :, i_x, i_y]/logit_b_hat_max_np.max()+i_x-x_min+.5, lw=0.3, c=color)

            ax.eventplot([np.where(slice[:, x, label] == 1.)[0] for x in range(x_min, x_max)], 
                          color=color, lineoffsets=1, linelengths=0.8, lw=.5)

        ax.set_yticklabels(range(x_min, x_max))
        ax.set_xlabel("Time")
        ax.set_ylabel("X")
        ax.invert_yaxis()

In [None]:
if done_learning:
    if do_figures:
        fig, ax = plt.subplots(figsize=(fig_width, fig_width))
        slice = logit_b_hat_top.swapaxes(0, -1)[i_y, :, :, :]
        x_min, x_max = N_X//4, N_X//2
        for label in range(N_PGs): #K_min, K_min+12):
            color = color_bar[label]
            for i_x in range(x_min, x_max):
                ax.plot(-logit_b_hat_max_np[label, :, i_x, i_y]/logit_b_hat_max_np.max()+i_x-x_min+.6, lw=0.3, c=color)

            ax.eventplot([np.where(slice[:, x, label] == 1.)[0] for x in range(x_min, x_max)], 
                          color=color, lineoffsets=1, linelengths=0.8, lw=.5)

        ax.set_yticklabels(range(x_min, x_max))
        ax.set_xlabel("Time")
        ax.set_ylabel("X")
        ax.invert_yaxis()

In [None]:
if done_learning:
    if do_figures:
        fig, ax = plt.subplots(figsize=(fig_width, fig_width))
        slice = logit_b_hat_top.swapaxes(0, -1)[i_y, :, :, :]
        x_min, x_max = N_X//4, N_X//2
        for label in range(N_PGs): #K_min, K_min+12):
            color = color_bar[label]
            for i_x in range(x_min, x_max):
                ax.plot(-logit_b_hat_max_np[label, :, i_x, i_y]/logit_b_hat_max_np.max()+i_x-x_min+.6, lw=0.3, c=color)

            ax.eventplot([np.where(slice[:, x, label] == 1.)[0] for x in range(x_min, x_max)], 
                          color=color, lineoffsets=1, linelengths=0.8, lw=.5)

        ax.set_yticklabels(range(x_min, x_max))
        ax.set_xlabel("Time")
        ax.set_ylabel("X")
        ax.invert_yaxis()

Synthesis:

In [None]:
kernel_size

In [None]:
if done_learning:
    
    if do_figures:
        t_0, x_0 = 92/DEBUG, 64/DEBUG
        
        subplotpars = matplotlib.figure.SubplotParams(left=0.15, right=.975, bottom=0.175, top=.975, wspace=0.05, hspace=0.00,)

        fig, axs = plt.subplots(1, 3, figsize=(fig_width, fig_width/phi), subplotpars=subplotpars)
        ax = axs[0]

        for i_pol, color in enumerate(['b', 'r']) :
            ax.eventplot([np.where(It_bool[i_pol, :, x, i_y] == 1.)[0] for x in range(0, N_X)], 
                          color=color, lineoffsets=1, linelengths=0.9)

        ax.set_ylim(0, N_X)
        ax.invert_yaxis()
        # ax.imshow((It_bool[0, :, :, i_y]-It_bool[1, :, :, i_y]).numpy().T, cmap = 'RdBu_r')
        ax.set_xlim(t_min, t_max)
        ax.set_xlabel("Time")
        ax.set_ylabel("X (vertical position)")
        ax.set_aspect('equal', 'box')

        ax.add_patch(patches.Rectangle((t_0, x_0), kernel_size[0], kernel_size[1], 
                                       linewidth=5, edgecolor='g', ls='--', facecolor="none"))
        ax.add_patch(patches.Circle((t_0+kernel_size[0], x_0+kernel_size[1]//2), 1.5, edgecolor='none', facecolor='orange'))

        ax = axs[1]
        logit_max = logit_b_hat_max_np[idx, :, :, i_y].max()
        ax.imshow(logit_b_hat_max_np[idx, :, :, i_y].T, vmin=-logit_max, vmax=logit_max, cmap = 'RdBu_r')
        ax.set_xlabel("Time")
        
        ax.set_xlim(t_min, t_max)
        # ax.set_ylabel("X");
        ax.set_yticklabels([])
        ax.add_patch(patches.Circle((t_0+kernel_size[0], x_0+kernel_size[1]//2), 1.5, edgecolor='none', facecolor='orange'))

        ax = axs[2]

        cm = plt.get_cmap('brg')
        slice = logit_b_hat_top.swapaxes(0, -1)[i_y, :, :, :]
        for i_class in range(0, N_PGs):
            ax.eventplot([np.where(slice[:, x, i_class] == 1.)[0] for x in range(0, N_X)], 
                          color=cm(1.*i_class/N_PGs), lineoffsets=1, linelengths=0.9)
        ax.set_xlabel("Time")
        ax.set_ylim(0, N_X)
        ax.invert_yaxis()
        ax.set_xlim(t_min, t_max)
        ax.set_xlabel("Time")
        ax.set_yticklabels([])
        # ax.axis('equal')
        ax.set_aspect('equal', 'box')

        # ax.set_ylabel("X");
        ax.add_patch(patches.Circle((t_0+kernel_size[0], x_0+kernel_size[1]//2), 1.5, edgecolor='none', facecolor='orange'))

        # Add line from one subplot to the other
        # ConnectionPatch handles the transform internally so no need to get fig.transFigure
        # https://stackoverflow.com/a/67531807
        for dx_0 in [kernel_size[1], 0]:
            arrow = patches.ConnectionPatch(
                (t_0+kernel_size[0], x_0+dx_0),
                (t_0+kernel_size[0], x_0+kernel_size[1]//2),
                coordsA=axs[0].transData,
                coordsB=axs[1].transData,
                # Default shrink parameter is 0 so can be omitted
                color="black",
                arrowstyle="->",  # "normal" arrow
                mutation_scale=30,  # controls arrow head size
                linewidth=2,
                linestyle='-',
            )
            fig.patches.append(arrow)
        arrow = patches.ConnectionPatch(
            (t_0+kernel_size[0], x_0+kernel_size[1]//2),
            (t_0+kernel_size[0], x_0+kernel_size[1]//2),
            coordsA=axs[1].transData,
            coordsB=axs[2].transData,
            # Default shrink parameter is 0 so can be omitted
            color="black",
            arrowstyle="->",  # "normal" arrow
            mutation_scale=30,  # controls arrow head size
            linewidth=2,
            linestyle='-',
        )
        fig.patches.append(arrow)

        if DEBUG<8: fig.savefig(os.path.join(figures, datetag + '_conv_HD-SNN.png'), bbox_inches='tight')
        if figpath != None: fig.savefig(os.path.join(figpath, 'conv_HD-SNN.pdf'), bbox_inches='tight')        

In [None]:
It_bool.shape

In [None]:
if done_learning:
#   for i_seed   in range(1000, 1012):
    # seed_ = i_seed
    seed_ = 1011 # good accuracy
    # print(seed_)
    # print(seed_show+i_seed)
    # i_y = 16
    # print(i_y)
    # w = NatWorld(seed=seed_show+i_seed, selectivity=1., noise=0.)
    w = NatWorld(seed=seed_, noise=0.)
    It_bool = w.get_input().to(device)
    
    # i_y = N_Y//4

    It_bool_flip = torch.roll(It_bool, 1, dims=0)
    logit_b_hat = model(It_bool.unsqueeze(0)).squeeze(0)
    logit_b_hat_flip = model(It_bool_flip.unsqueeze(0)).squeeze(0) # flipping ON and OFF events
    logit_b_hat_max = torch.maximum(logit_b_hat, logit_b_hat_flip) # element-wise maximum
    logit_b_hat_max_np = logit_b_hat_max.cpu().detach().numpy()
    # print(logit_b_hat_max_np.max(), logit_b_hat_max_np.min(), logit_b_hat_max_np.shape, logit_b_hat_max_np.max(axis=-1).max(axis=-1).shape)
    threshold = np.quantile(logit_b_hat_max_np, 1-p_B_show)
    logit_b_hat_top = logit_b_hat_max_np > threshold

    logit_b_hat_mean = logit_b_hat_top.mean(axis=(2, 3))
    label_b_hat = logit_b_hat_mean.argmax(axis=0)

    def sigmoid (x):
        return 1 - 1 / (1 + np.exp (-x) ) 



In [None]:

# for i_x in range(0, N_X-10, 7):
#     print(i_x, i_y)
if True:
    slice = logit_b_hat_top.swapaxes(0, -1)[i_y, :, :, :]

    # b_hat_max_np = b_hat_max.squeeze(0).detach().numpy()
    #     
    if do_figures:
        x_min, x_max = N_X//4, int(N_X/2.5)
        x_min, x_max = N_X//4, N_X//2
        x_min, x_max = N_X//2, N_X
        x_min, x_max = 0, 10
        x_min, x_max = 0, N_X//4
        x_min, x_max = i_x + 6, i_x + 16
        x_min, x_max = 35, 45
        t_min, t_max = N_T//2, N_T
        t_min, t_max = N_T//4, N_T
        t_min, t_max = 0, N_T
        K_min_, K_max_ = 8, 28
        K_min_, K_max_ = K_min, K_min + 20
        K_min_, K_max_ = 0, N_PGs
        subplotpars = matplotlib.figure.SubplotParams(left=0.15, right=.975, bottom=0.175, top=.975, wspace=0.05, hspace=0.00,)

        fig, axs = plt.subplots(1, 2, figsize=(fig_width, fig_width/phi**2), subplotpars=subplotpars)
        ax = axs[0]

        for i_x in range(x_min, x_max):
            x_ = (i_x-x_min) -.5
            ax.plot([t_min, t_max], [x_, x_], lw=0.2, c='k', alpha=.4)


        for i_pol, color in enumerate(['b', 'r']) :
            ax.eventplot([np.where(It_bool[i_pol, :, x, i_y] == 1.)[0] for x in range(x_min, x_max)], 
                          color=color, lineoffsets=1, lw=.8, linelengths=0.8)


        ax.scatter(np.arange(t_min, t_max), (x_max-x_min)*np.ones(t_max-t_min), s=20, c=[color_bar[w.label[t]] for t in range(t_min, t_max)])
        # ax.invert_yaxis()
        ax.set_xlim(t_min-.5, t_max-.5)
        ax.set_ylim(-.75, x_max-x_min + .35)
        ax.set_yticklabels(range(x_min, x_max))
        ax.set_yticks(np.arange(-.5, x_max-x_min-.5))
        ax.set_xlabel("Time")
        ax.set_ylabel("X (vertical position)")


        ax = axs[1]

        for label in range(K_min_, K_max_): #
            color = color_bar[label]
            for i_x in range(x_min, x_max):
                # ax.plot(2.0* logit_b_hat_max_np[label, :, i_x, i_y]/logit_b_hat_max_np.max()+i_x-x_min+.2, lw=0.08, c=color)
                # ax.plot(-2.0* logit_b_hat_max_np[label, :, i_x, i_y]/logit_b_hat_max_np.min()+i_x-x_min+.7, lw=0.5, c=color, alpha=.5)
                # ax.plot(.5* logit_b_hat_max_np[label, :, i_x, i_y]/logit_b_hat_max_np[:, :, i_x, i_y].max()+i_x-x_min+.5, lw=0.5, c=color, alpha=.5)
                ax.plot((i_x-x_min) + .6*logit_b_hat_max_np[label, :, i_x, i_y]/logit_b_hat_max_np.max()-.35, lw=0.2, c=color, alpha=.4)
                # ax.plot(.5* sigmoid(.2*logit_b_hat_max_np[label, :, i_x, i_y])+i_x-x_min+.17, lw=0.5, c=color, alpha=.4)

            ax.eventplot([np.where(slice[:, x, label] == 1.)[0] for x in range(x_min, x_max)], 
                          color=color, lineoffsets=1., linelengths=0.45, lw=.8, alpha=.4)
            
        ax.scatter(np.arange(t_min, t_max), (x_max-x_min)*np.ones(t_max-t_min), s=20, c=[color_bar[label_b_hat[t]] for t in range(t_min, t_max)])

        ax.set_yticklabels([])#range(x_min, x_max))
        ax.set_yticks(np.arange(-.5, x_max-x_min-.5))
        ax.set_xlim(t_min-.5, t_max-.5)
        ax.set_ylim(-.75, x_max-x_min + .35)
        ax.set_xlabel("Time")
        # ax.invert_yaxis()
        plt.show()


In [None]:
if do_figures:
    if DEBUG<8: fig.savefig(os.path.join(figures, datetag + '_conv_HD-SNN.png'), bbox_inches='tight')
    if figpath != None: fig.savefig(os.path.join(figpath, 'conv_HD-SNN.pdf'), bbox_inches='tight')        

### Testing

#### test function

In [None]:

def test_model(model, path, w=w, do_cache=True, p_B=p_B_test, verbose=True):
    
    if os.path.isfile(path + '.npz'):
        correct = np.load(path + '.npz')['correct']

    elif os.path.isfile(path + '.lock'):
        # we want to have a file but it's locked
        print(f'Path {path} is locked')
        return np.nan * np.zeros(N_test)

    else:
        # either we do not need a file or it does not exist (or it's locked)
        touch(path + '.lock') # we want to have a file let's lock it
        model.eval()
        with torch.no_grad():
            model = model.to(device)
            full_kernel_size = model.conv_layer.weight.data.shape
            correct = np.zeros(N_test)
            seed_test_ = w.seed
            for i_test in range(N_test):
                w.draw(seed=seed_test_+i_test)
                # generate the input
                It_bool = w.get_input(do_cache=do_cache)
                It_bool = It_bool.to(device)

                # predict
                logit_b_hat = model(It_bool.unsqueeze(0)).squeeze(0)
                # flip the ON/OFF bit
                It_bool_flip = torch.roll(It_bool, 1, dims=0)
                # flipping ON and OFF events
                logit_b_hat_flip = model(It_bool_flip.unsqueeze(0)).squeeze(0) # element-wise maximum 
                logit_b_hat_max = torch.maximum(logit_b_hat, logit_b_hat_flip) 

                # removing the first instants 
                logit_b_hat_max = logit_b_hat_max[:, full_kernel_size[2]:, :, :]
                 
                if p_B<1:
                    # using a winner takes all on the dimension of hypothesis
                    # logit_b_hat_top, ind = torch.max(logit_b_hat_max, dim=0)
                    # logit_b_hat_top = logit_b_hat_top * torch.ones((w.N_PGs, 1, 1, 1), device=device)
                    # top_of_the_top = logit_b_hat_top > torch.quantile(logit_b_hat_top, 1.-p_B)

                    weight, ind = torch.max(logit_b_hat_max, dim=0)
                    weight = weight > torch.quantile(weight, 1-p_B) # select that with maximal     
                    weight = weight * torch.ones((N_PGs, 1, 1, 1), device=device)
                    logit_b_hat_mean = (logit_b_hat_max*weight).mean(dim=(1, 2, 3)) # average over space and time
                else:
                    logit_b_hat_mean = logit_b_hat_max.mean(dim=(1, 2, 3)) # average over space and time
                                
                decision = logit_b_hat_mean.argmax().item()
                # if verbose: print(f'GT: {V_phis_line[w.label[0]]} {V_speeds_line[w.label[0]]}, pred: {V_phis_line[decision]} {V_speeds_line[decision]}')
                # if verbose: print(f'GT: {V_phis_line[w.label[0]]} , pred: {V_phis_line[decision]}, {np.mod(V_phis_line[w.label[0]]-V_phis_line[decision], 2*np.pi)}')
                correct[i_test] = (decision == w.label[0]) # assumes block_length=np.inf

        np.savez_compressed(path + '.npz', correct=correct)

    if os.path.isfile(path + '.lock'): os.remove(path + '.lock')
    
    return correct

In [None]:
if done_learning:
    w = NatWorld(seed=seed_test, block_length=np.inf)
    path_results = cachepath(data_cache, datetag, DEBUG) + '_results'
    correct = test_model(model, path=path_results, w=w)
    print(f'Accuracy for detection: {np.mean(correct)*100:.1f}%')
    # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binom.html
    from scipy.stats import binom
    p = np.mean(correct)
    n = correct.shape[0]
    print('mean', p, ', quantiles ', binom.ppf([0.05, 0.95], n, p)/n) # get quantiles from beta distribution

#### test with different `p_B_test` values

In [None]:
if done_learning:
    results_p_Bs = np.zeros_like(p_Bs)
    for i_p_B, p_B_ in enumerate(p_Bs):
        path_p_B_test = cachepath(data_cache, datetag, DEBUG) + f'_p_B_test={p_B_:.3e}'
        w = NatWorld(seed=seed_test, block_length=np.inf)
        correct = test_model(model, path=path_p_B_test, w=w, p_B=p_B_)
        print(f'For p_B_test={p_B_:.3e}, accuracy = {np.mean(correct)*100:.1f}%')                

In [None]:
if done_learning:
    results_p_Bs = np.zeros_like(p_Bs)
    for i_p_B, p_B_ in enumerate(p_B_test*np.logspace(-1, 1, N_scan, base=10)):
        path_p_B_test = cachepath(data_cache, datetag, DEBUG) + f'_p_B_test={p_B_:.3e}'
        w = NatWorld(seed=seed_test, block_length=np.inf)
        correct = test_model(model, path=path_p_B_test, w=w, p_B=p_B_)
        print(f'For p_B_test={p_B_:.3e}, accuracy = {np.mean(correct)*100:.1f}%')                

In [None]:
if done_learning:
    results_p_Bs = np.zeros_like(p_Bs)
    for i_p_B, p_B_ in enumerate(p_B_test*np.logspace(-1, 1, N_scan, base=2)):
        path_p_B_test = cachepath(data_cache, datetag, DEBUG) + f'_p_B_test={p_B_:.3e}'
        w = NatWorld(seed=seed_test, block_length=np.inf)
        correct = test_model(model, path=path_p_B_test, w=w, p_B=p_B_)
        print(f'For p_B_test={p_B_:.3e}, accuracy = {np.mean(correct)*100:.1f}%')                

#### Testing with a reduced number of active weights in the kernels

In [None]:
N_quantif = max(32//DEBUG, 4)
qs = 1 - np.geomspace(.0001, 1, N_quantif, endpoint=True)
def test_sparse(model, qs=qs):
    N_q = len(qs)
    model = Net(kernel_size, N_PGs)
    path = cachepath(data_cache, datetag, DEBUG)

    sparseness_K, correct_K = np.zeros(N_q), np.zeros((N_test, N_q))
    for i_q in range(N_q):
        w = NatWorld(seed=seed_test, block_length=np.inf)
        
        model.load_state_dict(torch.load(path + '.pth', map_location=torch.device(device)))
        size = torch.numel(model.conv_layer.weight.data) # total number of synapses
        threshold = torch.abs(model.conv_layer.weight.data).quantile(qs[i_q])
        model.conv_layer.weight.data[torch.abs(model.conv_layer.weight.data) < threshold] = 0.
        sparseness_K[i_q] = 1-(model.conv_layer.weight.data==0.).sum().item()/size
        path_results = cachepath(data_cache, datetag, DEBUG) + f'_results_K_{i_q}_{N_q}'
        correct_ = test_model(model, path=path_results, w=w)
        if not correct_ is None: correct_K[:, i_q] = correct_

    return sparseness_K, correct_K

if done_learning:
    sparseness_K, correct_K = test_sparse(model)

In [None]:
def test_short(model):
    
    model = Net(kernel_size, N_PGs)
    path = cachepath(data_cache, datetag, DEBUG)

    # dim model.conv_layer.weight.data.shape = IN OUT T X Y
    N_kernel_T = model.conv_layer.weight.data.shape[2]
    sparseness_short, correct_short = np.zeros(N_kernel_T), np.zeros((N_test, N_kernel_T))
    for i_kernel_T in range(N_kernel_T):
        w = NatWorld(seed=seed_test, block_length=np.inf)
        
        model.load_state_dict(torch.load(path + '.pth', map_location=torch.device(device)))
        model.conv_layer.weight.data[:, :, :i_kernel_T, :, :] = 0.

        sparseness_short[i_kernel_T] = 1-i_kernel_T/N_kernel_T

        path_results = cachepath(data_cache, datetag, DEBUG) + f'_results_short_{i_kernel_T}_{N_kernel_T}'
        correct_ = test_model(model, path=path_results, w=w)
        if not correct_ is None: correct_short[:, i_kernel_T] = correct_

    return sparseness_short, correct_short

if done_learning:
    sparseness_short, correct_short = test_short(model)

In [None]:
sparseness_short, np.mean(correct_short, axis=0), sparseness_K, np.mean(correct_K, axis=0)

In [None]:
w = NatWorld(seed=seed_test, block_length=np.inf)
path_results = cachepath(data_cache, datetag, DEBUG) + '_results'
correct = test_model(model, path=path_results, w=w)
p_mean = np.mean(correct)
print(f'Accuracy for detection: {np.mean(correct)*100:.1f}%')

In [None]:
if done_learning:
    if do_figures:
        # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binom.html
        from scipy.stats import binom
        path_results = cachepath(data_cache, datetag, DEBUG) + '_results'
        p_mean = np.mean(correct)
        n = correct.shape[0]
        # p_mean = .87


        from scipy.optimize import curve_fit

        def sigmoid (x, h, slope):
            C = 1 / N_PGs
            # return A / (1 + np.exp ((x - h) / slope)) + C #np.exp(C)
            return C + (p_mean-C) / (1 + np.exp (-(x - h) / slope)) #np.exp(C)
            # return C + (p_mean-C) / (1 + (x - 10**h) ** -slope) #np.exp(C)
            # return A / (1 + (x - h)**slope) + C #np.exp(C) # Naka-rushton

        fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi))

        p_mean_K = np.mean(correct_K, axis=0)
        p_low_K = binom.ppf(0.05, n, p_mean_K)/n
        p_high_K = binom.ppf(0.95, n, p_mean_K)/n
        _ = ax.fill_between(sparseness_K, p_low_K*100, p_high_K*100, alpha = .2)
        _ = ax.plot(sparseness_K, p_mean_K*100, 'o', label='pruning')

        p_mean_short = np.mean(correct_short, axis=0)
        p_low_short = binom.ppf(0.05, n, p_mean_short)/n
        p_high_short = binom.ppf(0.95, n, p_mean_short)/n
        _ = ax.fill_between(sparseness_short, p_low_short*100, p_high_short*100, color='orange', alpha = .2)
        _ = ax.plot(sparseness_short, p_mean_short*100, 'o', label='shortening')

        # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html
        # opts_curvefit = {'p0': [-2, -.3],'maxfev':100000}
        opts_curvefit = {'p0': [-1., .5], 'maxfev':100000}
        # opts_curvefit = {'p0': [1., -2, -1.], 'maxfev':100000}
        params, _ = curve_fit(sigmoid, xdata=np.log10(sparseness_K), ydata=p_mean_K, **opts_curvefit)
        # params, _ = curve_fit(sigmoid, xdata=sparseness_K, ydata=p_mean, **opts_curvefit)
        x = np.logspace(-4, 0, 100)
        _ = ax.plot([], [], color = 'k', label='sigmoid fit')
        _ = ax.plot(x, 100*sigmoid(np.log10(x), *params), color = '#1f77b4')
        # _ = ax.plot(x, 100*sigmoid((x), *params), color = '#1f77b4')
        print('Fit for pruning: h, slope = ', params, ', cutoff gain', 10**-params[0])
        
        params, _ = curve_fit(sigmoid, xdata=np.log10(sparseness_short), ydata=p_mean_short, **opts_curvefit)
        # params, _ = curve_fit(sigmoid, xdata=(sparseness_short), ydata=p_mean_short, **opts_curvefit)
        _ = ax.plot(x, 100*sigmoid(np.log10(x), *params), color = 'orange')
        # _ = ax.plot(x, 100*sigmoid((x), *params), color = 'orange')
        print('Fit for shortening: h, slope = ', params, ', cutoff gain', 10**-params[0])
        
        _ = ax.hlines(1/N_PGs*100,sparseness_K[0], sparseness_K[-1], linestyles='dashed', color='k', label='chance level')
        
        ax.set_xscale('log')
        ax.axis([sparseness_K[0], sparseness_K[-1], 0, 100]);
        ax.set_xlabel('Number of computations as a percentage of the total \namount of weights in the 3D kernel (in %)', fontsize=16)
        ax.set_ylabel('Accuracy (in %)', fontsize=16)
        plt.legend(fontsize=16, loc='upper left');

In [None]:
if do_figures:
    if DEBUG<8: fig.savefig(os.path.join(figures, datetag + '_accuracy.png'), bbox_inches='tight')
    if figpath != None: fig.savefig(os.path.join(figpath, 'quant_accuracy.pdf'), bbox_inches='tight')

#### todo: Testing with different retinal preprocessing

* gamma
* center surround

#### todo: Testing with less contrast

* change `selectivity` when generating events
* change `slope` when generating events

### Testing on Motion Clouds

We use a set of synthetic stimuli, [Motion Clouds](https://neuralensemble.github.io/MotionClouds/) which are natural-like random textures for which we can control for velocity, among other parameters. We choose N_PGs for the different possible velocities that will correspond to the labels for the supervised training of our model.


In [None]:
# random choice of theta
theta = None #np.pi * np.random.rand()
B_theta = 15.*np.pi/180
do_aperture = False

# default parameters, see
# https://github.com/NeuralEnsemble/MotionClouds/blob/master/MotionClouds/MotionClouds.py#L46
# spatial frequency of the clouds
B_sf = .15
sf_0 = .15
B_V = .05
# pi_sf_0 = .20
# pi_B_sf = .20
# pi_B_theta = 1.0
# pi_B_V = .20
# slope for the contrast of luminance within the stimuli
slope = 1.

In [None]:
import MotionClouds as mc

def make_cloud(N_X, N_Y, N_T, V_X=V_speed_0, V_Y=0., B_V=B_V, 
               sf_0=sf_0, B_sf=B_sf, theta=theta, B_theta=B_theta, seed=seed):
    fx, fy, ft = mc.get_grids(N_X, N_Y, N_T)
    env = mc.envelope_gabor(fx, fy, ft, V_X=V_X, V_Y=V_Y, B_V=B_V,
                            sf_0=sf_0, B_sf=B_sf, theta=theta, B_theta=B_theta)
    # normalize movie between a minimum of -1 and a maximum of +1
    z = mc.random_cloud(env, seed=seed)
    if z.max() == 0.:
        print(f'sf_0={sf_0}, B_V={B_V}, B_sf={B_sf}, theta={theta}, B_theta={B_theta}, seed={seed}')
    z = 2 * mc.rectif(z) - 1
    return z

In [None]:
def make_cloud_movie(N_X, N_Y, N_T, V_phi, V_speed, theta, B_V=B_V, B_theta=B_theta, 
                     sf_0=sf_0, B_sf=B_sf, slope=slope, seed=seed):

    V_X = V_speed*np.cos(V_phi)
    V_Y = V_speed*np.sin(V_phi)

    z = make_cloud(N_X, N_Y, N_T, V_X=V_X, V_Y=V_Y, B_V=B_V, 
                   sf_0=sf_0, B_sf=B_sf, 
                   theta=theta, B_theta=B_theta, seed=seed)

    movie = np.tanh(slope*z)
    return movie

def make_mcmovie_events(w, B_V=B_V, B_theta=B_theta, sf_0=sf_0, B_sf=B_sf, selectivity=selectivity, slope=slope, 
                        do_show=False, timestamp=timestamp, fname=None):

    movie = make_cloud_movie(N_X, N_Y, N_T, V_phi=w.V_phis[0], V_speed=w.V_speeds[0], theta=w.theta, B_V=B_V, B_theta=B_theta, 
                             sf_0=sf_0, B_sf=B_sf, slope=slope, seed=seed)

    movie_spike = make_events(movie, selectivity=selectivity)

    #print(movie.min(), movie.max(), np.abs(movie_spike).sum())

    It_bool = np.zeros((N_X, N_Y, N_T, 2))
    It_bool[movie_spike==-1, 0] = 1
    It_bool[movie_spike==1, 1] = 1

    It_bool = torch.from_numpy(It_bool)
    It_bool = torch.swapaxes(It_bool, 2, 1)
    It_bool = torch.swapaxes(It_bool, 3, 2)
    It_bool = torch.swapaxes(It_bool, 2, 0)

    if do_show:
        fig, axs = plot_cloud_events(timestamp, movie, It_bool, w.V_speeds, w.V_phis, type='Textured')
        return fig, axs
    elif not fname is None:
        from matplotlib.animation import FFMpegWriter
        writer = FFMpegWriter(fps=15)
        fig, axs = plot_cloud_events(0, movie, It_bool, w.V_speeds, w.V_phis, N_show=1, type='Textured')
        with writer.saving(fig, fname, N_T):
            for timestamp_ in range(1, N_T):
                for ax in axs: ax.cla()
                fig, axs = plot_cloud_events(timestamp_, movie, It_bool, w.V_speeds, w.V_phis, N_show=1, fig=fig, axs=axs, type='Textured')
                writer.grab_frame()
    else:
        return It_bool.float()

In [None]:
class MCWorld:
    def __init__(self, seed=seed, 
                 sf_0=sf_0,
                #  pi_sf_0=pi_sf_0,
                 B_V=B_V,
                 B_sf=B_sf,
                 #pi_B_sf=pi_B_sf,
                 B_theta=B_theta,
                 selectivity=selectivity, slope=slope,
                 N_V_phi=N_V_phi, N_V_speed=N_V_speed,
                 do_aperture=do_aperture, label=None
                ):
        self.seed = seed
        self.N_V_speed = N_V_speed
        self.N_V_phi = N_V_phi
        self.N_PGs = N_V_speed * N_V_phi
        # self.pi_B_sf = pi_B_sf
        # self.pi_sf_0 = pi_sf_0
        self.args0 = dict(B_V=B_V, sf_0=sf_0, B_sf=B_sf, B_theta=B_theta, 
                          selectivity = selectivity, slope = slope)
        self.do_aperture = do_aperture
        self.draw(seed, label)

    def draw(self, seed, label=None):
        self.seed = seed
        np.random.seed(seed=seed)

        if label is None:
            self.label = np.random.randint(self.N_PGs, size=1)*np.ones(N_T, dtype=int)
        else:
            self.label = np.array([label], dtype=int)

        self.V_phis, self.V_speeds = logpol_speed(self.label*np.ones(N_T, dtype=int))
        
        B_V = self.args0['B_V']

        if self.do_aperture:
            self.theta = self.V_phis[0]
        else:
            self.theta = np.pi * np.random.rand()

        sf_0 = self.args0['sf_0'] # * np.random.lognormal(0, self.pi_sf_0)
        B_sf = self.args0['B_sf'] #* np.random.lognormal(0, self.pi_B_sf)

        if self.args0['B_theta'] is None: 
            B_theta = 2 * np.pi / 2**np.random.randint(0, 7)
        else:
            B_theta = self.args0['B_theta']
        
        B_V = self.args0['B_V']
        selectivity = self.args0['selectivity']
        slope = self.args0['slope']

        self.args = dict(selectivity=selectivity, slope=slope, 
                         B_V=B_V, sf_0=sf_0, B_sf=B_sf, B_theta=B_theta, 
                        )

    def get_input(self, do_cache=False, do_show=False):
        if do_cache:
            tensor_fname = cachepath(data_cache, datetag, DEBUG) + f'_It_bool_MC_seed={self.seed}.pt'
            if os.path.isfile(tensor_fname):
                It_bool = torch.load(tensor_fname)
            else:
                It_bool = self.get_input(do_cache=False, do_show=do_show)
                torch.save(It_bool, tensor_fname)
            return It_bool
        else:
            return make_mcmovie_events(self, do_show=do_show, **self.args)

In [None]:
w = MCWorld(seed=seed+997799, label=label_down_by_one)
w.args, w.label

In [None]:
w = MCWorld(seed=seed+997799)
w.args, w.label

In [None]:
w.V_phis

In [None]:

# movie = It_bool = np.zeros((N_X, N_Y, 0))
movie = make_mcmovie_events(w, B_V=B_V, B_theta=B_theta, selectivity=selectivity, 
                            sf_0=sf_0, B_sf=B_sf, slope=slope)
movie.shape

In [None]:
w.get_input().shape

In [None]:
w = MCWorld(seed=seed_show, label=label_down_by_one, B_theta = 0.1)
w.theta = 0
fig, axs = w.get_input(do_show=True, do_cache=False)

In [None]:
if do_figures:
    N_show = 5

    for i_show in range(N_show):
        w = MCWorld(seed=seed_show+i_show, do_aperture=True)

        fname_mp4 = os.path.join(figures, datetag + '_input-MC_' + str(i_show) + '.mp4')
        if not os.path.isfile(fname_mp4):
            make_mcmovie_events(w, B_V=B_V, B_theta=B_theta, selectivity=selectivity, 
                            sf_0=sf_0, B_sf=B_sf, slope=slope, do_show=False, fname=fname_mp4)

In [None]:
if do_figures:
    N_show = 1

    for i_show in range(N_show):
        w = MCWorld(seed=seed_show+i_show, label=label_down_by_one)
        fig, axs = w.get_input(do_show=True, do_cache=False)
        plt.show()


##### loading the model

We load the default model and test it on different motion clouds textures:

In [None]:
# w = NatWorld(seed=seed)
model, df_train = learn_model(path=cachepath(data_cache, datetag, DEBUG))

##### test with different `p_B_test_MC` values

In [None]:
if done_learning:
    for i_p_B, p_B_ in enumerate(p_Bs):
        path_p_B_test_MC = cachepath(data_cache, datetag, DEBUG) + f'_p_B_test_MC={p_B_:.3e}'
        w = MCWorld(seed=seed_test)
        correct = test_model(model, path=path_p_B_test_MC, w=w, p_B=p_B_)
        print(f'For p_B_test_MC={p_B_:.3e}, accuracy = {np.mean(correct)*100:.1f}%')                

In [None]:
if done_learning:
    for i_p_B, p_B_ in enumerate(p_B_test_MC*np.logspace(-1, 1, N_scan, base=2)):
        path_p_B_test_MC = cachepath(data_cache, datetag, DEBUG) + f'_p_B_test_MC={p_B_:.3e}'
        w = MCWorld(seed=seed_test)
        correct = test_model(model, path=path_p_B_test_MC, w=w, p_B=p_B_)
        print(f'For p_B_test_MC={p_B_:.3e}, accuracy = {np.mean(correct)*100:.1f}%')                

##### testing different sf_0

In [None]:
w = MCWorld(seed=seed)
w.args

In [None]:
sf_0s = sf_0 * np.logspace(-.7, 1.3, N_scan, base=8)
w = MCWorld(sf_0=sf_0s[0])
w.args

In [None]:
if done_learning:
    results_sf_0s = np.zeros_like(sf_0s)
    for i_sf_0, sf_0_ in enumerate(sf_0s):
        w = MCWorld(sf_0=sf_0_, B_sf=sf_0_)
        path_results = cachepath(data_cache, datetag, DEBUG) + f'_MC-sf_0={sf_0_/np.pi*180:.1f}'
        correct = test_model(model, path=path_results, w=w, do_cache=False, p_B=p_B_test_MC)
        results_sf_0s[i_sf_0] = np.mean(correct)
        print(f'sf_0={sf_0_/np.pi*180:.1f}°: {np.mean(correct)*100:.1f}%')

##### testing with or without independence between orientation and direction

In [None]:
if done_learning:
    results_sf_0s_aperture = np.zeros_like(sf_0s)
    for i_sf_0, sf_0_ in enumerate(sf_0s):
        path_results = cachepath(data_cache, datetag, DEBUG) + f'_MC-sf_0={sf_0_/np.pi*180:.1f}-do_aperture={not do_aperture}'
        w = MCWorld(sf_0=sf_0_, B_sf=sf_0_, do_aperture=not do_aperture)
        correct = test_model(model, path=path_results, w=w, do_cache=False, p_B=p_B_test_MC)
        results_sf_0s_aperture[i_sf_0] = np.mean(correct)
        print(f'sf_0={sf_0_/np.pi*180:.1f}°: {np.mean(correct)*100:.1f}%')

##### testing different B_thetas

In [None]:
w = MCWorld(seed=seed)
w.args

In [None]:
B_thetas = B_theta * np.logspace(-1.5, .5, N_scan, base=4)
w = MCWorld(B_theta=B_thetas[0])
w.args, w.theta

In [None]:
if done_learning:
    results_B_thetas = np.zeros_like(B_thetas)
    for i_B_theta, B_theta_ in enumerate(B_thetas):
        w = MCWorld(B_theta=B_theta_)
        path_results = cachepath(data_cache, datetag, DEBUG) + f'_MC-B_theta={B_theta_/np.pi*180:.1f}'
        correct = test_model(model, path=path_results, w=w, do_cache=False, p_B=p_B_test_MC)
        results_B_thetas[i_B_theta] = np.mean(correct)
        print(f'B_theta={B_theta_/np.pi*180:.1f}°: {np.mean(correct)*100:.1f}%')

##### testing with or without independence between orientation and direction

In [None]:
if done_learning:
    results_B_thetas_aperture = np.zeros_like(B_thetas)
    for i_B_theta, B_theta_ in enumerate(B_thetas):
        path_results = cachepath(data_cache, datetag, DEBUG) + f'_MC-B_theta={B_theta_/np.pi*180:.1f}-do_aperture={not do_aperture}'
        w = MCWorld(B_theta=B_theta_, do_aperture=not do_aperture)
        correct = test_model(model, path=path_results, w=w, do_cache=False, p_B=p_B_test_MC)
        results_B_thetas_aperture[i_B_theta] = np.mean(correct)
        print(f'B_theta={B_theta_/np.pi*180:.1f}°: {np.mean(correct)*100:.1f}%')

##### testing with different B_sf

In [None]:
B_sfs = B_sf * np.logspace(-1, 1, N_scan, base=10)
w = MCWorld(B_sf=B_sfs[0])
w.args

In [None]:
if done_learning:
    results_B_sfs = np.zeros_like(B_sfs)
    for i_B_sf, B_sf_ in enumerate(B_sfs):
        path_results = cachepath(data_cache, datetag, DEBUG) + f'_MC-B_sf={B_sf_:.3e}'
        w = MCWorld(B_sf=B_sf_)
        correct = test_model(model, path=path_results, w=w, do_cache=False, p_B=p_B_test_MC)
        results_B_sfs[i_B_sf] = np.mean(correct)
        print(f'B_sf={B_sf_:.3e}: {np.mean(correct)*100:.1f}%')

In [None]:
if done_learning:
    results_B_sfs_aperture = np.zeros_like(B_sfs)
    for i_B_sf, B_sf_ in enumerate(B_sfs):
        path_results = cachepath(data_cache, datetag, DEBUG) + f'_MC-B_sf={B_sf_:.3e}-do_aperture={not do_aperture}'
        w = MCWorld(B_sf=B_sf_, do_aperture=not do_aperture)
        correct = test_model(model, path=path_results, w=w, do_cache=False, p_B=p_B_test_MC)
        results_B_sfs_aperture[i_B_sf] = np.mean(correct)
        print(f'B_sf={B_sf_:.3e}: {np.mean(correct)*100:.1f}%')

##### testing with different B_V

In [None]:
B_Vs = B_V * np.logspace(-1, 1, N_scan, base=10)
w = MCWorld(B_V=B_Vs[0])
w.args

In [None]:
if done_learning:
    results_B_Vs = np.zeros_like(B_Vs)
    for i_B_V, B_V_ in enumerate(B_Vs):
        path_results = cachepath(data_cache, datetag, DEBUG) + f'_MC-B_V={B_V_:.3e}'
        w = MCWorld(B_V=B_V_)
        correct = test_model(model, path=path_results, w=w, do_cache=False, p_B=p_B_test_MC)
        results_B_Vs[i_B_V] = np.mean(correct)
        print(f'B_V={B_V_:.3e}: {np.mean(correct)*100:.1f}%')

In [None]:
if done_learning:
    results_B_Vs_aperture = np.zeros_like(B_Vs)
    for i_B_V, B_V_ in enumerate(B_Vs):
        path_results = cachepath(data_cache, datetag, DEBUG) + f'_MC-B_V={B_V_:.3e}-do_aperture={not do_aperture}'
        w = MCWorld(B_V=B_V_, do_aperture=not do_aperture)
        correct = test_model(model, path=path_results, w=w, do_cache=False, p_B=p_B_test_MC)
        results_B_Vs_aperture[i_B_V] = np.mean(correct)
        print(f'B_V={B_V_:.3e}: {np.mean(correct)*100:.1f}%')

##### synthesis for MCs

In [None]:
if done_learning:
    if do_figures:
        acc_max = 100
        fig, axs = plt.subplots(1, 4, figsize=(fig_width, fig_width/phi**2))
        w = MCWorld()

        ax = axs[0]
        _ = ax.hlines(1/N_PGs*100, sf_0s[0], sf_0s[-1], linestyles='dashed', color='k', label='chance level')
        _ = ax.plot(sf_0s, results_sf_0s*100, label='aperture')
        _ = ax.plot(sf_0s, results_sf_0s_aperture*100, label='perpend')

        for i_sf_0, sf_0_ in enumerate(sf_0s):
            movie = make_cloud_movie(N_X, N_Y, N_T, V_phi=w.V_phis[0], V_speed=w.V_speeds[0], theta=w.theta, B_V=B_V, B_theta=B_theta, sf_0=sf_0_, B_sf=B_sf, slope=slope, seed=seed)
            axins = ax.inset_axes([i_sf_0/N_scan, 1-1/N_scan/1.25, 1/N_scan, 1/N_scan])
            axins.imshow(movie[:, :, 0], cmap='gray', vmin=movie[:, :, 0].min(), vmax=movie[:, :, 0].max())
            axins.set_xticks([])
            axins.set_yticks([])

        ax.set_xscale('log')
        ax.axis([sf_0s[0], sf_0s[-1], 0, acc_max]);
        ax.set_xlabel('Spatial frequency', fontsize=16)
        ax.set_ylabel('Accuracy (in %)', fontsize=16)

        ax = axs[1]
        _ = ax.hlines(1/N_PGs*100, B_sfs[0], B_sfs[-1], linestyles='dashed', color='k', label='chance level')
        _ = ax.plot(B_sfs, results_B_sfs*100, label='aperture')
        _ = ax.plot(B_sfs, results_B_sfs_aperture*100, label='perpend')

        for i_B_sf, B_sf_ in enumerate(B_sfs):
            movie = make_cloud_movie(N_X, N_Y, N_T, V_phi=w.V_phis[0], V_speed=w.V_speeds[0], theta=w.theta, B_V=B_V, B_theta=B_theta, sf_0=sf_0, B_sf=B_sf_, slope=slope, seed=seed)
            axins = ax.inset_axes([i_B_sf/N_scan, 1-1/N_scan/1.25, 1/N_scan, 1/N_scan])
            axins.imshow(movie[:, :, 0], cmap='gray', vmin=movie[:, :, 0].min(), vmax=movie[:, :, 0].max())
            axins.set_xticks([])
            axins.set_yticks([])

        ax.set_xscale('log')
        ax.axis([B_sfs[0], B_sfs[-1], 0, acc_max]);
        ax.set_xlabel('Spatial freq bandwidth', fontsize=16)
        ax.legend(fontsize=14, loc='right', bbox_to_anchor=(0.75, 0., 0.25, 0.25), frameon=False)

        ax = axs[2]
        _ = ax.hlines(1/N_PGs*100, B_thetas[0], B_thetas[-1], linestyles='dashed', color='k', label='chance level')
        _ = ax.plot(B_thetas, results_B_thetas*100, label='no aperture')
        _ = ax.plot(B_thetas, results_B_thetas_aperture*100, label='perpend')


        for i_B_theta, B_theta_ in enumerate(B_thetas):
            movie = make_cloud_movie(N_X, N_Y, N_T, V_phi=w.V_phis[0], V_speed=w.V_speeds[0], theta=w.theta, B_V=B_V, B_theta=B_theta_, sf_0=sf_0, B_sf=B_sf, slope=slope, seed=seed)
            axins = ax.inset_axes([i_B_theta/N_scan, 1-1/N_scan/1.25, 1/N_scan, 1/N_scan])
            axins.imshow(movie[:, :, 0], cmap='gray', vmin=movie[:, :, 0].min(), vmax=movie[:, :, 0].max())
            axins.set_xticks([])
            axins.set_yticks([])

        ax.set_xscale('log')
        ax.axis([B_thetas[0], B_thetas[-1], 0, acc_max]);
        ax.set_xlabel('Orientation bandwidth', fontsize=16)

        ax = axs[3]
        _ = ax.hlines(1/N_PGs*100, B_Vs[0], B_Vs[-1], linestyles='dashed', color='k', label='chance level')
        _ = ax.plot(B_Vs, results_B_Vs*100, label='aperture')
        _ = ax.plot(B_Vs, results_B_Vs_aperture*100, label='perpend')

        for i_B_V, B_V_ in enumerate(B_Vs):
            movie = make_cloud_movie(N_X, N_Y, N_T, V_phi=w.V_phis[0], V_speed=w.V_speeds[0], theta=w.theta, B_V=B_V_, B_theta=B_theta, sf_0=sf_0, B_sf=B_sf, slope=slope, seed=seed)
            axins = ax.inset_axes([i_B_V/N_scan, 1-1/N_scan/1.25, 1/N_scan, 1/N_scan])
            axins.imshow(movie[:, 0, :N_X], cmap='gray', vmin=movie[:, 0, :N_X].min(), vmax=movie[:, 0, :N_X].max())
            axins.set_xticks([])
            axins.set_yticks([])

        ax.set_xscale('log')
        ax.axis([B_Vs[0], B_Vs[-1], 0, acc_max]);
        ax.set_xlabel('Speed precision ', fontsize=16)

        for ax in axs:
           ax.spines[['right', 'top']].set_visible(False);

In [None]:
if done_learning:
    if do_figures:
        if DEBUG<8: fig.savefig(os.path.join(figures, datetag + '_motion_clouds.png'), bbox_inches='tight')
        if figpath != None: fig.savefig(os.path.join(figpath, 'motion_clouds.pdf'), bbox_inches='tight')

In [None]:
movie.shape

### Validation of the learning

#### scan `lr` values

In [None]:
lrs_, loss, correct = [], [], []
for i_lr, lr_ in enumerate(lrs):
    w = NatWorld(seed=seed_train)
    path = cachepath(data_cache, datetag, DEBUG) + f'_lr={lr_:.3e}'
    model, df_train = learn_model(path=path, w=w, lr=lr_, N_train=N_train_scan, N_epochs=N_epochs_scan)

    path_results = cachepath(data_cache, datetag, DEBUG) + f'_lr={lr_:.3e}_results'
    if not df_train is None:
        lrs_.append(lr_)
        loss_lr = df_train['loss'][-loss_samples:].mean()
        print(f"lr={lr_:.3e} - loss {loss_lr:.3e}")
        loss.append(loss_lr)
        if not(os.path.isfile(path_results + '.lock')):
            w = NatWorld(seed=seed_test, block_length=np.inf)
            correct_ = test_model(model, path=path_results, w=w).mean()
            correct.append(correct_)

    if do_figures:
        learned_kernels = model.conv_layer.weight.data.cpu().numpy()
        fig, axs = plot_kernels(learned_kernels, K_min=K_min, K_max=K_min+3)
        plt.show()


In [None]:
print('Summary:')
for lr_ in lrs:
    path_results = cachepath(data_cache, datetag, DEBUG) + f'_lr={lr_:.3e}_results'
    correct = test_model(None, path=path_results)
    print(f'For lr={lr_:.3e}, accuracy = {np.mean(correct)*100:.1f}%')    

In [None]:
if do_figures:
    fig, ax = plt.subplots(figsize=(fig_width, fig_width/phi/2))
    ax.plot(lrs_, loss, lw=2, marker='.', markersize=10)
    ax.spines['left'].set_position(('axes', -0.01))
    ax.grid(which='both')
    for side in ['top', 'right'] :ax.spines[side].set_visible(False)
    ax.set_xlabel('learning rate')
    ax.set_ylabel('final loss')
    ax.set_xscale('log');
    #ax.set_ylim(0.);

#### check alternatives : use a bias in model?

In [None]:
do_bias

In [None]:
w = NatWorld(seed=seed_train)
path = cachepath(data_cache, datetag, DEBUG) + f'_do_bias={not do_bias}'
model, df_train = learn_model(path=path, w=w, do_bias=not do_bias, N_train=N_train_scan, N_epochs=N_epochs_scan)

In [None]:
if do_figures:
    learned_kernels = model.conv_layer.weight.data.cpu().numpy()
    fig, axs = plot_kernels(learned_kernels);

In [None]:
w = NatWorld(seed=seed_test, block_length=np.inf)
path_results = cachepath(data_cache, datetag, DEBUG) + f'_do_bias={not do_bias}'
correct = test_model(model, path=path_results, w=w)
print(f'Accuracy for detection: {np.mean(correct)*100:.1f}%')

#### check alternatives : use mask?

In [None]:
do_mask

In [None]:
w = NatWorld(seed=seed_train)
path = cachepath(data_cache, datetag, DEBUG) + f'_do_mask={not do_mask}'

model, df_train = learn_model(path=path, w=w, do_mask=not do_mask, N_train=N_train_scan, N_epochs=N_epochs_scan)

In [None]:
if do_figures:
    learned_kernels = model.conv_layer.weight.data.cpu().numpy()
    fig, axs = plot_kernels(learned_kernels);

In [None]:
w = NatWorld(seed=seed_test, block_length=np.inf)
path_results = cachepath(data_cache, datetag, DEBUG) + f'_do_mask={not do_mask}'
correct = test_model(model, path=path_results, w=w)
print(f'Accuracy for detection: {np.mean(correct)*100:.1f}%')

 #### learning different kernels sizes

In [None]:
for kernel_size_ in kernel_sizes:
    w = NatWorld(seed=seed_train)
    path = cachepath(data_cache, datetag, DEBUG) + f'_kernel_size={kernel_size_}'
    model, df_train = learn_model(path=path, w=w,  kernel_size=kernel_size_, N_train=N_train_scan, N_epochs=N_epochs_scan)
    if not df_train is None:
        print(f"{path}|Mean final loss {df_train['loss'][-loss_samples:].mean():.3e}")
        if do_figures:
            learned_kernels = model.conv_layer.weight.data.cpu().numpy()
            plot_kernels(learned_kernels)
            plt.show()
        w = NatWorld(seed=seed_test, block_length=np.inf)
        correct = test_model(model, path=path, w=w)
        print(f'Accuracy for detection: {np.mean(correct)*100:.1f}%')

In [None]:
print('Summary:')
for kernel_size_ in kernel_sizes:
    path = cachepath(data_cache, datetag, DEBUG) + f'_kernel_size={kernel_size_}'
    if not df_train is None:
        correct = test_model(None, path=path)
        print(f'For kernel_size={kernel_size_}, accuracy = {np.mean(correct)*100:.1f}%')    

#### learning different kernel's temporal size

In [None]:
results_k_Ts = np.zeros_like(k_Ts)
for i_k_T, k_T_ in enumerate(k_Ts):
    w = NatWorld(seed=seed_train)
    kernel_size_ = (k_T_, kernel_size[1], kernel_size[2])
    path = cachepath(data_cache, datetag, DEBUG) + f'_k_T={k_T_}'
    model, df_train = learn_model(path=path, w=w, kernel_size=kernel_size_, N_train=N_train_scan, N_epochs=N_epochs_scan)
    if not df_train is None:
        print(f"{path}|Mean final loss {df_train['loss'][-loss_samples:].mean():.3e}")
        if do_figures:
            learned_kernels = model.conv_layer.weight.data.cpu().numpy()
            plot_kernels(learned_kernels)
            plt.show()
        w = NatWorld(seed=seed_test, block_length=np.inf)
        correct = test_model(model, path=path, w=w)
        print(f'Accuracy for detection: {np.mean(correct)*100:.1f}%')
        results_k_Ts[i_k_T] = np.mean(correct)

In [None]:

print('Summary:')
for k_T_ in k_Ts:
    path = cachepath(data_cache, datetag, DEBUG) + f'_k_T={k_T_}'
    if not df_train is None:
        correct = test_model(None, path=path)
        print(f'For k_T={k_T_}, accuracy = {np.mean(correct)*100:.1f}%')

#### learning with different weight_init_center values

In [None]:
results_weight_init_centers = np.zeros_like(weight_init_centers)
for i_weight_init_center, weight_init_center_ in enumerate(weight_init_centers):
    path = cachepath(data_cache, datetag, DEBUG) + f'_weight_init_center={weight_init_center_:.3e}'
    w = NatWorld(seed=seed_train)
    model, df_train = learn_model(path=path, w=w, weight_init_center=weight_init_center_, N_train=N_train_scan, N_epochs=N_epochs_scan)
    if not df_train is None:
        print(f"{path}|Mean final loss {df_train['loss'][-loss_samples:].mean():.3e}")
        if do_figures:
            learned_kernels = model.conv_layer.weight.data.cpu().numpy()
            plot_kernels(learned_kernels)
            plt.show()
        w = NatWorld(seed=seed_test, block_length=np.inf)
        correct = test_model(model, path=path, w=w)
        print(f'Accuracy for detection: {np.mean(correct)*100:.1f}%')
        results_weight_init_centers[i_weight_init_center] = np.mean(correct)
print('Summary:')
for weight_init_center_ in weight_init_centers:
    path = cachepath(data_cache, datetag, DEBUG) + f'_weight_init_center={weight_init_center_:.3e}'
    if not df_train is None:
        correct = test_model(None, path=path)
        print(f'For weight_init_center={weight_init_center_:.3e}, accuracy = {np.mean(correct)*100:.1f}%')                

#### learning with different weight_init values

In [None]:
results_weight_inits = np.zeros_like(weight_inits)
for i_weight_init, weight_init_ in enumerate(weight_inits):
    path = cachepath(data_cache, datetag, DEBUG) + f'_weight_init={weight_init_:.3e}'
    w = NatWorld(seed=seed_train)
    model, df_train = learn_model(path=path, w=w, weight_init=weight_init_, N_train=N_train_scan, N_epochs=N_epochs_scan)
    if not df_train is None:
        print(f"{path}|Mean final loss {df_train['loss'][-loss_samples:].mean():.3e}")
        if do_figures:
            learned_kernels = model.conv_layer.weight.data.cpu().numpy()
            plot_kernels(learned_kernels)
            plt.show()
        w = NatWorld(seed=seed_test, block_length=np.inf)
        correct = test_model(model, path=path, w=w)
        print(f'Accuracy for detection: {np.mean(correct)*100:.1f}%')
        results_weight_inits[i_weight_init] = np.mean(correct)
print('Summary:')
for weight_init_ in weight_inits:
    path = cachepath(data_cache, datetag, DEBUG) + f'_weight_init={weight_init_:.3e}'
    if not df_train is None:
        correct = test_model(None, path=path)
        print(f'For weight_init={weight_init_:.3e}, accuracy = {np.mean(correct)*100:.1f}%')                

#### learning with different p_B values

In [None]:
results_p_Bs = np.zeros_like(p_Bs)
for i_p_B, p_B_ in enumerate(p_Bs):
    path = cachepath(data_cache, datetag, DEBUG) + f'_p_B={p_B_:.3e}'
    w = NatWorld(seed=seed_train)
    model, df_train = learn_model(path=path, w=w, p_B=p_B_, N_train=N_train_scan, N_epochs=N_epochs_scan)
    if not df_train is None:
        print(f"{path}|Mean final loss {df_train['loss'][-loss_samples:].mean():.3e}")
        if do_figures:
            learned_kernels = model.conv_layer.weight.data.cpu().numpy()
            plot_kernels(learned_kernels)
            plt.show()
        w = NatWorld(seed=seed_test, block_length=np.inf)
        correct = test_model(model, path=path, w=w)
        print(f'Accuracy for detection: {np.mean(correct)*100:.1f}%')
        results_p_Bs[i_p_B] = np.mean(correct)
print('Summary:')
for p_B_ in p_Bs:
    path = cachepath(data_cache, datetag, DEBUG) + f'_p_B={p_B_:.3e}'
    if not df_train is None:
        correct = test_model(None, path=path)
        print(f'For p_B={p_B_:.3e}, accuracy = {np.mean(correct)*100:.1f}%')                

## Voilà!