In [None]:
%reload_ext autoreload
%autoreload 2

import acr
import kdephys as kde
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
import plotly.express as px

import warnings
warnings.filterwarnings('ignore')

# FUNCTIONS

# Analysis

In [None]:
all_subjects = acr.nor.get_all_subjects()

In [None]:
sleep_subjects = []
for subject in all_subjects:
    sub_type = acr.nor.get_subject_type(subject)
    if sub_type == 'sleep':
        sleep_subjects.append(subject)
sleep_subjects

In [None]:
vel_dfs = {}
for subject in sleep_subjects:
    vel = acr.dlc.load_nor_actigraphy(subject)
    vel_dfs[subject] = vel

In [None]:
# get the mean of the velocity df (i.e. across all nodes), then do a robust z-score, then threshold sleep at -(mean/std), per validation on ephys mice. Only select the time of interest
# (i.e. the recovery period) after doing this full pipeline.
sub_dfs = []
c = -1
for subject in vel_dfs.keys():
    acqday_start, recovery_start, recovery_end = acr.dlc.load_sleep_recovery_info(vel_dfs[subject], subject, recovery_duration='1h', buffer='5min')
    v = vel_dfs[subject]
    #v = v.ts(acqday_start-pd.Timedelta('24h'), acqday_start+pd.Timedelta('36h'))
    spd = v.group_by('datetime').agg(pl.col('speed').mean()).sort(['datetime'])
    spd = acr.dlc.rob_z_col(spd, 'speed')
    mean = spd['speed_robust_z'].mean()
    std = spd['speed_robust_z'].std()
    base = mean/std
    thresh_sleep = base*c
    thresh_sleep_pct = acr.dlc.threshold_sleep_on_diff_df(spd.ts(recovery_start, recovery_end), thresh_sleep, col='speed_robust_z')
    sub_df = pd.DataFrame({'subject': subject, 'thresh_sleep': thresh_sleep_pct, 'mean': mean, 'std': std}, index=[0])
    sub_dfs.append(sub_df)
sdf = pd.concat(sub_dfs)

In [None]:
f, ax = plt.subplots(1, 1, figsize=(28, 8))
sns.barplot(x='subject', y='thresh_sleep', data=sdf, ax=ax)
ax.axhline(y=0.33, color='k', linestyle='--')
ax.axhline(y=0.3, color='r', linestyle='--')
ax.axhline(y=0.25, color='green', linestyle='--')
ax.axhline(y=0.20, color='pink', linestyle='--')

In [None]:
sdf['thresh_sleep'].mean()

In [None]:
sdf['thresh_sleep'].std()

In [None]:
thresh_sleep = base*c
        thresh_sleep_pct = acr.dlc.threshold_sleep_on_diff_df(spd, thresh_sleep, col='speed_robust_z')
        thresh_sleeps.append(thresh_sleep_pct)
    sub_df = pd.DataFrame({'subject': sub, 'rec': rec, 'thresh_sleep': thresh_sleeps, 'h_sleep': h_sleep, 'coeffs': coeff_tests, 'mean': mean, 'std': std})
    sub_dfs.append(sub_df)

sdf = pd.concat(sub_dfs)
sdf['diff'] = sdf['thresh_sleep'] - sdf['h_sleep']

In [None]:
for subject in sleep_subjects:

    box_time, acq_day = acr.nor.get_sub_timing(subject)
    acqday_start = pd.Timestamp(f'{acq_day} {box_time}')
    diffs = load_nor_actigraphy(subject)
    diff_mean = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').mean()).sort(['datetime'])
    diff_max = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').max()).sort(['datetime'])

    mean_thresh = 0.12
    diff_mean = thresh_df(diff_mean, mean_thresh)

    max_thresh = 0.4
    diff_max = thresh_df(diff_max, max_thresh)
    min_sleep_start = acqday_start + pd.Timedelta('10min')
    sleep_period_border = diff_mean.filter(pl.col('datetime')>min_sleep_start)['datetime'].min()
    sleep_period_start = pd.Timestamp(sleep_period_border + pd.Timedelta('15min'))
    sleep_period_end = sleep_period_start + pd.Timedelta('1h')

    mean_frac = diff_mean.ts(sleep_period_start, sleep_period_end).frac_sleep()
    max_frac = diff_max.ts(sleep_period_start, sleep_period_end).frac_sleep()

    f, ax = plt.subplots(1, 1, figsize=(28, 8))
    diff_to_plot = diff_mean.ts(sleep_period_start-pd.Timedelta('1h'), sleep_period_end+pd.Timedelta('3h'))
    sns.lineplot(x='datetime', y='diff', data=diff_mean, ax=ax, color='k')
    ax.axvline(x=sleep_period_start, color='r', linestyle='--')
    ax.axvline(x=sleep_period_end, color='r', linestyle='--')

    ax.set_title(f'{subject} | mean-sleep-frac: {mean_frac*100:.2f}% | max-sleep-frac: {max_frac*100:.2f}%')
    plt.show()

In [None]:
for subject in sleep_subjects:

    box_time, acq_day = acr.nor.get_sub_timing(subject)
    acqday_start = pd.Timestamp(f'{acq_day} {box_time}')
    diffs = load_nor_actigraphy(subject)
    diff_mean = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').mean()).sort(['datetime'])
    diff_max = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').max()).sort(['datetime'])

    mean_thresh = 0.12
    diff_mean = thresh_df(diff_mean, mean_thresh)

    max_thresh = 0.4
    diff_max = thresh_df(diff_max, max_thresh)
    min_sleep_start = acqday_start + pd.Timedelta('10min')
    sleep_period_border = diff_mean.filter(pl.col('datetime')>min_sleep_start)['datetime'].min()
    sleep_period_start = pd.Timestamp(sleep_period_border + pd.Timedelta('0min'))
    sleep_period_end = sleep_period_start + pd.Timedelta('1h')

    mean_frac = diff_mean.ts(sleep_period_start, sleep_period_end).frac_sleep()
    max_frac = diff_max.ts(sleep_period_start, sleep_period_end).frac_sleep()

    f, ax = plt.subplots(1, 1, figsize=(28, 8))
    diff_to_plot = diff_mean.ts(sleep_period_start-pd.Timedelta('1h'), sleep_period_end+pd.Timedelta('3h'))
    sns.lineplot(x='datetime', y='diff', data=diff_mean, ax=ax, color='k')
    ax.axvline(x=sleep_period_start, color='r', linestyle='--')
    ax.axvline(x=sleep_period_end, color='r', linestyle='--')

    ax.set_title(f'{subject} | mean-sleep-frac: {mean_frac*100:.2f}% | max-sleep-frac: {max_frac*100:.2f}%')
    plt.show()

In [None]:
f, ax = plt.subplots(1, 1, figsize=(28, 8))
sns.lineplot(x='datetime', y='diff', data=diff_mean, ax=ax, color='k')
ax.set_title(f'mean diffs, all nodes')
ax.axvline(x=sleep_period_start, color='r', linestyle='--')
ax.axvline(x=sleep_period_end, color='r', linestyle='--')

In [None]:
f, ax = plt.subplots(1, 1, figsize=(28, 8))
sns.lineplot(x='datetime', y='diff', data=diff_max, ax=ax, color='k')
kde.plot.main.shade_hypno_for_me(h, ax=ax, alpha=0.25)
ax.set_title(f'max diffs, all nodes')

In [None]:
px.line(diff_total, x='frame', y='diff')

In [None]:
thresh = 0.12
diff_total = diff_total.with_columns(state=pl.lit('wake'))
diff_total = diff_total.with_columns(state=pl.when(pl.col('diff')<thresh).then(pl.lit('sleep')).otherwise(pl.col('state')))
diff_total.group_by(['state']).agg(pl.col('diff').count()).sort(['state'])['diff'][0]/len(diff_total)

In [None]:
px.line(diff_max, x='frame', y='diff')

In [None]:
np.quantile(diff_max['diff'], [0.1, 0.2, 0.3, 0.4, 0.5])

In [None]:
thresh = 0.4
diff_max = diff_max.with_columns(state=pl.lit('wake'))
diff_max = diff_max.with_columns(state=pl.when(pl.col('diff')<thresh).then(pl.lit('sleep')).otherwise(pl.col('state')))
diff_max.group_by(['state']).agg(pl.col('diff').count()).sort(['state'])['diff'][0]/len(diff_max)

In [None]:
subject = 'ACR_45'
rec = 'swisin'
diffs, h = load_probe_diff_df(subject, rec)
total_hyp_duration = h.duration.sum().total_seconds()
print(f'total hypno duration: {total_hyp_duration}')

sleep_states = ['NREM', 'REM', 'Transition-to-REM', 'Transition-to-NREM', 'Transition-to-Wake']
hypno_frac = h.keep_states(sleep_states).duration.sum().total_seconds()/h.duration.sum().total_seconds()
print(f'hypno sleep fraction: {hypno_frac}')

# get mean, min, max diffs across all nodes
diff_mean = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').mean()).sort(['datetime'])
diff_max = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').max()).sort(['datetime'])

mean_thresh = 0.12
mean_frac = thresh_df(diff_mean, mean_thresh)
print(f'mean diff sleep fraction: {mean_frac}')

max_thresh = 0.4
max_frac = thresh_df(diff_max, max_thresh)
print(f'max diff sleep fraction: {max_frac}')

In [None]:
subject = 'ACR_40'
rec = 'swinap'
diffs, h = load_probe_diff_df(subject, rec)
total_hyp_duration = h.duration.sum().total_seconds()
print(f'total hypno duration: {total_hyp_duration}')

sleep_states = ['NREM', 'REM', 'Transition-to-REM', 'Transition-to-NREM', 'Transition-to-Wake']
hypno_frac = h.keep_states(sleep_states).duration.sum().total_seconds()/h.duration.sum().total_seconds()
print(f'hypno sleep fraction: {hypno_frac}')

# get mean, min, max diffs across all nodes
diff_mean = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').mean()).sort(['datetime'])
diff_max = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').max()).sort(['datetime'])

mean_thresh = 0.12
mean_frac = thresh_df(diff_mean, mean_thresh)
print(f'mean diff sleep fraction: {mean_frac}')

max_thresh = 0.4
max_frac = thresh_df(diff_max, max_thresh)
print(f'max diff sleep fraction: {max_frac}')

In [None]:
subject = 'ACR_41'
rec = 'swi2'
diffs, h = load_probe_diff_df(subject, rec)
total_hyp_duration = h.duration.sum().total_seconds()
print(f'total hypno duration: {total_hyp_duration}')
ht1 = h['start_time'].min()
ht2 = h['end_time'].max()
diffs = diffs.filter((pl.col('datetime')>ht1) & (pl.col('datetime')<ht2))


sleep_states = ['NREM', 'REM', 'Transition-to-REM', 'Transition-to-NREM', 'Transition-to-Wake']
hypno_frac = h.keep_states(sleep_states).duration.sum().total_seconds()/h.duration.sum().total_seconds()
print(f'hypno sleep fraction: {hypno_frac}')

# get mean, min, max diffs across all nodes
diff_mean = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').mean()).sort(['datetime'])
diff_max = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').max()).sort(['datetime'])

mean_thresh = 0.12
mean_frac = thresh_df(diff_mean, mean_thresh)
print(f'mean diff sleep fraction: {mean_frac}')

max_thresh = 0.4
max_frac = thresh_df(diff_max, max_thresh)
print(f'max diff sleep fraction: {max_frac}')

In [None]:
subject = 'ACR_35'
rec = 'swi2-bl'
diffs, h = load_probe_diff_df(subject, rec)
total_hyp_duration = h.duration.sum().total_seconds()
print(f'total hypno duration: {total_hyp_duration}')

sleep_states = ['NREM', 'REM', 'Transition-to-REM', 'Transition-to-NREM', 'Transition-to-Wake']
hypno_frac = h.keep_states(sleep_states).duration.sum().total_seconds()/h.duration.sum().total_seconds()
print(f'hypno sleep fraction: {hypno_frac}')

# get mean, min, max diffs across all nodes
diff_mean = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').mean()).sort(['datetime'])
diff_max = diffs.group_by(['frame', 'datetime']).agg(pl.col('diff').max()).sort(['datetime'])

mean_thresh = 0.12
mean_frac = thresh_df(diff_mean, mean_thresh)
print(f'mean diff sleep fraction: {mean_frac}')

max_thresh = 0.4
max_frac = thresh_df(diff_max, max_thresh)
print(f'max diff sleep fraction: {max_frac}')