In [None]:
import retinanalysis as ra
import retinanalysis.schema as schema
import numpy as np
from matplotlib.patches import Ellipse
import matplotlib.pyplot as plt

ra.djconnect()

## Step 1: Query protocol across all experiments

In [None]:
protocols_to_query = ['movingletters', 'presentmatfiles']
experiment_query = ra.get_datasets_from_protocol_names(protocols_to_query)
display(experiment_query)

## Select experiment and data file of interest

In [None]:
exp_name = experiment_query.loc[1, 'exp_name']
datafile_name = experiment_query.loc[1,'datafile_name']

# Create stim, response, analysis_chunk and pipeline class
stim = ra.StimBlock(exp_name, datafile_name)
response = ra.ResponseBlock(exp_name, datafile_name)
analysis_chunk = ra.AnalysisChunk(exp_name, stim.nearest_noise_chunk)
pipeline = ra.MEAPipeline(stim, response, analysis_chunk)


In [None]:
cell_ids = response.df_spike_times.iloc[0:50].index.values
# cell_types = ['OnP', 'OffP']
# cell_ids = list(response.df_spike_times.query('cell_type == @cell_types').index.values)
# display(response.df_spike_times.query('cell_type == @cell_types'))

axes = pipeline.plot_rfs(cell_ids = cell_ids, units = 'microns');

In [None]:
epoch_params = ['imageOrder']
image_order = [lst for lst in stim.df_epochs['imageOrder']]
image_order = np.array(image_order)

# display(stim.df_epochs.columns)

images_per_epoch = stim.df_epochs.loc[0,'epoch_parameters']['imagesPerEpoch']
num_epochs = len(stim.df_epochs)

# epochs_to_see = [1, 2, 5, 245, 6, 232]
# display(stim.df_epochs[['imageOrder', 'matFile']].query('epoch_index == @epochs_to_see'))

cell_of_interest = 6
cell_1_times = response.df_spike_times.query('cell_id == 1')['spike_times'].values

OffP_noise_ids = analysis_chunk.cell_params_df.query('typing_file_0 == "OffP"')['cell_id'].values
OffP_protocol_ids = [pipeline.match_dict[cell] for cell in OffP_noise_ids if cell in pipeline.match_dict.keys()]
OffP_responses = response.df_spike_times.query('cell_id == @OffP_protocol_ids')
OffP_spike_times = OffP_responses['spike_times'].values

test = np.empty((OffP_spike_times.shape[0], OffP_spike_times[0].shape[0]), dtype = object)
for i in range(test.shape[0]):
    test[i,:] = OffP_spike_times[i]

print(test.shape)
fig, ax = plt.subplots()

epochs_to_plot = np.arange(0,9)
ax.eventplot(test[:,0], color = 'k', linewidths = 0.8, linelengths = 0.8)
ax.set_xlabel("Time (ms)")
ax.set_ylabel("Epoch")


In [None]:
cell_types = ['OnP', 'OffP']
cell_ids = [1, 3, 9, 10, 38, 42, 121,]

cts = list(response.df_spike_times['cell_type'].unique())

display(cts)
test = [key for key, val in pipeline.match_dict.items()]
print(test)


protocol_ids = response.df_spike_times.query('cell_type == @cell_types').index.values
ids_to_plot = [key for key, val in pipeline.match_dict.items() if (val in protocol_ids
                                                           and val in cell_ids)]

corr_ids = [val for key, val in pipeline.match_dict.items() if (val in protocol_ids
                                                        and val in cell_ids)]
print(ids_to_plot)
print(corr_ids)

In [None]:

def get_ells(analysis_chunk, cell_types, std_scaler = 1.6, units = None):
    
    if units is not None:
        if 'microns' in units.lower():
            scale_factor = analysis_chunk.microns_per_stixel
        elif 'pixels' in units.lower():
            scale_factor = analysis_chunk.pixels_per_stixel
        else:
            scale_factor = 1
    else:
        scale_factor = 1

    ells = []
    for idx, ct in enumerate(cell_types):
        ell_dict = dict()
        cell_df = analysis_chunk.cell_params_df.query('typing_file_0 == @ct')
        rf_params = analysis_chunk.rf_params
        cell_ids = cell_df['cell_id'].values
        for id in cell_ids:
            ell_dict[id] = Ellipse(xy=(rf_params[id]['center_x']*scale_factor,
                                    rf_params[id]['center_y']*scale_factor),
                                    width = rf_params[id]['std_x']*std_scaler*scale_factor,
                                    height = rf_params[id]['std_y']*std_scaler*scale_factor,
                                    angle = rf_params[id]['rot'],
                                    facecolor= f'C{idx}', edgecolor= f'C{idx}',
                                    alpha = 0.7)
        ells.append(ell_dict)
    
    return ells, scale_factor

units = 'pixels'
cell_types = ['OffP', 'OnP', 'OffM', 'OnM', 'SBC']
all_ells, scale_factor = get_ells(analysis_chunk, cell_types, units = units)

fig, ax = plt.subplots(nrows=1, ncols = len(cell_types), figsize = (18, 12/len(cell_types)))

for idx, ct in enumerate(cell_types):
    for id in all_ells[idx]:
        ax[idx].add_patch(all_ells[idx][id])

    ax[idx].set_xlim(0,analysis_chunk.numXChecks * scale_factor)
    ax[idx].set_ylim(0,analysis_chunk.numYChecks * scale_factor)
    if units is not None:
        ax[idx].set_ylabel(units.lower())
        ax[idx].set_xlabel(units.lower())

fig.tight_layout()

