In [355]:
import pandas as pd
import dill
import numpy as np
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache
from rich import print
from tqdm.auto import tqdm
from time import perf_counter
from operator import itemgetter
from scipy.stats import zscore
from Utils.Settings import window_spike_hist, output_folder_calculations, neuropixel_dataset, var_thr, minimum_ripples_count_spike_analysis, minimum_ripples_count_generated_in_lateral_or_medial_spike_analysis
from Utils.Utils import acronym_to_main_area, clean_ripples_calculations, find_ripples_clusters_new, \
    batch_process_spike_hists_by_seed_location, process_spike_hists
import matplotlib.pyplot as plt
from multiprocessing import Pool, Manager

In [160]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

In [39]:

all_areas_recorded = [item for sublist in  sessions["ecephys_structure_acronyms"].to_list() for item in sublist]

In [41]:
#pd.Series(all_areas_recorded).value_counts()#.plot.hist()


In [33]:
t1_start = perf_counter()

manifest_path = f"{neuropixel_dataset}/manifest.json"

cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

sessions = cache.get_session_table()
#ProfileReport(sessions)

with open(f'{output_folder_calculations}/clean_ripples_calculations.pkl', 'rb') as f:
    ripples_calcs = dill.load(f)


spike_hists = {}

input_rip = []
for session_id in ripples_calcs.keys():
    ripples = ripples_calcs[session_id][3].copy()
    ripples = ripples.groupby("Probe number-area").filter(lambda group: group["∫Ripple"].var() > var_thr)
    input_rip.append(ripples.groupby("Probe number-area").mean()["L-R (µm)"])

lr_space = pd.concat(input_rip)

medial_lim = lr_space.quantile(.33333)
lateral_lim = lr_space.quantile(.666666)
center = lr_space.median()
medial_lim_lm = medial_lim - 5691.510009765625
lateral_lim_lm = lateral_lim - 5691.510009765625

def l_m_classifier(row):
    if row["Source M-L (µm)"] < medial_lim_lm:
        v = "Medial"
    elif row["Source M-L (µm)"] > lateral_lim_lm:
        v = "Lateral"
    else:
        v = "Central"
    return v

In [None]:
ripples_calcs.keys()

dict_keys([715093703, 719161530, 721123822, 743475441, 744228101, 750332458, 750749662, 751348571, 754312389, 754829445, 755434585, 757216464, 757970808, 758798717, 759883607, 760345702, 761418226, 762602078, 763673393, 766640955, 767871931, 768515987, 771160300, 771990200, 773418906, 774875821, 778240327, 778998620, 779839471, 781842082, 786091066, 787025148, 789848216, 791319847, 793224716, 794812542, 797828357, 798911424, 799864342, 816200189, 819186360, 819701982, 821695405, 829720705, 831882777, 835479236, 839068429, 840012044, 847657808])

In [351]:
session_id = 771990200
spikes_summary = {}

In [359]:
def process(session_id, spikes_summary):
    print(session_id)

    session = cache.get_session_data(session_id)  # , amplitude_cutoff_maximum = np.inf, presence_ratio_minimum = -np.inf, isi_violations_maximum = np.inf)

    units = session.units
    units["parent area"] = units["ecephys_structure_acronym"].apply(lambda area: acronym_to_main_area(area))


    # #  each area, change output name accordingly
    areas = np.delete(units["parent area"].unique(),
                      np.argwhere(units["parent area"].unique() == "grey"))  # delete grey if present

    if "HPF" in areas:
        areas = ["HPF"]
    else:
        return

    print(f"In session {session_id} areas recorded: {areas}")

    spike_times = session.spike_times

    ripples = ripples_calcs[session_id][3].copy()

    sel_probe = ripples_calcs[session_id][5]

    print(f"number ripples on best probe: {ripples[ripples['Probe number'] == sel_probe].shape[0]}")

    if ripples[ripples['Probe number'] == sel_probe].shape[0] < minimum_ripples_count_spike_analysis:
        return

    ripples = ripples.groupby("Probe number-area").filter(lambda group: group["∫Ripple"].var() > var_thr)

    ripples = ripples.sort_values(by="Start (s)").reset_index(drop=True)
    ripples = ripples[ripples["Area"] == "CA1"]
    ripples = ripples.reset_index().rename(columns={'index': 'Ripple number'})

    print(session_id, "Recording in each ML section:",
          np.any(ripples["L-R (µm)"].unique() < medial_lim) & np.any(ripples["L-R (µm)"].unique() > lateral_lim) & \
          np.any((ripples["L-R (µm)"].unique() > medial_lim) & (ripples["L-R (µm)"].unique() < lateral_lim)))

    if np.any(ripples["L-R (µm)"].unique() < medial_lim) & np.any(ripples["L-R (µm)"].unique() > lateral_lim) & \
            np.any((ripples["L-R (µm)"].unique() > medial_lim) & (ripples["L-R (µm)"].unique() < lateral_lim)) == False:
        return

    ripples["Local strong"] = ripples.groupby("Probe number").apply(
        lambda x: x["∫Ripple"] > x["∫Ripple"].quantile(.9)).sort_index(level=1).values

    try:
        ripples["Z-scored ∫Ripple"] = ripples.groupby("Probe number-area").apply(
            lambda group: zscore(group["∫Ripple"], ddof=1)).droplevel(0)
    except:
        ripples["Z-scored ∫Ripple"] = ripples.groupby("Probe number-area").apply(
            lambda group: zscore(group["∫Ripple"], ddof=1)).T

    to_loop = []
    to_loop.append((ripples.groupby("Probe number-area").mean()['L-R (µm)'].sub(center).abs().idxmin(), "central"))

    print(f"In session {session_id} process:{to_loop}")
    if ripples.shape[0] > 0:
        for source_area, type_source in to_loop:

            real_ripple_summary = find_ripples_clusters_new(ripples, source_area)
            real_ripple_summary = real_ripple_summary[real_ripple_summary["Spatial engagement"]>.5]
            real_ripple_summary["Location seed"] = real_ripple_summary.apply(l_m_classifier, axis=1)

            print(f"in {session_id}, medial ripples number: {real_ripple_summary[real_ripple_summary['Location seed'] == 'Medial'].shape[0]}, " \
                  f"lateral ripples number: {real_ripple_summary[real_ripple_summary['Location seed'] == 'Lateral'].shape[0]}")
            if (real_ripple_summary[real_ripple_summary["Location seed"] == "Medial"].shape[0] < minimum_ripples_count_generated_in_lateral_or_medial_spike_analysis) or \
                    (real_ripple_summary[real_ripple_summary["Location seed"] == "Lateral"].shape[0] < minimum_ripples_count_generated_in_lateral_or_medial_spike_analysis):
                return

    field_to_use_to_compare = "parent area"
    target_area = "HPF"

    #space_sub_spike_times = dict(zip(units[units[field_to_use_to_compare] == target_area].index,
     #                                itemgetter(*units[units[field_to_use_to_compare] == target_area].index)(
     #                                    spike_times)))
    space_sub_spike_times = dict(zip(units.index,
                                     itemgetter(*units.index)(
                                         spike_times)))
    window = [.10, .12]
    spikes_per_ripple_medial = {}
    for cluster_id, spikes in tqdm(space_sub_spike_times.items()):
        _ = []
        ripple_start = []
        for index, row in real_ripple_summary[real_ripple_summary["Location seed"] == "Medial"].iterrows():
            time_center = row["Start (s)"] + row[row.index.str.contains('lag')].min()# either sum zero or sum a negative value
            _.append(spikes[(spikes > time_center - window[0]) & (spikes < time_center + window[1])]-time_center)
            ripple_start.append(time_center)
        spikes_per_ripple_medial[cluster_id] = (_, ripple_start) 
    
    _ = [(key, np.concatenate(spikes_per_ripple_medial[key][0]), len(spikes_per_ripple_medial[key][0])) for key in spikes_per_ripple_medial.keys()]

    out_medial = []
    for key, q, num_ripples in _:
        if (len(q[q<0])/num_ripples)>0:
            den = (len(q[q<0])/num_ripples)
        else:
            den = 1
        out_medial.append((key, ((len(q[q>0])/num_ripples)/1.2)/den))

    spikes_per_ripple_lateral = {}
    for cluster_id, spikes in tqdm(space_sub_spike_times.items()):
        _ = []
        ripple_start = []
        for index, row in real_ripple_summary[real_ripple_summary["Location seed"] == "Lateral"].iterrows():
            time_center = row["Start (s)"] + row[row.index.str.contains('lag')].min()# either sum zero or sum a negative value
            _.append(spikes[(spikes > time_center - window[0]) & (spikes < time_center + window[1])]-time_center)
            ripple_start.append(time_center)
        spikes_per_ripple_lateral[cluster_id] = (_, ripple_start) 

    _ = [(key, np.concatenate(spikes_per_ripple_lateral[key][0]), len(spikes_per_ripple_lateral[key][0])) for key in spikes_per_ripple_lateral.keys()]

    out_lateral = []
    for key, q, num_ripples in _:
        if (len(q[q<0])/num_ripples)>0:
            den = (len(q[q<0])/num_ripples)
        else:
            den = 1
        out_lateral.append((key, ((len(q[q>0])/num_ripples)/1.2)/den))



    mod_df = pd.concat([pd.DataFrame(out_lateral, columns=["unit_id", "modulation lat"]), pd.DataFrame(out_medial, columns=["unit_id", "modulation med"])["modulation med"]], axis=1).set_index("unit_id")

    out = pd.concat([units, mod_df], axis=1).groupby("ecephys_structure_acronym")[["modulation lat", "modulation med"]].mean()
    spikes_summary[session_id] = out
    print(out)
    return spikes_summary

In [None]:
%%time 
#%%capture --no-stdout
manager = Manager()

spikes_summary = manager.dict()

input_multiprocessing = []
for ecephys_session_id in  ripples_calcs.keys():
        input_multiprocessing.append((ecephys_session_id, spikes_summary)) 

pool = Pool(processes=30) # Instantiate the pool here

pool.starmap_async(process, [x for x in input_multiprocessing])
pool.close()
pool.join()
print(len(spikes_summary))



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1819/1819 [00:23<00:00, 77.72it/s]


  0%|          | 0/546 [00:00<?, ?it/s]

  0%|          | 0/546 [00:00<?, ?it/s]