In [12]:
import spikeinterface.full as si
from probeinterface.plotting import plot_probe, plot_probegroup
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import gridspec
from scipy import stats
import os
from Helpers import plot_utility
from astropy.nddata import block_reduce
import matplotlib.ticker as ticker

In [13]:

def plot_firing_rate_maps_per_trial_2(cluster_spike_data, track_length, save_path=None, ax=None):
    firing_times_cluster = cluster_spike_data["firing_times_vr"].iloc[0]
    if len(firing_times_cluster)>1:
        cluster_firing_maps = np.array(cluster_spike_data['fr_binned_in_space_smoothed'].iloc[0])
        cluster_firing_maps[np.isnan(cluster_firing_maps)] = 0
        cluster_firing_maps[np.isinf(cluster_firing_maps)] = 0
        percentile_99th_display = np.nanpercentile(cluster_firing_maps, 95)
        cluster_firing_maps = min_max_normalize(cluster_firing_maps)
        percentile_99th = np.nanpercentile(cluster_firing_maps, 95)
        cluster_firing_maps = np.clip(cluster_firing_maps, a_min=0, a_max=percentile_99th)
        vmin, vmax = get_vmin_vmax(cluster_firing_maps)

        locations = np.arange(0, len(cluster_firing_maps[0]))
        ordered = np.arange(0, len(cluster_firing_maps), 1)
        X, Y = np.meshgrid(locations, ordered)
        if ax is None:
            fig = plt.figure()
            fig.set_size_inches(5, 5, forward=True)
            ax = fig.add_subplot(1, 1, 1)

        ax.pcolormesh(X, Y, cluster_firing_maps, shading="auto", vmin=vmin, vmax=vmax)
        ax.set_title(str(np.round(percentile_99th_display, decimals=1))+" Hz", fontsize=20)
        ax.set_ylabel('Trial Number', fontsize=20, labelpad = 20)
        ax.set_xlabel('Location (cm)', fontsize=20, labelpad = 20)
        ax.set_xlim([0, track_length])
        ax.set_ylim([0, len(cluster_firing_maps)-1])
        ax.tick_params(axis='both', which='both', labelsize=20)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        tick_spacing = 100
        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')
        #cbar = spikes_on_track.colorbar(c, ax=ax, fraction=0.046, pad=0.04)
        #cbar.set_label('Firing Rate (Hz)', rotation=270, fontsize=20)
        #cbar.set_ticks([0,np.max(cluster_firing_maps)])
        #cbar.set_ticklabels(["0", "Max"])
        #cbar.ax.tick_params(labelsize=20)
        plt.subplots_adjust(hspace = .35, wspace = .35,  bottom = 0.2, left = 0.3, right = 0.87, top = 0.92)
        if save_path is not None:
            plt.savefig(save_path + '/firing_rate_map_trials_' + 
                        spike_data.session_id.iloc[cluster_index] + '_' + 
                        str(int(cluster_id)) + '.png', dpi=300) 

In [14]:
def plot_firing_rate_maps_short(cluster_data, track_length=200, ax=None, save_path=None):
    firing_times_cluster = cluster_data["firing_times_vr"].iloc[0]
    cluster_id = cluster_data["cluster_id"].iloc[0]

    if len(firing_times_cluster)>1:
        cluster_firing_maps = np.array(cluster_data['fr_binned_in_space_smoothed'].iloc[0])
        cluster_firing_maps[np.isnan(cluster_firing_maps)] = np.nan
        cluster_firing_maps[np.isinf(cluster_firing_maps)] = np.nan

        if ax is None: 
            spikes_on_track = plt.figure()
            spikes_on_track.set_size_inches(5, 5/3, forward=True)
            ax = spikes_on_track.add_subplot(1, 1, 1)
        
        locations = np.arange(0, len(cluster_firing_maps[0]))
        ax.fill_between(locations, np.nanmean(cluster_firing_maps, axis=0) - stats.sem(cluster_firing_maps, axis=0,nan_policy="omit"),
                                    np.nanmean(cluster_firing_maps, axis=0) + stats.sem(cluster_firing_maps, axis=0,nan_policy="omit"), color="black", alpha=0.2)
        ax.plot(locations, np.nanmean(cluster_firing_maps, axis=0), color="black", linewidth=1)
        
        #plt.ylabel('FR (Hz)', fontsize=25, labelpad = 10)
        #plt.xlabel('Location (cm)', fontsize=25, labelpad = 10)
        plt.xlim(0, track_length)
        ax.tick_params(axis='both', which='both', labelsize=20)
        ax.set_xlim([0, track_length])
        max_fr = max(np.nanmean(cluster_firing_maps, axis=0)+stats.sem(cluster_firing_maps, axis=0))
        max_fr = max_fr+(0.1*(max_fr))
        #ax.set_ylim([0, max_fr])
        ax.set_yticks([0, np.round(ax.get_ylim()[1], 1)])
        ax.set_ylim(bottom=0)
        plot_utility.style_track_plot(ax, track_length, alpha=0.15)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(100))
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')
        plt.subplots_adjust(hspace = .35, wspace = .35,  bottom = 0.2, left = 0.3, right = 0.87, top = 0.92)
        if save_path is not None:
            plt.savefig(save_path + '/avg_firing_rate_maps_short_' + cluster_data.session_id.iloc[0] + '_' + str(int(cluster_id)) + '.png', dpi=300)              

In [15]:
def min_max_normalize(x):
    """
        argument
            - x: input image data in numpy array [32, 32, 3]
        return
            - normalized x
    """
    min_val = np.min(x)
    max_val = np.max(x)
    x = (x-min_val) / (max_val-min_val)
    return x

In [16]:

def get_vmin_vmax(cluster_firing_maps, bin_cm=8):
    cluster_firing_maps_reduced = []
    for i in range(len(cluster_firing_maps)):
        cluster_firing_maps_reduced.append(block_reduce(cluster_firing_maps[i], bin_cm, func=np.mean))
    cluster_firing_maps_reduced = np.array(cluster_firing_maps_reduced)
    vmin= 0
    vmax= np.max(cluster_firing_maps_reduced)
    return vmin, vmax 

In [17]:
def ramp_score(ramp_data_vr):
    ramp_data_vr = ramp_data_vr.fillna('None') # replace None types with None strings for ease
    df = pd.DataFrame()
    for cluster_id in np.unique(ramp_data_vr["cluster_id"]):

        outbound = ramp_data_vr[(ramp_data_vr["cluster_id"] == cluster_id) &
                                (ramp_data_vr["trial_type"] == "None") &
                                (ramp_data_vr["hit_miss_try"] == "None") &
                                (ramp_data_vr["track_length"] == "outbound")]["ramp_class"].iloc[0]
        homebound = ramp_data_vr[(ramp_data_vr["cluster_id"] == cluster_id) &
                                 (ramp_data_vr["trial_type"] == "None") &
                                 (ramp_data_vr["hit_miss_try"] == "None") &
                                 (ramp_data_vr["track_length"] == "homebound")]["ramp_class"].iloc[0]
        ramp_class = outbound+homebound
        cluster_df = pd.DataFrame({'cluster_id': [cluster_id],
                                   'outbound_homebound_ramp_class': [ramp_class]})
        df = pd.concat([df, cluster_df], ignore_index=True)
    return df


## M21 ramp cells

In [18]:
project_path = "/mnt/datastore/Harry/Cohort11_april2024/derivatives/"
# get sorting analyzer and unit locations
of_session_base_names = ["M21_D26_2024-05-28_16-35-31_OF1", "M21_D23_2024-05-25_16-07-17_OF1", "M21_D25_2024-05-27_15-35-57_OF1", "M21_D24_2024-05-26_15-58-23_OF1"]
vr_session_base_names = ["M21_D26_2024-05-28_17-04-41_VR1", "M21_D23_2024-05-25_16-54-12_VR1", "M21_D25_2024-05-27_16-00-30_VR1", "M21_D24_2024-05-26_16-35-19_VR1"]
 
M21_master_data = pd.DataFrame()
for vr_name, of_name in zip(vr_session_base_names, of_session_base_names):
    mouse = vr_name.split("_")[0]
    day = vr_name.split("_")[1]
    sorting_analyzer_path = f"{project_path}{mouse}/{day}/ephys/sorting_analyzer"
    vr_path = f"{project_path}{mouse}/{day}/vr/{vr_name}/processed/kilosort4/spikes.pkl"
    ramp_path = f"{project_path}{mouse}/{day}/vr/{vr_name}/processed/kilosort4/ramp_classifications.pkl"
    of_path = f"{project_path}{mouse}/{day}/of/{of_name}/processed/kilosort4/spikes.pkl"

    spike_data_vr = pd.read_pickle(vr_path); spike_data_vr["firing_times_vr"] = spike_data_vr["firing_times"]; spike_data_vr["session_id_vr"] = vr_name
    spike_data_of = pd.read_pickle(of_path); spike_data_of["firing_times_of"] = spike_data_of["firing_times"]; spike_data_of["session_id_of"] = of_name
    ramp_data_vr = pd.read_pickle(ramp_path)
    ramp_data_vr = ramp_score(ramp_data_vr)
    spike_data = pd.merge(spike_data_vr, spike_data_of, on="cluster_id")
    spike_data = pd.merge(spike_data, ramp_data_vr, on="cluster_id")
    M21_master_data = pd.concat([M21_master_data, spike_data])

### curate
print(f"pre-curation M21: {len(M21_master_data)}")
M21_master_data = M21_master_data[(M21_master_data["snr_x"] > 1) & (M21_master_data["mean_firing_rate_x"] > 0.5) & (M21_master_data["rp_contamination_x"] < 0.9)]
print(f"post-curation M21: {len(M21_master_data)}")



pre-curation M21: 1734
post-curation M21: 1070


In [20]:
for ramp_class, c, m, in zip(["UNUN", "--", "-+", "-UN", "++", "+-", "+UN"], 
                             [ "grey", "palevioletred", "palevioletred", "palevioletred", "palegreen", "palegreen", "palegreen"], 
                             ["x", "o", "^", "+", "o", "^", "+"]):
    subset = M21_master_data[M21_master_data['outbound_homebound_ramp_class'] == ramp_class]
    
    for i, cluster_df in subset.iterrows():
        cluster_df = cluster_df.to_frame().T.reset_index(drop=True)
        grid_score = cluster_df["grid_score"].iloc[0]
        cluster_id = cluster_df["cluster_id"].iloc[0]
        session_id_vr = cluster_df["session_id_vr"].iloc[0] 
        fig, ax = plt.subplots(figsize=(4, 4)) 
        plot_firing_rate_maps_per_trial_2(cluster_df, track_length=200, ax=ax, save_path=None)
        plt.savefig(f"/mnt/datastore/Harry/plot_viewer/ramp_maps/{ramp_class}_c{cluster_id}_long_{session_id_vr}.png", dpi=100)
        plt.close()
        fig, ax = plt.subplots(figsize=(4, 2)) 
        plot_firing_rate_maps_short(cluster_df, track_length=200, ax=ax, save_path=None)
        ax.set_title(f"RC: {ramp_class}, GS: {np.round(grid_score, decimals=2)}")
        plt.savefig(f"/mnt/datastore/Harry/plot_viewer/ramp_maps/{ramp_class}_c{cluster_id}_short_{session_id_vr}.png", dpi=100)
        plt.close()
        plt.close('all') 
 
        
        