# WN stim regen demo
Verifying that stim regen gives correct STA
VR 2025-07-08

In [None]:
import retinanalysis as ra
import numpy as np
import matplotlib.pyplot as plt
import tqdm.auto as tqdm

def compute_sta(cell_idx, s1: ra.StimBlock, r1: ra.ResponseBlock, n_depth = 15):
    lags = np.arange(0,n_depth)

    sta = np.zeros((len(lags), s1.stim_data[0].shape[1], s1.stim_data[0].shape[2]))
    pre_bins = r1.bin_rate * s1.df_epochs.at[0, 'epoch_parameters']['preTime'] * 1e-3
    pre_bins = np.round(pre_bins).astype(int)
    for e_idx in range(len(s1.stim_data)):
        for i, lag in tqdm.tqdm(list(enumerate(lags)), desc="STA Depth"):
            frames = s1.stim_data[e_idx, :,:,:, 0]
            bs = r1.df_spike_times.at[cell_idx, 'binned_spikes'][e_idx, pre_bins:pre_bins + len(frames)]
            if lag > 0:
                bs = bs[lag:]
                frames = frames[:-lag]
            # Keep only timepoints with a spike
            mask = np.where(bs > 0)[0]
            if len(mask) == 0:
                print(f'No spikes found for cell idx {cell_idx} in epoch {e_idx}')
                continue
            bs = bs[mask]
            frames = frames[mask]
            frames = np.moveaxis(frames, 0, -1)
            sta[i] += np.matmul(frames, bs)
    return sta

In [None]:
df = ra.get_datasets_from_protocol_names('protocols.spatialnoise')
df = df[df.exp_name=='20250514C'].reset_index()
display(df)

idx = 0
exp_name = df.at[idx, 'exp_name']
datafile_name = df.at[idx, 'datafile_name']
pp = ra.create_mea_pipeline(exp_name, datafile_name)

Let's bin spike times.

In [None]:
pp.response_block.bin_spike_times_by_frames()

Let's regenerate WN frames for the first epoch

In [None]:
pp.stim_block.regenerate_stimulus(ls_epochs=[0])

Let's use the compute_sta function defined up top to calculate and plot the STA for cell 2

In [None]:
cell_idx = 2
sta = compute_sta(cell_idx, pp.stim_block, pp.response_block)

In [None]:
# Get Red channel spatial map, and peak pixel
cell_id = pp.response_block.cell_ids[cell_idx]

sm = pp.analysis_chunk.d_spatial_maps[cell_id][:,:,0]
peak = np.unravel_index(np.argmax(sm), sm.shape)

plt.plot(sta[:,peak[0], peak[1]])

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.imshow(sta[3,:,:], cmap='gray')
plt.subplot(122)
plt.imshow(sm, cmap='gray')