In [5]:
from citalopram_project.load import load_spikes, load_neurons
from citalopram_project.correlations import pairwise_correlation_spikes
from citalopram_project.ensemble.humphries import humphries_ensemble, communities_test
from citalopram_project.ensemble.ensembles import  _create_ensemble_stats, _create_ensemble_df, get_ensemble_id, get_ensemble_sig, drop_non_sig_ensembles
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
from citalopram_project.load import get_data_dir


neurons = load_neurons()


In [2]:
class EnsembleError(Exception):
    ...


def detect_ensembles_single(df, bin_width, sigma, n_runs=20, n_boot_coms=1000):
    mod_list, communities_list, cluster_list = [], [], []
    for _ in range(n_runs):
        df_corr = pairwise_correlation_spikes(df.copy(), bin_width=bin_width, sigma=sigma, fillna=0, rectify=True)
        mod, communities, cluster_idx = humphries_ensemble(df_corr)
        mod_list.append(mod)
        communities_list.append(communities)
        cluster_list.append(cluster_idx)
    if np.isnan(mod_list).all():
        raise EnsembleError()
    best_run_idx = np.argmax(mod_list)
    
    mod = mod_list[best_run_idx]
    communities = communities_list[best_run_idx]
    cluster_idx = cluster_list[best_run_idx]
    com_scores, com_score_p_values, com_similarities = communities_test(df_corr, communities, n_boot=n_boot_coms)
    stats_df = _create_ensemble_stats(mod, communities, com_scores, com_score_p_values, com_similarities)
    ensemble_df = _create_ensemble_df(communities)
    return ensemble_df, stats_df



def detect_ensembles(df_spikes, session_col, n_runs=20, bin_width=1, sigma=0, n_boot_coms=1000):
    session_names = df_spikes[session_col].unique()
    stats_dfs = []
    ensembles_dfs = []
    for session in tqdm(session_names):
        df1 = df_spikes[df_spikes[session_col] == session]
        try:
            ensemble_df, stats_df = detect_ensembles_single(df1.copy(), n_runs=n_runs, sigma=sigma, bin_width=bin_width, n_boot_coms=n_boot_coms)
        except EnsembleError:
            print(f"Error in {session}")
            continue
        stats_dfs.append(stats_df.assign(session_name=session))
        ensembles_dfs.append(ensemble_df.assign(session_name=session))
    
    df_stats = pd.concat(stats_dfs)
    df_ensembles = pd.concat(ensembles_dfs)
    df_stats = get_ensemble_sig(df_stats)
    df_stats, df_ensembles = get_ensemble_id(df_stats, df_ensembles)
    df_ensembles = drop_non_sig_ensembles(df_stats, df_ensembles)
    return df_stats, df_ensembles


In [4]:
df_spikes = load_spikes(block_name="pre").merge(neurons[["neuron_id", "session_name", "cluster", "group"]]).loc[lambda x: x.spiketimes < (60 * 30)].loc[lambda x: x.group != "discontinuation"]

df_stats, df_ensembles = detect_ensembles(df_spikes, session_col="session_name", n_runs=20, bin_width=1, sigma=0, n_boot_coms=1000)

  0%|          | 0/12 [00:00<?, ?it/s]

In [6]:

derived_data_dir = get_data_dir() / "derived"
df_stats.to_parquet(derived_data_dir / "spont_ensemble_stats.parquet.gzip", compression="gzip")
df_ensembles.to_parquet(derived_data_dir / "spont_ensembles.parquet.gzip", compression="gzip")


In [7]:
df_spikes = load_spikes(block_name="pre").merge(neurons[["neuron_id", "session_name", "cluster", "group"]]).loc[lambda x: x.spiketimes < (60 * 15)].loc[lambda x: x.group != "discontinuation"]

df_stats, df_ensembles = detect_ensembles(df_spikes, session_col="session_name", n_runs=20, bin_width=1, sigma=0, n_boot_coms=1000)

df_stats.to_parquet(derived_data_dir / "spont_ensemble_stats_FIRST15.parquet.gzip", compression="gzip")
df_ensembles.to_parquet(derived_data_dir / "spont_ensembles_FIRST15.parquet.gzip", compression="gzip")

  0%|          | 0/12 [00:00<?, ?it/s]

In [9]:
df_spikes = load_spikes(block_name="pre").merge(neurons[["neuron_id", "session_name", "cluster", "group"]]).loc[lambda x: (x.spiketimes > (60 * 15)) & (x.spiketimes < (60 * 30)) ].loc[lambda x: x.group != "discontinuation"]


df_stats, df_ensembles = detect_ensembles(df_spikes, session_col="session_name", n_runs=20, bin_width=1, sigma=0, n_boot_coms=1000)

df_stats.to_parquet(derived_data_dir / "spont_ensemble_stats_LAST15.parquet.gzip", compression="gzip")
df_ensembles.to_parquet(derived_data_dir / "spont_ensembles_LAST15.parquet.gzip", compression="gzip")

  0%|          | 0/12 [00:00<?, ?it/s]

Error in hamilton_14
Error in chronic_09
Error in chronic_01
