In [None]:
import retinanalysis as ra
import retinanalysis.schema as schema
import numpy as np
import matplotlib.pyplot as plt

## Step 1: Query protocol across all experiments

In [None]:
protocols_to_query = ['movingletters', 'presentmatfiles']
experiment_query = ra.get_datasets_from_protocol_names(protocols_to_query)
display(experiment_query)

## Select experiment and data file of interest

In [None]:
exp_name = experiment_query.loc[0, 'exp_name']
datafile_name = experiment_query.loc[0,'datafile_name']

# Create stim, response, analysis_chunk and pipeline class
stim = ra.StimBlock(exp_name, datafile_name)
response = ra.ResponseBlock(exp_name, datafile_name)
analysis_chunk = ra.AnalysisChunk(exp_name, stim.nearest_noise_chunk)
pipeline = ra.MEAPipeline(stim, response, analysis_chunk)

In [None]:
cell_types = ['OnP', 'OffP', 'OnM', 'OffM', 'SBC']
rf_axes = pipeline.plot_rfs(cell_types = cell_types, units = 'pixels')
timecourse_axes = pipeline.plot_timecourses(cell_types = cell_types);

In [None]:
epoch_params = ['imageOrder']
image_order = [lst for lst in stim.df_epochs['imageOrder']]
image_order = np.array(image_order)

# display(stim.df_epochs.columns)

images_per_epoch = stim.df_epochs.loc[0,'epoch_parameters']['imagesPerEpoch']
num_epochs = len(stim.df_epochs)

# epochs_to_see = [1, 2, 5, 245, 6, 232]
# display(stim.df_epochs[['imageOrder', 'matFile']].query('epoch_index == @epochs_to_see'))

cell_of_interest = 6
cell_1_times = response.df_spike_times.query('cell_id == 1')['spike_times'].values

OffP_noise_ids = analysis_chunk.df_cell_params.query('typing_file_0 == "OffP"')['cell_id'].values
OffP_protocol_ids = [pipeline.match_dict[cell] for cell in OffP_noise_ids if cell in pipeline.match_dict.keys()]
OffP_responses = response.df_spike_times.query('cell_id == @OffP_protocol_ids')
OffP_spike_times = OffP_responses['spike_times'].values

st_cell_by_epoch = np.empty((OffP_spike_times.shape[0], OffP_spike_times[0].shape[0]), dtype = object)
for i in range(st_cell_by_epoch.shape[0]):
    st_cell_by_epoch[i,:] = OffP_spike_times[i]

print(st_cell_by_epoch.shape)
fig, ax = plt.subplots()

epochs_to_plot = np.arange(0,9)
ax.eventplot(st_cell_by_epoch[:,0], color = 'k', linewidths = 0.8, linelengths = 0.8)
ax.set_xlabel("Time (ms)")
ax.set_ylabel("Epoch")
