In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from ecephys_analyses.on_off import get_on_off_df_filename
from ecephys_analyses.data.paths import get_datapath
import pandas as pd

from ecephys_analyses.data.channel_groups import region_depths
from ecephys_analyses.data import paths
import ecephys.units

import itertools 

from utils import *

from joblib import Parallel, delayed

In [None]:
import altair as alt

alt.data_transformers.disable_max_rows()

## USER INPUT BELOW

In [None]:
# DATASETS

data_conditions = [
#     (
#         'Segundo',
#         'sleep-homeostasis-2h_imec0',
#         'ks2_5_catgt_Th=12-10_lam=50_8s-batches_postpro_1_metrics_all_isi',
#     ), etc
]

In [None]:
# PARAMETERS USED IN COMPUTE_...

region_orig = 'all'

selected_groups_orig = ['good', 'mua', 'noise_contam'] # remember that i have either to add noise contam or do selection afterwards 

selection_intervals_orig = {
    'fr': (0.0, float('Inf')),
}

pool = False
assert not pool  # Don't pool

state = 'N2'

detection_condition = 'on_off_threshold_single_unit_2'

root_key = 'sleep-homeostasis'

In [None]:
# PARAMETERS FOR CURRENT ANALYSIS

# Cluster subselection
region = 'cortex_base'

selected_groups = ['good', 'mua']

selection_intervals =  {
    'fr': (0.5, 5.0),
    'isi_viol_2.0': (0.0, 0.25),
#     'isi_viol_1_0': (0.0, 0.25),
#     'contam_rate_1_0': (0.0, float('Inf')),
#     'contam_rate_2_0': (0.0, float('Inf')),
# etc
}

# Removal of outliers (baseline vs recovery FR)
MAD_THRESHOLD = 3 # unit of Median absolute deviation ( 3 (very conservative), 2.5 (moderately conservative) or even 2 (poorly conservative).)

## END USER INPUT

## LOAD DATA AND SUBSET CLUSTERS OF INTEREST

In [None]:
data, off_dat, cluster_info = load_analysis_data(
    data_conditions,
    detection_condition,
    state,
    region_orig,
    selected_groups_orig,
    selection_intervals_orig,
    region,
    selected_groups,
    selection_intervals,
    pool,
    root_key='SD',
)

## Check some info about data

In [None]:
# N clusters
data.groupby(['dataset', 'condition']).nunique()['cluster_id']

In [None]:
# On-Off period durations
data.groupby(['dataset', 'condition', 'state']).describe()['duration']

In [None]:
# Total time spend in state for each condition
data.groupby(['dataset', 'condition']).describe()['condition_state_time']

# Clean outliers (different firing rate for recovery vs baseline)

exclude clusters with MAD of difference above threshold for (within state) FR during baseline vs recovery

In [None]:
on_dat = data[
    (data['state'] == 'on')
    & (data['condition'] != 'interbout')
]

FR_df = on_dat.groupby(['dataset', 'cluster_id', 'condition', 'condition_state_time']).count()['duration'].reset_index()
FR_df = FR_df.rename(columns={'duration': 'n_spikes'})

FR_df['condition_FR'] = FR_df['n_spikes'] / FR_df['condition_state_time']

In [None]:
FR_df[0:5]

In [None]:
# Describe Firing rate DURING STATE OF INTEREST for baseline vs recovery
FR_df.groupby(['dataset', 'condition']).describe()['condition_FR']

In [None]:
FR_df_wide = pd.pivot_table(
    FR_df,
    values = 'condition_FR',
    index=['dataset','cluster_id'],
    columns = 'condition'
).reset_index()

FR_df_wide['diff'] = FR_df_wide['recovery'] - FR_df_wide['baseline']
FR_df_wide['ratio'] = FR_df_wide['recovery'] / FR_df_wide['baseline']
FR_df_wide['ratio_inv'] = 1 / FR_df_wide['ratio']

FR_df_wide[0:5]

In [None]:
# Describe recovery/baseline ratios of FR 
FR_df_wide.groupby(['dataset']).describe()['ratio']

In [None]:
# Describe baseline/recovery ratios_inv of FR 
FR_df_wide.groupby(['dataset']).describe()['ratio_inv']

In [None]:
axes = FR_df_wide['diff'].hist(bins=100, by=FR_df_wide['dataset'], bottom=0.01)

In [None]:
axes = FR_df_wide['ratio'].hist(bins=50, by=FR_df_wide['dataset'], bottom=0.01)
# set log scale
# try:
#     for a in axes.ravel(): a.set_xscale('log')
# except AttributeError:
#     axes.set_xscale('log')

# on_duration_df['ratio'].hist(bins=100, by=on_duration_df['subject'])

In [None]:
axes = FR_df_wide['ratio_inv'].hist(bins=50, by=FR_df_wide['dataset'], bottom=0.01)
# try:
#     for a in axes.ravel(): a.set_xscale('log')
# except AttributeError:
#     axes.set_xscale('log')

In [None]:
# How many NAN ratios (if cluster missing from one of the conditions)

FR_df_wide.isna().sum()


In [None]:
# Fill NaN ratios with Inf

# Ratio or ratio_inv is NaN if the cluster is missing from one of the conditions (that is if it's rate in this condition is 0)
FR_df_wide['ratio'] = FR_df_wide['ratio'].fillna(float('Inf'))
FR_df_wide['ratio_inv'] = FR_df_wide['ratio_inv'].fillna(float('Inf'))


In [None]:
from scipy import stats

for dataset in data.dataset.unique():
    print(dataset, end=': ')
    df = FR_df_wide[
        (FR_df_wide['dataset'] == dataset)
    ]
    
    exclude_clusters = df[
#         ((df['diff'] - df['diff'].median()) > MAD_THRESHOLD * stats.median_abs_deviation(df['diff']))
#         | ((df['diff'] - df['diff'].median()) < - MAD_THRESHOLD * stats.median_abs_deviation(df['diff']))
        ((df['ratio'] - df['ratio'].median()) > MAD_THRESHOLD * stats.median_abs_deviation(df['ratio']))
        | ((df['ratio_inv'] - df['ratio_inv'].median()) > MAD_THRESHOLD * stats.median_abs_deviation(df['ratio_inv']))
    ].cluster_id.unique()
    print(f'exclude N={len(exclude_clusters)}/{len(df.cluster_id.unique())}: {exclude_clusters}')
    
    cluster_ids = sorted([c for c in df.cluster_id.unique() if c not in exclude_clusters])
    off_dat = off_dat[
        (off_dat['dataset'] != dataset)
        | off_dat['cluster_id'].isin(cluster_ids)
    ].copy()
    data = data[
        (data['dataset'] != dataset)
        | data['cluster_id'].isin(cluster_ids)
    ].copy()

    # Also FR_df to plot the histograms without outliers
    FR_df_wide = FR_df_wide[
        (FR_df_wide['dataset'] != dataset)
        | FR_df_wide['cluster_id'].isin(cluster_ids)
    ]
    FR_df = FR_df[
        (FR_df['dataset'] != dataset)
        | FR_df['cluster_id'].isin(cluster_ids)
    ]

In [None]:
axes = FR_df_wide['diff'].hist(bins=50, by=FR_df_wide['dataset'], bottom=0.01)

In [None]:
axes = FR_df_wide['ratio'].hist(bins=50, by=FR_df_wide['dataset'], bottom=0.01)

In [None]:
axes = FR_df_wide['ratio_inv'].hist(bins=50, by=FR_df_wide['dataset'], bottom=0.01)

## Check again some info about data

In [None]:
# N clusters
off_dat.groupby(['dataset', 'condition']).nunique()['cluster_id']

In [None]:
# Describe Firing rate DURING STATE OF INTEREST for baseline vs recovery
FR_df.groupby(['dataset', 'condition']).describe()['condition_FR']

# PLOTS


# TODO: Split by dataset

# DON"T RUN THIS SECTION IF THERE ARE MULTIPLE DATASETS

In [None]:
assert len(off_dat.dataset.unique()) == 1

In [None]:
N_clust = 40

np.random.seed(0)
cluster_ids = sorted(data.cluster_id.unique())
np.random.shuffle(cluster_ids)
cluster_select = cluster_ids[0:0+N_clust]
# cluster_select = cluster_ids[40:40+N_clust]

off_dat = off_dat[off_dat['cluster_id'].isin(cluster_select)]



In [None]:
# Facetted chart order and title
off_dat['unit'] = off_dat.apply(
    lambda row: f"cluster_id={row['cluster_id']}, FR={round(row['cumFR'], 2)}Hz",
    axis=1
)
# Sort by FR
off_dat = off_dat.sort_values(by='cumFR')

In [None]:
# DENSITY OF OFF PERIODS DURATIONS

In [None]:

duration_density = alt.Chart(
    off_dat
).transform_density(
    'duration',
    as_=['duration', 'density'],
    groupby=['subject', 'condition', 'unit']
).mark_area(
    opacity=0.3,
#     interpolate='step'
).encode(
    x="duration:Q",
    y='density:Q',
    color='condition:N',
).properties(
    width=150,
    height=100,
).facet(
    facet=alt.Facet(
        'unit:N',
#         sort=unit_order,
    ),
    columns=5,
).resolve_scale(
    x='independent',
    y='independent',
)

duration_density

In [None]:
# FREQUENCY OF OFF PERIODS DURATIONS

In [None]:
binwidth = 0.25
bins = np.arange(
    off_dat.duration.min(),
#     off_dat.duration.max() + binwidth,
    10 + binwidth,
    binwidth,
)

In [None]:
# Count in each bin

binned_durations = off_dat.groupby(
    [
        'unit',
        'condition',
         pd.cut(off_dat.duration, bins=bins)
    ]
).count().loc[:,'state'].reset_index()
binned_durations['duration_count'] = binned_durations['state']
binned_durations['bin_min'] = binned_durations.apply(lambda row: row['duration'].left, axis=1)
binned_durations['bin_max'] = binned_durations.apply(lambda row: row['duration'].right, axis=1)
binned_durations['bin_center'] = (binned_durations['bin_min'] + binned_durations['bin_max']) / 2

# Normalize by condition duration
cond_durations = {
    cond: off_dat[off_dat.condition == cond].condition_state_time.unique()[0]
    for cond in off_dat.condition.unique()
}
binned_durations['frequency'] = binned_durations.apply(
    lambda row: 60 * row['duration_count'] / cond_durations[row['condition']],
    axis=1
)


In [None]:
binned_durations[0:5]

In [None]:

duration_frequency = alt.Chart(
   binned_durations.drop(columns='duration')
).mark_bar(
    opacity=0.3,
).encode(
    x=alt.X('bin_min:Q'),
    x2=alt.X2('bin_max:Q'),
    y=alt.Y(
        'frequency:Q',
        axis=alt.Axis(
            title="Occurrence (per min)"
        ),
    ),
    color='condition:N',
).properties(
    width=150,
    height=100,
).facet(
    facet=alt.Facet(
        'unit:N',
#         sort=unit_order,
    ),
    columns=5,
).resolve_scale(
    x='independent',
    y='independent',
)

duration_frequency

# Statistics

In [None]:
funcs = ['mean', 'median', 'skew']

stats_df = off_dat.groupby(
    ['dataset', 'condition', 'cluster_id']
).agg({
    'duration': funcs,
})['duration'].reset_index()

stats_df[0:5]

# Boxplots with mean/median/skew

In [None]:
charts = [
    alt.Chart(
        stats_df[stats_df['dataset'] == dataset]
    ).transform_fold(
        fold=funcs,
    ).mark_boxplot(
        color='black',
        extent=0,
    ).encode(
        x=alt.X(
            'condition:N',
        ),
        y=alt.Y(
            'value:Q',
            axis=alt.Axis(
                title='Value',
            ),
            scale=alt.Scale(
                zero=False,
            ),
        ),
        column=alt.Column(
            'key:N',
            header=alt.Header(
                title="Measure"
            ),
        ),
#         row=alt.Row('dataset'),
        color=alt.Color('condition:N')
    ).properties(
        title=f"{dataset}, N={len(stats_df[stats_df['dataset'] == dataset].cluster_id.unique())}",
        width=50,
        height=300
    ).resolve_scale(
        y='independent'
    )
    for dataset in stats_df.dataset.unique()
]

concat = alt.hconcat()
for chart in charts:
    concat = alt.hconcat(concat, chart, spacing=70)

concat.configure_title(
    fontSize=15,
).configure_axis(
    labelFontSize = 12,
    titleFontSize = 14
).configure_header(
    labelFontSize = 14,
    titleFontSize = 14
).configure_legend(
    labelFontSize = 14,
    titleFontSize = 14
)

# Boxplots with median only

In [None]:
charts = [
    alt.Chart(
        stats_df[stats_df['dataset'] == dataset]
    ).transform_fold(
        fold=funcs,
    ).mark_boxplot(
        color='black',
        extent=0,
    ).encode(
        x=alt.X(
            'condition:N',
        ),
        y=alt.Y(
            'median:Q',
            axis=alt.Axis(
                title="Units' median off period duration",
            ),
            scale=alt.Scale(
                zero=False,
            ),
        ),
#         row=alt.Row('dataset'),
        color=alt.Color('condition:N')
    ).properties(
        title=f"{dataset}, N={len(stats_df[stats_df['dataset'] == dataset].cluster_id.unique())}",
        width=50,
        height=300
    )
    for dataset in stats_df.dataset.unique()
]

concat = alt.hconcat()
for chart in charts:
    concat = alt.hconcat(concat, chart, spacing=70)

concat.configure_title(
    fontSize=15,
).configure_axis(
    labelFontSize = 12,
    titleFontSize = 14
).configure_header(
    labelFontSize = 14,
    titleFontSize = 14
).configure_legend(
    labelFontSize = 14,
    titleFontSize = 14
)

In [None]:
import scipy.stats

scipy.stats.ttest_rel(
    stats_df[stats_df['condition'] == 'baseline'].sort_values(by='cluster_id')['skew'],
    stats_df[stats_df['condition'] == 'recovery'].sort_values(by='cluster_id')['skew'],
    axis=0)

In [None]:
scipy.stats.ttest_rel(
    stats_df[stats_df['condition'] == 'baseline'].sort_values(by='cluster_id')['mean'],
    stats_df[stats_df['condition'] == 'recovery'].sort_values(by='cluster_id')['mean'],
    axis=0)

In [None]:
scipy.stats.ttest_rel(
    stats_df[stats_df['condition'] == 'baseline'].sort_values(by='cluster_id')['median'],
    stats_df[stats_df['condition'] == 'recovery'].sort_values(by='cluster_id')['median'],
    axis=0)