In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from ecephys_analyses.on_off import run_on_off_detection, get_on_off_df_filename
from ecephys_analyses.data.paths import get_datapath
import pandas as pd

from ecephys_analyses.data.channel_groups import region_depths
from ecephys_analyses.data import paths
import ecephys.io.load

from joblib import Parallel, delayed

In [None]:
import altair as alt

alt.data_transformers.disable_max_rows()

In [None]:
data_conditions = [
    (
        'Doppio',
        'sleep-homeostasis-2h_imec1',
        'ks2_5_catgt_Th=12-10_lam=50_8s-batches_postpro_1',
        'cortex'
    ),
]
# subject = 'Doppio'
# condition = 'sleep-homeostasis-2h_imec1'

FR_interval = (0.5, 5) # Hz

good_only = True

pool = False

state = 'N2'

detection_condition = 'on_off_threshold_single_unit_1'

n_jobs = 1

In [None]:
assert not pool

In [None]:

dfs = []
for (
    subject, condition, sorting_condition, region
) in data_conditions:

    filename = get_on_off_df_filename(region, good_only, pool) + '.csv'
    df = pd.read_csv(
        get_datapath(
            subject,
            condition,
            detection_condition,
        )/filename
    )
    df['subject'] = subject
    df['orig_condition'] = condition
    df['region'] = region
    df['good_only'] = good_only
    df['pool'] = pool
    dfs.append(df)

data = pd.concat(dfs)

off_dat = data[
    (data['state'] == 'off')
    & (data['condition'] != 'interbout')
]
cluster_ids = sorted(data.cluster_id.unique())




assert len(off_dat.subject.unique()) == 1

In [None]:
# Rename conditions
for cond in off_dat.condition.unique():
    if 'baseline' in cond:
        off_dat.replace(cond, 'baseline', inplace=True)
    if 'recovery' in cond:
        off_dat.replace(cond, 'recovery', inplace=True)

off_dat.condition.unique()

In [None]:
N_clust = 40

np.random.seed(0)
cluster_ids = sorted(data.cluster_id.unique())
np.random.shuffle(cluster_ids)
cluster_select = cluster_ids[40:40+N_clust]

off_dat = off_dat[off_dat['cluster_id'].isin(cluster_select)]



In [None]:
# Facetted chart order and title
off_dat['unit'] = off_dat.apply(
    lambda row: f"cluster_id={row['cluster_id']}, FR={round(row['cumFR'], 2)}Hz",
    axis=1
)
# Sort by FR
off_dat = off_dat.sort_values(by='cumFR')

In [None]:
# DENSITY OF OFF PERIODS DURATIONS

In [None]:

duration_density = alt.Chart(
    off_dat
).transform_density(
    'duration',
    as_=['duration', 'density'],
    groupby=['subject', 'condition', 'unit']
).mark_area(
    opacity=0.3,
#     interpolate='step'
).encode(
    x="duration:Q",
    y='density:Q',
    color='condition:N',
).properties(
    width=150,
    height=100,
).facet(
    facet=alt.Facet(
        'unit:N',
#         sort=unit_order,
    ),
    columns=5,
).resolve_scale(
    x='independent',
    y='independent',
)

duration_density

In [None]:
# FREQUENCY OF OFF PERIODS DURATIONS

In [None]:
binwidth = 0.25
bins = np.arange(
    off_dat.duration.min(),
#     off_dat.duration.max() + binwidth,
    10 + binwidth,
    binwidth,
)

In [None]:
# Count in each bin

binned_durations = off_dat.groupby(
    [
        'unit',
        'condition',
         pd.cut(off_dat.duration, bins=bins)
    ]
).count().loc[:,'state'].reset_index()
binned_durations['duration_count'] = binned_durations['state']
binned_durations['bin_min'] = binned_durations.apply(lambda row: row['duration'].left, axis=1)
binned_durations['bin_max'] = binned_durations.apply(lambda row: row['duration'].right, axis=1)
binned_durations['bin_center'] = (binned_durations['bin_min'] + binned_durations['bin_max']) / 2

# Normalize by condition duration
cond_durations = {
    cond: off_dat[off_dat.condition == cond].condition_state_time.unique()[0]
    for cond in off_dat.condition.unique()
}
binned_durations['frequency'] = binned_durations.apply(
    lambda row: 60 * row['duration_count'] / cond_durations[row['condition']],
    axis=1
)


In [None]:
binned_durations[0:5]

In [None]:

duration_frequency = alt.Chart(
   binned_durations.drop(columns='duration')
).mark_bar(
    opacity=0.3,
).encode(
    x=alt.X('bin_min:Q'),
    x2=alt.X2('bin_max:Q'),
    y=alt.Y(
        'frequency:Q',
        axis=alt.Axis(
            title="Occurrence (per min)"
        ),
    ),
    color='condition:N',
).properties(
    width=150,
    height=100,
).facet(
    facet=alt.Facet(
        'unit:N',
#         sort=unit_order,
    ),
    columns=5,
).resolve_scale(
    x='independent',
    y='independent',
)

duration_frequency

In [None]:
# Statistics

In [None]:
funcs = ['mean', 'median', 'sum', 'skew']

stats_df = off_dat.groupby(
    ['subject', 'condition', 'cluster_id']
).agg({
    'duration': funcs,
})['duration'].reset_index()

stats_df[0:5]

In [None]:
alt.Chart(stats_df).transform_fold(
    fold=funcs,
).mark_boxplot(
    color='black',
    extent=0,
).encode(
    x=alt.X(
        'condition:N',
    ),
    y=alt.Y(
        'value:Q',
        axis=alt.Axis(
            title='Value',
        ),
        scale=alt.Scale(
            zero=False,
        ),
    ),
    column=alt.Column(
        'key:N',
        header=alt.Header(
            title='Measure'
        ),
    ),
    color=alt.Color('condition:N')
).properties(
    width=50,
    height=300
).resolve_scale(
    y='independent'
)

In [None]:
import scipy.stats

scipy.stats.ttest_rel(
    stats_df[stats_df['condition'] == 'baseline'].sort_values(by='cluster_id')['skew'],
    stats_df[stats_df['condition'] == 'recovery'].sort_values(by='cluster_id')['skew'],
    axis=0)

In [None]:
scipy.stats.ttest_rel(
    stats_df[stats_df['condition'] == 'baseline'].sort_values(by='cluster_id')['mean'],
    stats_df[stats_df['condition'] == 'recovery'].sort_values(by='cluster_id')['mean'],
    axis=0)