In [None]:
import importlib
import retinanalysis as ra
# importlib.reload(ra)
# importlib.reload(ra.analysis_chunk)
import pandas as pd
import numpy as np
ra.settings.mea_config
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse

# Run below if needing to update the database
# ra.database_pop.reload_celltypefiles()
# ra.database_utils.populate_database()

In [None]:
import retinanalysis.classes.dedup as dd
importlib.reload(dd)

In [None]:
df = ra.get_datasets_from_protocol_names('matfiles')
df

In [None]:
s1 = ra.MEAStimBlock(df.at[0, 'exp_name'], df.at[0, 'datafile_name'])
ac1 = ra.AnalysisChunk(df.at[0, 'exp_name'], s1.nearest_noise_chunk)

In [None]:
cell_idx = 10
cell_id = ac1.cell_ids[cell_idx]
display(ac1.d_spatial_maps[cell_id].shape)
plt.imshow(ac1.d_spatial_maps[cell_id][:,:,0])

Correlation matrix computation

In [None]:
sm_flat = [sm.flatten() for sm in ac1.d_spatial_maps.values()]
sm_flat = np.array(sm_flat)

ei_flat = [ac1.vcd.get_ei_for_cell(id).ei.flatten() for id in ac1.cell_ids]
ei_flat = np.array(ei_flat)

sm_corr = np.corrcoef(sm_flat)
ei_corr = np.corrcoef(ei_flat)

np.nan_to_num(sm_corr, copy=False, nan = 0, posinf = 0, neginf = 0)
np.nan_to_num(ei_corr, copy=False, nan = 0, posinf = 0, neginf = 0);

In [None]:
f, axs = plt.subplots(1, 2, figsize=(12, 6))
im = axs[0].imshow(sm_corr, cmap='viridis', vmin=-1, vmax=1)
axs[0].set_title('Spatial Map Correlation')
plt.colorbar(im, ax=axs[0])
im = axs[1].imshow(ei_corr, cmap='viridis', vmin=-1, vmax=1)
axs[1].set_title('EI Correlation')
plt.colorbar(im, ax=axs[1])

In [None]:
#start with a noise file

df_1 = ra.get_datasets_from_protocol_names('.SpatialNoise')
df_1[df_1['exp_name'] == '20250514C']

In [None]:
import retinanalysis.classes.dedup as dd


In [None]:
sb = ra.MEAStimBlock(df_1.at[177, 'exp_name'], df_1.at[177, 'datafile_name'])
ac = ra.AnalysisChunk(df_1.at[177, 'exp_name'], sb.nearest_noise_chunk)

In [None]:
importlib.reload(dd)
dd.compare_ei_methods(ac);

In [None]:
dup = dd.DedupBlock(exp_name=ac.exp_name, chunk_name=ac.chunk_name, ss_version='kilosort2.5', is_noise=True, ei_method='space')

In [None]:
#confirmation that padding worked as expected

test_ids = ac.cell_ids[::30]
fig, ax = plt.subplots(int(np.ceil((len(test_ids)/4))), np.min([4, (len(test_ids)-1%4)+1]), figsize=(20, 20), constrained_layout=True)
ax= ax.flatten()
for idx, cell in enumerate(test_ids):
    ax[idx].imshow(ac.d_spatial_maps[cell][:,:,0], cmap='gray')
    subset = ac.df_cell_params[ac.df_cell_params['cell_id'] == cell]
    rf_params = [subset['center_x'].values[0],
                    subset['center_y'].values[0],
                    subset['std_x'].values[0],
                    subset['std_y'].values[0],
                    subset['rot'].values[0]]
    ellipse = Ellipse(xy=(rf_params[0], rf_params[1]),
                      width=rf_params[2]*2, height=rf_params[3]*2,
                      angle=np.rad2deg(rf_params[4]), 
                      edgecolor='red', facecolor='none', lw=2)
    ax[idx].add_patch(ellipse)
    ax[idx].set_title(f'Cell ID: {cell}')



In [None]:
sm_flat = [sm.flatten() for sm in ac.d_spatial_maps.values()]
sm_flat = np.array(sm_flat)

ei_flat = [ac.vcd.get_ei_for_cell(id).ei.flatten() for id in ac.cell_ids]
ei_flat = np.array(ei_flat)

sm_corr = np.corrcoef(sm_flat)
ei_corr = np.corrcoef(ei_flat)

np.nan_to_num(sm_corr, copy=False, nan = 0, posinf = 0, neginf = 0)
np.nan_to_num(ei_corr, copy=False, nan = 0, posinf = 0, neginf = 0);

In [None]:
ei_corr = ra.ei_corr(ac, ac, method='full')
plt.imshow(ei_corr, cmap='viridis', vmin=-1, vmax=1)

In [None]:
cluster_to_index = dict(zip(ac.cell_ids, range(len(ac.cell_ids))))

In [None]:
f, axs = plt.subplots(1, 2, figsize=(12, 6))
im = axs[0].imshow(sm_corr, cmap='viridis', vmin=-1, vmax=1)
axs[0].set_title('Spatial Map Correlation')
plt.colorbar(im, ax=axs[0])
im = axs[1].imshow(ei_corr, cmap='viridis', vmin=-1, vmax=1)
axs[1].set_title('EI Correlation')
plt.colorbar(im, ax=axs[1])

In [None]:
import copy
threshold_sm = 0.80
threshold_ei = 0.80

# Identify cells with high spatial map correlation
#remove self-correlations (diagonal of the correlation matrix)
sm_upper_tri = copy.deepcopy(sm_corr)
sm_upper_tri = np.triu(sm_upper_tri, k=1)  # Keep only upper triangle of the correlation matrix
high_sm_idx = np.where(sm_upper_tri > threshold_sm)
high_sm_corr_cells = set([(ac.cell_ids[high_sm_idx[0][i]], ac.cell_ids[high_sm_idx[1][i]]) for i in range(len(high_sm_idx[0]))]) # Get the indices of cells with high spatial map correlation
high_sm_idx_set = set([(high_sm_idx[0][i], high_sm_idx[1][i]) for i in range(len(high_sm_idx[0]))]) # Combine indices of both cells in the pair

ei_upper_tri = copy.deepcopy(ei_corr)
ei_upper_tri = np.triu(ei_upper_tri, k=1)  # Keep only upper triangle of the correlation matrix
high_ei_idx = np.where(ei_upper_tri > threshold_ei)
high_ei_corr_cells = set([(ac.cell_ids[high_ei_idx[0][i]], ac.cell_ids[high_ei_idx[1][i]]) for i in range(len(high_ei_idx[0]))]) # Get the indices of cells with high EI correlation
high_ei_idx_set = set([(high_ei_idx[0][i], high_ei_idx[1][i]) for i in range(len(high_ei_idx[0]))]) # Combine indices of both cells in the pair

# all_problem_cells = np.unique(np.array([ac.cell_ids[high_ei_idx[0][:]], ac.cell_ids[high_ei_idx[1][:]], ac.cell_ids[high_sm_idx[0][:]], ac.cell_ids[high_sm_idx[1][:]]]).flatten())

fig, ax = plt.subplots(1,2, figsize=(12, 6), constrained_layout=True)
ax[0].hist(sm_upper_tri.flatten(), bins=100, alpha=0.7)
ax[0].semilogy()
ax[0].set_title('Spatial Map Correlation Histogram')
ax[0].set_xlabel('Correlation Coefficient')
ax[0].set_ylabel('Count')
ax[1].hist(ei_upper_tri.flatten(), bins=100, alpha=0.7)
ax[1].semilogy()
ax[1].set_title('EI Correlation Histogram')
ax[1].set_xlabel('Correlation Coefficient')
ax[1].set_ylabel('Count')
ax[0].axvline(threshold_sm, color='red', linestyle='--', label='Threshold')
ax[1].axvline(threshold_ei, color='red', linestyle='--', label='Threshold')
ax[0].legend()
ax[1].legend()
print(f'number of high ei correlation pairs: {len(high_ei_corr_cells)}')
print(f'number of high spatial map correlation pairs: {len(high_sm_corr_cells)}')
print(f'intersection: {len(high_ei_corr_cells.intersection(high_sm_corr_cells))}')


In [None]:
#index ac.cell_ids by target clusters
all_problem_cells = {item for tuple in high_sm_corr_cells.union(high_ei_corr_cells) for item in tuple}
all_sm_cells = {item for tuple in high_sm_corr_cells for item in tuple}
all_ei_cells = {item for tuple in high_ei_corr_cells for item in tuple}

In [None]:
from matplotlib.lines import Line2D
rf_params = ac.rf_params
fig = plt.figure(figsize=(15, 20), layout='compressed')
fig.suptitle('Overview of RFs and Correlations', fontsize=16)
subfigs = fig.subfigures(3, 1) 
ax00 = subfigs[0].subplots(1, 2)
for cell in ac.cell_ids:
    ell1 = Ellipse(xy=(rf_params[cell]['center_x'], rf_params[cell]['center_y']),
                width=rf_params[cell]['std_x']*2, height=rf_params[cell]['std_y']*2,
                angle=np.rad2deg(rf_params[cell]['rot']), 
                edgecolor='None', facecolor='None', lw=2, alpha=0.8)
    if cell in all_ei_cells and cell in all_sm_cells:
        ell1.set_edgecolor('red')
        color = 'red'
        ax00[1].add_patch(ell1)
        ax00[1].annotate(f'{cell}', xy=(rf_params[cell]['center_x'], rf_params[cell]['center_y']), fontsize=12)
    elif cell in all_ei_cells:
        ell1.set_edgecolor('orange')
        color = 'orange'
        ax00[1].add_patch(ell1)
        ax00[1].annotate(f'{cell}', xy=(rf_params[cell]['center_x'], rf_params[cell]['center_y']), fontsize=12)
    elif cell in all_sm_cells:
        ell1.set_edgecolor('magenta')
        color = 'magenta'
        ax00[1].add_patch(ell1)
        ax00[1].annotate(f'{cell}', xy=(rf_params[cell]['center_x'], rf_params[cell]['center_y']), fontsize=12)
    else:
        color = 'black'
    ell = Ellipse(xy=(rf_params[cell]['center_x'], rf_params[cell]['center_y']),
                  width=rf_params[cell]['std_x']*2, height=rf_params[cell]['std_y']*2,
                  angle=np.rad2deg(rf_params[cell]['rot']), 
                  edgecolor=color, facecolor='None', lw=1, alpha=0.8)
    ax00[0].add_patch(ell)
ax00[0].set_xlim(0, ac.numXChecks)
ax00[0].set_ylim(ac.numYChecks, 0) 
ax00[1].set_xlim(0, ac.numXChecks)
ax00[1].set_ylim(ac.numYChecks, 0)  # Invert y-axis to match spatial map orientation
ax00[0].set_title('All Cells')
ax00[1].set_title('Cells with High Correlation')
custom_lines = [Line2D([0], [0], color='red', lw=1, label='Both'),
                    Line2D([0], [0], color='orange', lw=1, label='EI only'),
                    Line2D([0], [0], color='magenta', lw=1, label='RF only')]
ax00[1].legend(handles=custom_lines, loc='upper right')

ax01 = subfigs[1].subplots(1, 2)
ax01[0].hist(sm_upper_tri.flatten(), bins=100, alpha=0.7)
ax01[0].semilogy()
ax01[0].set_title('Spatial Map Correlation Histogram')
ax01[0].set_xlabel('Correlation Coefficient')
ax01[0].set_ylabel('Count')
ax01[1].hist(ei_upper_tri.flatten(), bins=100, alpha=0.7)
ax01[1].semilogy()
ax01[1].set_title('EI Correlation Histogram')
ax01[1].set_xlabel('Correlation Coefficient')
ax01[1].set_ylabel('Count')
ax01[0].axvline(threshold_sm, color='red', linestyle='--', label='Threshold')
ax01[1].axvline(threshold_ei, color='red', linestyle='--', label='Threshold')
ax01[0].legend()
ax01[1].legend()


In [None]:
high_ei_corr_cells

In [None]:
high_sm_corr_cells

In [None]:
cell_a = 3
cell_b = 5

ei_1 = ac.vcd.get_ei_for_cell(cell_a).ei
ei_2 = ac.vcd.get_ei_for_cell(cell_b).ei
sorted_electrodes = ra.sort_electrode_map(ac.vcd.get_electrode_map())
e1_1 = ra.reshape_ei(ei_1, sorted_electrodes)
e1_2 = ra.reshape_ei(ei_2, sorted_electrodes)
e1_1 = np.log10(np.max(np.abs(e1_1), axis=2)+1e-6)
e1_2 = np.log10(np.max(np.abs(e1_2), axis=2)+1e-6)

fig, ax = plt.subplots(7, 2, figsize=(12, 20))
im1=ax[0,0].imshow(ac.d_spatial_maps[cell_a][:,:,0], cmap='gray')
im2=ax[0,1].imshow(ac.d_spatial_maps[cell_b][:,:,0], cmap='gray')
ax[0,0].set_title(f'Cell {cell_a} Spatial Map')
ax[0,1].set_title(f'Cell {cell_b} Spatial Map')
# im3=ax[1,0].imshow(e1_1, cmap='hot')
# im4=ax[1,1].imshow(e1_2, cmap='hot')
# ax[1,0].set_title(f'Cell {cell_a} EI')
# ax[1,1].set_title(f'Cell {cell_b} EI')
ax1 = ra.ei_utils.plot_ei_map(cell_a, ac.vcd, axs=ax[1:,0])
ax2 = ra.ei_utils.plot_ei_map(cell_b, ac.vcd, axs=ax[1:,1])
# plt.colorbar(im1, ax=ax[0,0])
# plt.colorbar(im2, ax=ax[0,1])
# plt.colorbar(im3, ax=ax[1,0])
# plt.colorbar(im4, ax=ax[1,1])
fig.suptitle(f'EI Corr: {ei_corr[cluster_to_index[cell_a], cluster_to_index[cell_b]]:.2f}, SM Corr: {sm_corr[cluster_to_index[cell_a], cluster_to_index[cell_b]]:.2f}', fontsize=16)
plt.tight_layout()

In [None]:
ra.ei_utils.plot_ei_map(cell_a, ac.vcd)

In [None]:
ac

In [None]:
amps = np.load('/Volumes/data/data/sorted/20250514C/chunk1/kilosort2.5/amplitudes.npy')
templates = np.load('/Volumes/data/data/sorted/20250514C/chunk1/kilosort2.5/spike_templates.npy')

In [None]:
spike_times = np.load('/Volumes/data/data/sorted/20250514C/chunk1/kilosort2.5/spike_times.npy')
print(spike_times.shape)

In [None]:
pc_templates = np.load('/Volumes/data/data/sorted/20250429C/chunk1/kilosort2.5/pc_features.npy',mmap_mode='r')

In [None]:
pc_feature_inds = np.load('/Volumes/data/data/sorted/20250429C/chunk1/kilosort2.5/pc_feature_ind.npy')
spike_templ = np.load('/Volumes/data/data/sorted/20250429C/chunk1/kilosort2.5/spike_templates.npy')
spike_times = np.load('/Volumes/data/data/sorted/20250429C/chunk1/kilosort2.5/spike_times.npy')

In [None]:
cluster_id_a = 199-1
cluster_id_b = 212-1

spike_indx_a = np.where(spike_templ == cluster_id_a)[0]
spike_indx_b = np.where(spike_templ == cluster_id_b)[0]

feat_chan_ids_a = pc_feature_inds[cluster_id_a, :].astype(int)
feat_chan_ids_b = pc_feature_inds[cluster_id_b, :].astype(int)


In [None]:
chan_0_a = feat_chan_ids_a[0]
chan_0_b = feat_chan_ids_b[0]

chan_1_a = feat_chan_ids_a[1]
chan_1_b = feat_chan_ids_b[1]

times_spikes_a = spike_times[spike_indx_a]
times_spikes_b = spike_times[spike_indx_b]

num_chan_pc_combos = 2*2

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(12,12))
for i in range(num_chan_pc_combos):
    for j in range(num_chan_pc_combos):
        ax = axs[i, j]
        if i == j:
            x_a = times_spikes_a / 20000
            y_a = pc_templates[spike_indx_a, i%2, i//2] 
            ax.set_xlabel(f'Time (s)')
            ax.set_ylabel(f'Ch{i//2} PC{i%2}')

            x_b = times_spikes_b / 20000
            y_b = pc_templates[spike_indx_b, j%2, j//2]


        else:
            x_a = pc_templates[spike_indx_a, i%2, i//2]
            y_a = pc_templates[spike_indx_a, j%2, j//2]

            x_b = pc_templates[spike_indx_b, i%2, i//2]
            y_b = pc_templates[spike_indx_b, j%2, j//2]

            ax.set_xlabel(f'Ch{i//2} PC{i%2}')
            ax.set_ylabel(f'Ch{j//2} PC{j%2}')
            x_max_a = np.max(np.abs(x_a))

            x_max_b = np.max(np.abs(x_b))
            ax.set_xlim(min(-x_max_a, -x_max_b), max(x_max_a, x_max_b))

            ax.axhline(0, color='gray', linewidth=0.5)
            ax.axvline(0, color='gray', linewidth=0.5)

        y_max_a = np.max(np.abs(y_a))
        y_max_b = np.max(np.abs(y_b))

        ax.set_ylim(min(-y_max_a, -y_max_b), max(y_max_a, y_max_b))
        ax.scatter(x_a, y_a, s=1, alpha=0.5, color='blue', label=f'Cell {cluster_id_a}')
        ax.scatter(x_b, y_b, s=1, alpha=0.5, color='green', label=f'Cell {cluster_id_b}')
plt.tight_layout()


In [None]:
pc_templates[spike_indx_a, i%2, i//2] 

In [None]:
cluster_id_a

In [None]:
templates_vision = templates + 1
amplitudes = np.vstack((np.squeeze(amps), np.squeeze(templates_vision)))


In [None]:
a = amplitudes[:,amplitudes[1,:]==cell_a]
print(a.shape)

In [None]:
cell_a = 3
cell_b = 5
fig, ax = plt.subplots(1, 1, figsize=(12, 6), constrained_layout=True)
amps1 = amplitudes[0,amplitudes[1,:]==cell_a]
times1 = spike_times[amplitudes[1,:]==cell_a]
amps2 = amplitudes[0,amplitudes[1,:]==cell_b]
times2 = spike_times[amplitudes[1,:]==cell_b]
ax.plot(times1,amps1, 'o', label=f'Cell {cell_a}', alpha=0.5)
ax.plot(times2,amps2, 'o', label=f'Cell {cell_b}', alpha=0.5)
ax_histy = ax.inset_axes([1.05, 0, 0.25, 1], sharey=ax)
ax_histy.tick_params(axis='y',labelleft=False)

min_val = min(np.min(amps1), np.min(amps2))
max_val = max(np.max(amps1), np.max(amps2))

num_bins = 50
bin_edges = np.linspace(min_val, max_val, num_bins + 1)
hist1, _ = np.histogram(amps1, bins=bin_edges)
hist2, _ = np.histogram(amps2, bins=bin_edges)

ax_histy.stairs(hist1, bin_edges, fill=True, alpha=0.5, color='blue', orientation='horizontal', label=f'Cell {cell_a}')
ax_histy.stairs(hist2, bin_edges, fill=True, alpha=0.5, color='orange', orientation='horizontal', label=f'Cell {cell_b}')

intersection = np.sum(np.minimum(hist1, hist2))
total_sum_hist1 = np.sum(hist1)
overlap_fraction = intersection / total_sum_hist1 if total_sum_hist1 > 0 else 0

ax.set_xlabel('Spike Index')
ax.set_ylabel('Amplitude')
ax.set_title(f'Amplitude Comparison for Cells {cell_a} and {cell_b}, Overlap Percentage: {overlap_fraction*100:.2%}%')
ax.legend()

In [None]:
index = np.where(amplitudes[1,:] == cell_a)[0]
templates[index]

In [None]:
import retinanalysis.classes.dedup as dd
importlib.reload(dd)

# ei_autocorr, high_ei_pairs = dd.get_ei_autocorrelation(ac)
# print(ei_autocorr.all() == ei_corr.all())
# print(high_ei_pairs.difference(high_ei_corr_cells))

In [None]:
pcs, ei_pcs, sm_pcs = dd.isolate_problem_cells(block=ac)
# print(pcs.difference(all_problem_cells))
# print(ei_pcs.difference(all_problem_cells))
# print(sm_pcs.difference(all_problem_cells))

In [None]:
fig, ax = dd.plotRFs_dedup(ac)

In [None]:
fig, ax = dd.plot_histograms(ac)

In [None]:
#now using ei and sm pairs, generate a set of groups that are connected by either ei or sm correlation
ei_corr, high_ei_pairs = dd.get_ei_autocorrelation(ac)
sm_corr, high_sm_pairs = dd.get_sm_autocorrelation(ac)

In [None]:
pairs_dict = {}
for a, b in high_sm_pairs:
    if a not in pairs_dict:
        pairs_dict[a] = set()
    if b not in pairs_dict:
        pairs_dict[b] = set()
    pairs_dict[a].add(b)
    pairs_dict[b].add(a)

extended = set()
for origin in high_sm_pairs:
    a, b = origin

    paired_w_a = pairs_dict.get(a, set())
    paired_w_b = pairs_dict.get(b, set())
    all_paired = paired_w_a.union(paired_w_b)
    all_paired_tuple = tuple(sorted(all_paired))
    extended.add(all_paired_tuple)

ext_test = dd.generate_extended_pairings(high_sm_pairs)
print(ext_test == extended)




In [None]:
ext_1 = dd.generate_extended_pairings(high_sm_pairs)
ext_2 = dd.generate_extended_pairings(high_ei_pairs)
ext_all = ext_1.union(ext_2)

In [None]:
group = (625, 628, 1405)
dd.visualize_groups(group, ac, detailed=True);

In [None]:
dd.plot_amplitude_histograms(amplitudes, templates, group)

In [None]:
all_pairs = high_sm_pairs.union(high_ei_pairs)
# for a, b in all_pairs:


In [None]:
ext_test

In [None]:
#now, we want to set up a df to summarize the results of the deduplication process for each pair
#not group, since most metrics are pairwise
from itertools import combinations

header = ['cell_a', 'cell_b', 'sm_corr', 'ei_corr']
stats = []
#columns are: cella, cellb, sm_corr(if available), ei_corr, amplitude histogram overlap fraction
for tup in ext_all:
    group = np.array(tup)
    pairs = list(combinations(group, 2))
    for a, b in pairs:
        cell_a = a
        cell_b = b
        sm_corr_val = sm_corr[cluster_to_index[cell_a], cluster_to_index[cell_b]]
        ei_corr_val = ei_corr[cluster_to_index[cell_a], cluster_to_index[cell_b]]
        stats.append([cell_a, cell_b, sm_corr_val, ei_corr_val])

df_stats = pd.DataFrame(stats, columns=header)
df_stats




In [None]:
summary_stats = dd.get_summary_stats(ac, amplitudes)
summary_stats