The goal of this notebook is to build a simple, end-to-end prototype for how we will be modeling the V1-MT motion perception pipeline:
- Generate a drifting sinusodial stimulus
- Compute local motion energy signals with a bank of Gabor Filters to simulate V1 complex cell outputs
- Generate a synthentic spike train from a Poisson process (this will be replaced with real MT data in the actual model)
- Fit a Poisson GLM, using the motion energy features, to the single, synthetic MT neuron
- Show that we can recover the synthetic spike train

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.typing import NDArray
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
%matplotlib inline



# Drifting sinusodial grating

In [18]:
def sinusoidal_2d(
    x: NDArray[np.floating],
    y: NDArray[np.floating],
    theta: float,
    A: float,
    f: float,
    phase: float,
) -> NDArray[np.floating]:
    # TODO: What is a typical domain used in real world stimuli?
    # TODO: What is an intuitive way to understand when frequency, time buckets and time constants
    #       in this context? 
    X, Y = np.meshgrid(x, y)
    X_rot = X * np.cos(theta) + Y * np.sin(theta) 
    return A * np.cos((2 * np.pi * f * X_rot) + phase)

def render_sinusodial_2d(sinusodial: NDArray[np.floating], title: str):
    plt.figure()
    plt.imshow(sinusodial, cmap='gray', origin='lower')
    plt.colorbar()
    plt.title(title)
    plt.xlabel('X')
    plt.ylabel('Intensity')
    plt.show()

def test_sinusoidal_2d():
    L = 1 
    x, y = np.linspace(0, L, 100), np.linspace(0, L, 100)
    theta = np.pi / 4
    A = 1.0
    f = 4
    phase = 0.0
    sinusodial = sinusoidal_2d(x, y, theta, A, f, phase)
    render_sinusodial_2d(sinusodial, f'A basic 2D sinusodial oriented at {theta}')

def sinusoidal_3d(
    x: NDArray[np.floating],
    y: NDArray[np.floating],
    theta: float,
    A: float,
    spatial_f: float,
    temporal_f: float,
    frames: int,
) -> NDArray[np.floating]:
    # TODO: Internalize a deeper mental model for the relationship b/w fps, # of frames, temporal frequency
    dt = 1 / frames # 1 second to go through all the frames (frames per second)
    t = np.arange(0, frames, 1) * dt
    # TODO: I'd like to better internalize a mental model for how to use meshgrid with the 'indexing'
    # parameter.
    T, X, Y = np.meshgrid(t, x, y, indexing='ij')
    X_rot = X * np.cos(theta) + Y * np.sin(theta)
    sinusoidal = A * np.cos(2 * np.pi * spatial_f * X_rot - (2 * np.pi * temporal_f * T))
    return sinusoidal

def animate_sinusodial_3d(drifting_sinusodial: NDArray[np.floating], frames: int, title: str):
    # TODO: There's some detail that I think I'm missing around how to use frames and frequency
    # correctly
    fig, ax = plt.subplots()
    im = ax.imshow(drifting_sinusodial[40], cmap='gray', origin='lower')
    fig.colorbar(im)
    plt.title(title)
    plt.xlabel('X')
    plt.ylabel('Intensity')

    def update(frame):
        next_frame = drifting_sinusodial[frame]
        im.set_array(next_frame)
        return (im,)

    ani = FuncAnimation(
        fig,
        update,
        frames = frames,
        interval = 100,
        blit=True, # redraw changed parts
    )
    plt.close()
    return HTML(ani.to_jshtml())

def render_first_n_frames(drifting_sinusodial, n):
    print(f"sinusodial shape: {drifting_sinusodial.shape}")
    plots_per_row = 4
    num_rows = n // plots_per_row
    last_row = n % plots_per_row
    if last_row:
        num_rows += 1
    fig, ax = plt.subplots(num_rows, plots_per_row)
    t = 0
    step = 5
    for i in range(num_rows):
        for j in range(plots_per_row):
            sinusodial = drifting_sinusodial[t]
            t = min(t+ step, drifting_sinusodial.shape[0] - 1)
            ax[i][j].imshow(sinusodial, cmap='gray', origin='lower')
    plt.show()

def test_sinusoidal_3d():
    L = 1 
    x, y = np.linspace(0, L, 100), np.linspace(0, L, 100)
    theta = np.pi / 4
    A = 1.0
    f = 2
    phase = 0.0
    frames = 50
    temporal_f = 3
    sinusodial = sinusoidal_3d(x, y, theta, A, f, temporal_f, frames)
    # render_first_n_frames(sinusodial, n=10)
    return animate_sinusodial_3d(sinusodial, frames, f'A drifting sinusodial oriented at {theta}') 

test_sinusoidal_3d()

# Compute local motion energy with a bank of filters

In [None]:
def create_gabor_filter(theta: float, sigma: float):
    pass

def motion_energy():
    pass

# Generate synthetic spikes from a Poisson process with hard-coded weights 


# Fit a Poisson GLM to the synthetic spikes to recover the real weights 

In [8]:
x = np.arange(0, 2, 1)
y = np.arange(4, 8, 1)
t = np.arange(9, 20, 1)

T, X, Y = np.meshgrid(t, x, y, indexing = 'ij')

print(f'T: {T}, {T.shape}')
print(f'X: {X}, {X.shape}')
print(f'Y: {Y}, {Y.shape}')

T: [[[ 9  9  9  9]
  [ 9  9  9  9]]

 [[10 10 10 10]
  [10 10 10 10]]

 [[11 11 11 11]
  [11 11 11 11]]

 [[12 12 12 12]
  [12 12 12 12]]

 [[13 13 13 13]
  [13 13 13 13]]

 [[14 14 14 14]
  [14 14 14 14]]

 [[15 15 15 15]
  [15 15 15 15]]

 [[16 16 16 16]
  [16 16 16 16]]

 [[17 17 17 17]
  [17 17 17 17]]

 [[18 18 18 18]
  [18 18 18 18]]

 [[19 19 19 19]
  [19 19 19 19]]], (11, 2, 4)
X: [[[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]

 [[0 0 0 0]
  [1 1 1 1]]], (11, 2, 4)
Y: [[[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]

 [[4 5 6 7]
  [4 5 6 7]]], (11, 2, 4)
