The goal of this notebook is to build a simple, end-to-end prototype for how we will be modeling the V1-MT motion perception pipeline:
- Generate a drifting sinusodial stimulus
- Compute local motion energy signals with a bank of Gabor Filters to simulate V1 complex cell outputs
- Generate a synthentic spike train from a Poisson process (this will be replaced with real MT data in the actual model)
- Fit a Poisson GLM, using the motion energy features, to the single, synthetic MT neuron
- Show that we can recover the synthetic spike train

In [117]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import fftconvolve
from numpy.typing import NDArray
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
%matplotlib inline



# Drifting sinusodial grating

In [110]:
def sinusoidal_2d(
    x: NDArray[np.floating],
    y: NDArray[np.floating],
    theta: float,
    A: float,
    f: float,
    phase: float,
) -> NDArray[np.floating]:
    # TODO: What is a typical domain used in real world stimuli?
    # TODO: What is an intuitive way to understand when frequency, time buckets and time constants
    #       in this context? 
    X, Y = np.meshgrid(x, y)
    X_rot = X * np.cos(theta) + Y * np.sin(theta) 
    return A * np.cos((2 * np.pi * f * X_rot) + phase)

def render_sinusodial_2d(sinusodial: NDArray[np.floating], title: str):
    plt.figure()
    plt.imshow(sinusodial, cmap='gray', origin='lower')
    plt.colorbar()
    plt.title(title)
    plt.xlabel('X')
    plt.ylabel('Intensity')
    plt.show()

def test_sinusoidal_2d():
    L = 1 
    x, y = np.linspace(0, L, 100), np.linspace(0, L, 100)
    theta = np.pi / 4
    A = 1.0
    f = 4
    phase = 0.0
    sinusodial = sinusoidal_2d(x, y, theta, A, f, phase)
    render_sinusodial_2d(sinusodial, f'A basic 2D sinusodial oriented at {theta}')

def sinusoidal_3d(
    x: NDArray[np.floating],
    y: NDArray[np.floating],
    theta: float,
    A: float,
    spatial_f: float,
    temporal_f: float,
    frames: int,
) -> NDArray[np.floating]:
    # TODO: Internalize a deeper mental model for the relationship b/w fps, # of frames, temporal frequency
    dt = 1 / frames # 1 second to go through all the frames (frames per second)
    t = np.arange(0, frames, 1) * dt
    # TODO: I'd like to better internalize a mental model for how to use meshgrid with the 'indexing'
    # parameter.
    T, X, Y = np.meshgrid(t, x, y, indexing='ij')
    X_rot = X * np.cos(theta) + Y * np.sin(theta)
    sinusoidal = A * np.cos(2 * np.pi * spatial_f * X_rot - (2 * np.pi * temporal_f * T))
    return sinusoidal

def animate_sinusodial_3d(drifting_sinusodial: NDArray[np.floating], frames: int, title: str):
    # TODO: There's some detail that I think I'm missing around how to use frames and frequency
    # correctly
    fig, ax = plt.subplots()
    im = ax.imshow(drifting_sinusodial[40], cmap='gray', origin='lower')
    fig.colorbar(im)
    plt.title(title)
    plt.xlabel('X')
    plt.ylabel('Intensity')

    def update(frame):
        next_frame = drifting_sinusodial[frame]
        im.set_array(next_frame)
        return (im,)

    ani = FuncAnimation(
        fig,
        update,
        frames = frames,
        interval = 100,
        blit=True, # redraw changed parts
    )
    plt.close()
    return HTML(ani.to_jshtml())

def render_first_n_frames(drifting_sinusodial, n):
    print(f"sinusodial shape: {drifting_sinusodial.shape}")
    plots_per_row = 4
    num_rows = n // plots_per_row
    last_row = n % plots_per_row
    if last_row:
        num_rows += 1
    fig, ax = plt.subplots(num_rows, plots_per_row)
    t = 0
    step = 5
    for i in range(num_rows):
        for j in range(plots_per_row):
            sinusodial = drifting_sinusodial[t]
            t = min(t+ step, drifting_sinusodial.shape[0] - 1)
            ax[i][j].imshow(sinusodial, cmap='gray', origin='lower')
    plt.show()

def test_sinusoidal_3d():
    L = 2 
    x, y = np.linspace(0, L, 200), np.linspace(0, L, 200)
    theta = np.pi / 4
    A = 1.0
    f = 2
    phase = 0.0
    frames = 50
    temporal_f = 3
    sinusodial = sinusoidal_3d(x, y, theta, A, f, temporal_f, frames)
    # render_first_n_frames(sinusodial, n=10)
    return animate_sinusodial_3d(sinusodial, frames, f'A drifting sinusodial oriented at {theta}') 

test_sinusoidal_3d()

# Compute local motion energy with a bank of filters

In [120]:

def create_temporal_gabor_filter(phase: float, frequency: float):
    pass

def create_spatial_gabor_filter(theta: float, phase: float, frequency: float, pixel_pitch: float = 0.02):
    # TODO: How big is this filter? What are the dimensions? Where is the origin of the filter? For now, let's pass in dimensions, and assume the origin is the center
    # TODO: Are the Gabor filters anchored to a point in space? I believe so - the Gaussian envelope should do that

    # TODO: Understand how bandwidth plays a role in this sigma calculation. I'm currently using a rule of thumb
    sigma = 0.5 / frequency
    # 3*sigma will cover 95%+ of the gaussian's spread

    # TODO: Better understand the relationship between pixel_pitch, spatial frequency and sigma - how do they work together?
    radius_px    = int(np.ceil(3 * sigma / pixel_pitch))
    coords       = np.arange(-radius_px, radius_px + 1) * pixel_pitch
    X, Y = np.meshgrid(coords, coords) 

    # TODO: Is there a cleaner way to do this rotation in numpy?
    X_rot = X * np.cos(theta) + Y * np.sin(theta)
    Y_rot = -X * np.sin(theta) + Y * np.cos(theta)

    gaussian = np.exp(-(X_rot**2 + Y_rot**2) / (2 * sigma**2))
    sinusoid = np.cos(2 * np.pi * frequency * X_rot + phase)
    kernel = gaussian * sinusoid
    return kernel / np.linalg.norm(kernel)

def test_single_spatial_filter():
    pixel_pitch = 0.02     # deg/px
    f_s = 1.0      # cyc/deg  → σ = 0.5
    k = create_spatial_gabor_filter(0, 0, f_s, pixel_pitch)
    print(k.shape)  # expect ≈ (151, 151)  because 6σ / 0.02 ≈ 150 px


def test_spatial_gabor_filter():
    thetas = [0, np.pi / 4, np.pi / 2, 3*np.pi/4, np.pi]
    frequencies = [0.25, 0.5, 1, 2, 4] # cycles/deg
    phases = [0, np.pi]

    gabors = []
    for theta in thetas:
        for f in frequencies:
            for phase in phases:
                filter = create_spatial_gabor_filter(theta, phase, f) 
                gabors.append(filter)

    # gabors = np.array(gabors)
    print(len(gabors))
    plot_spatial_gabors(gabors, 'Spatial Gabors')

# test_single_spatial_filter()
# test_spatial_gabor_filter()

def plot_spatial_gabors(gabors: list[NDArray], title: str):
    num_gabors = len(gabors)
    print(f"--> plotting {num_gabors} kernels")
    num_per_row = min(5, num_gabors)
    num_rows = (num_gabors // num_per_row) + int(bool(num_gabors % num_per_row)) # add an extra row if remainder

    fig, axs = plt.subplots(num_rows, num_per_row, squeeze=True, figsize=(12,12))
    plt.tight_layout(pad=0.0, h_pad=0.0, w_pad=0.0)
    # plt.suptitle(title)
    plt.xlabel("X")
    plt.xlabel("Y")
    idx = 0
    for i in range(num_rows):
        for j in range(num_per_row):
            ax = axs[i][j]
            plot_gabor(gabors[idx], ax) 
            idx += 1

    plt.show()

def plot_gabor(gabor: NDArray, ax):
    ax.imshow(gabor, cmap='gray', origin = 'lower')

def create_filter_bank(frequencies: list[float], thetas: NDArray[np.floating]):
    """
    Returns a numpy array of dimension 2 x F, where F is the cartesian product of |frequencies x thetas|
    This function constructs a spatio-temporal Gabor filter for each spatial frequency, orientation, 
    and organizes the returned array in terms of the two phases: 0 and np.pi/2 to make downstream
    motion energy computation on the quadrature pair easy.
    """
    phases = [0, np.pi/2]
    kernels = [[], []] # two phases - quadrature pairs
    channels = []
    for (i, phase) in enumerate(phases):
        for (j, f) in enumerate(frequencies):
            for (k, theta) in enumerate(thetas):
                filter = create_spatial_gabor_filter(theta, phase, f)
                kernels[i].append(filter) 
                channels.append((phase, f, theta))
    return kernels, channels


def create_motion_energy_features(stimulus: NDArray[np.floating]):
    # energy 
    # 1/ Create spatial Gabor filter bank (different orientations) - must be in quadrature pairs (phase shifted)
    # TODO: How does one create a Gabor filter bank? What are the various combinations (phase shifted, orientations, points in space)
    # TODO: How do the temporal filters fit into the bank?
    # TODO: What is the shape of the energy matrix (2 x F x T x X x Y)
    #
    # 2/ Convolve Gabor filters with stimulus 
    #
    # 3/ Compute motion energy
    # For each quadrature pair, compute the non-linear square + sum to compute the local motion energy for that filter pair
    # Do the above for each time slice

    thetas = np.arange(0, np.pi, np.deg2rad(22.5)) # 0 to 180 every 22.5deg
    frequencies = [1.0, 2.0, 4.0] # cycles/deg
    (bank_even, bank_odd), channels = create_filter_bank(frequencies, thetas)
    print(f"number of even filters: {len(bank_even)}, number of odd filters: {len(bank_odd)}")

    max_kernel_size = max([kernel.shape[0] for kernel in bank_even])
    pad = max_kernel_size // 2
    print(pad)
    print(stimulus.shape)
    padded_stimulus = np.pad(stimulus, [(0, 0), (pad, pad), (pad, pad)], mode='reflect')
    print(padded_stimulus.shape)

    energy = np.zeros((stimulus.shape[0], len(channels)))

    for kernel_idx in range(len(bank_even)):
        even_filter = bank_even[kernel_idx][::-1, ::-1]
        odd_filter = bank_odd[kernel_idx][::-1, ::-1]
        for (frame_idx, frame) in enumerate(stimulus):
            # TODO: How exactly should I think about the dimensions of the output (how does this change with mode='valid')
            even = fftconvolve(frame, even_filter, mode="valid")
            odd = fftconvolve(frame, odd_filter, mode="valid")

            local_energy = even**2 + odd**2

            # TODO: Is the spatial poolin necessary?
            energy[frame_idx, kernel_idx] = local_energy.mean()


    return energy


In [123]:
def test_motion_energy():
    L = 1 
    x, y = np.linspace(0, L, 200), np.linspace(0, L, 200)
    theta = np.pi / 4
    A = 1.0
    f = 2
    phase = 0.0
    frames = 50
    temporal_f = 3
    stimulus = sinusoidal_3d(x, y, theta, A, f, temporal_f, frames)
    energy = create_motion_energy_features(stimulus)
    print(energy)

test_motion_energy()

number of even filters: 24, number of odd filters: 24
75
(50, 200, 200)
(50, 350, 350)
[[ 18.22180567 156.39080835 346.76051964 ...   0.           0.
    0.        ]
 [ 18.23376142 156.40299651 346.75866747 ...   0.           0.
    0.        ]
 [ 18.24342864 156.41285165 346.75716984 ...   0.           0.
    0.        ]
 ...
 [ 18.22337141 156.39240453 346.76027707 ...   0.           0.
    0.        ]
 [ 18.21467915 156.3835433  346.76162366 ...   0.           0.
    0.        ]
 [ 18.21404214 156.3828939  346.76172235 ...   0.           0.
    0.        ]]


# Generate synthetic spikes from a Poisson process with hard-coded weights 


# Fit a Poisson GLM to the synthetic spikes to recover the real weights 

In [8]:
x = np.arange(0, 2, 1)
y = np.arange(4, 8, 1)
t = np.arange(9, 20, 1)

T, X, Y = np.meshgrid(t, x, y, indexing = 'ij')

print(f'T: {T}, {T.shape}')
print(f'X: {X}, {X.shape}')
print(f'Y: {Y}, {Y.shape}')

T: [[[ 9  9  9  9]
  [ 9  9  9  9]]

 [[10 10 10 10]
  [10 10 10 10]]

 [[11 11 11 11]
  [11 11 11 11]]

 [[12 12 12 12]
  [12 12 12 12]]

 [[13 13 13 13]
  [13 13 13 13]]

 [[14 14 14 14]
  [14 14 14 14]]

 [[15 15 15 15]
  [15 15 15 15]]

 [[16 16 16 16]
  [16 16 16 16]]

 [[17 17 17 17]
  [17 17 17 17]]

 [[18 18 18 18]
  [18 18 18 18]]

 [[19 19 19 19]
  [19 19 19 19]]], (11, 2, 4)
X: [[[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]], (11, 2, 4)
Y: [[[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]], (11, 2, 4)
