In [2]:
from config import project_config as config
from utils.sleep_wake_filter import filter_sleep_series
import pandas as pd
import os
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

In [3]:
def mask_dates_of_sleep_periods(sleep_periods):
    # This is a trick that will help align the x-axis on the plots
    # We will replace the dates on different days with the same dummy date
    # Note that each plot covers 12 noon to 12 noon on the next day
    # So, the first half of hours (afternoon) and the second half will have different dates
    for i, _ in enumerate(sleep_periods):
        for j, _ in enumerate(sleep_periods[i]):
            sleep_periods[i][j] = sleep_periods[i][j].replace(
                year=1970,
                month=1,
                day=1 if 12 <= sleep_periods[i][j].hour <= 23 else 2
            )
        
    return sleep_periods

In [4]:
def get_sleep_periods(df, sleep_col, time_col, mask_dates=False):
    local_df = df.copy()
    local_df['start'] = local_df[sleep_col].diff().fillna(1)  # The first row (Nan) is by nature an edge
    local_df['end'] = local_df[sleep_col].diff(-1).fillna(1)  # The last row (Nan) is by nature an edge
    local_df['start'] = (local_df['start'] == 1) & (local_df[sleep_col] == 1)  # Only intersted in start of sleep
    local_df['end'] = (local_df['end'] == 1) & (local_df[sleep_col] == 1)  # Only intersted in end of sleep

    start_times = local_df.loc[local_df['start'], time_col]
    end_times = local_df.loc[local_df['end'], time_col]

    # Don't want to return list of tuples because tuples are immutable
    return [[s, e] for s, e in zip(start_times, end_times)]

In [5]:
def read_sleep_dairies(path):
    sleep_diary_df = pd.DataFrame()
    for filename in [f for f in os.listdir(path) if f.endswith('csv')]:
        if filename.find('nap') >= 0:
            continue
        df = pd.read_csv(f'{path}/{filename}')
        sleep_diary_df = pd.concat([sleep_diary_df, df])

    # reading the extra nap diaries
    nap_df = pd.read_csv(f'{path}/SRCDRI001_Sleep Diary 019-036_nap.csv')
    nap_df = nap_df.rename(columns={
        'date_startnap': 'date_gotosleep',
        'date_endnap': 'date_finalawake',
        'nap_start': 'gotosleep',
        'nap_end': 'finalawake'   
    }).drop(columns=['nap times'])
    sleep_diary_df = pd.concat([sleep_diary_df, nap_df])
    sleep_diary_df = sleep_diary_df.sort_values(['participantNo', 'date_gotosleep']).reset_index(drop=True)

    sleep_diary_df['sleep_start'] = pd.to_datetime(sleep_diary_df['date_gotosleep'] + ' ' + sleep_diary_df['gotosleep'])
    sleep_diary_df['sleep_end'] = pd.to_datetime(sleep_diary_df['date_finalawake'] + ' ' + sleep_diary_df['finalawake'])
    sleep_diary_df = sleep_diary_df[['participantNo', 'sleep_start', 'sleep_end']]
    return sleep_diary_df


In [6]:
def get_daily_sleep_periods_from_diary(subject_diary, day_start, day_end):
    sleep_periods = list(zip(subject_diary['sleep_start'], subject_diary['sleep_end']))
    day_sleep_periods = [  # Keep only episodes that start and/or end on the current day
        [sleep_start, sleep_end] for sleep_start, sleep_end in sleep_periods
        if (day_start <= sleep_start <= day_end) or (day_start <= sleep_end <= day_end)
    ]
    
    # Since we plot each day separately we need to make sure that both the start and the end of the sleep
    # episode are with the bounds of the current day (noon to noon)
    # If not, we will break the sleep episode into two and plot the part that belongs to the current
    # day. The rest of the episode will be plotted on its own day.
    for i, (sleep_start, sleep_end) in enumerate(day_sleep_periods):
        if sleep_start < day_start:  # sleep started before noon (belongs to a different row of the plot)
            day_sleep_periods[i][0] = day_start  # Set the start to beginning of the day
        if sleep_end > day_end:
            day_sleep_periods[i][1] = day_end

    return day_sleep_periods

In [7]:
def plot_daily_sleep_indicators(sleep_periods_dict, ax, **kwargs):
    
    ONE_EPOCH = np.timedelta64(config['seconds_per_epoch'], 's')

    colors = matplotlib.colormaps['tab10']
    for idx, (label, sleep_periods) in enumerate(sleep_periods_dict.items()):
        sleep_periods = mask_dates_of_sleep_periods(sleep_periods)  # Masking dates allows hours to align on the x-axis
        plot_x_ranges = [(s, pd.Timedelta(e - s) + ONE_EPOCH) for s, e in sleep_periods]
        ax.broken_barh(plot_x_ranges, (idx, 1), color=colors(idx), label=label, **kwargs)

    ax.set_yticks([])
    ax.set_xlabel('Time')
    
    return ax

In [11]:
sleep_diaries_path = 'data/Sleep diaries'
plot_path = 'Results/plots/v1'
os.makedirs(plot_path, exist_ok=True)

for id in config['subject_ids'][:15]:

    # This determines the order of the columns in the plot
    # And the names in the legend
    binary_sources = {
        # 'df_columns': ['pred', 'Aux AWS Sleep', 'Biobank Sleep'],
        'df_columns': ['pred', 'AWS Sleep', 'Biobank Sleep'],
        'plot_labels': ['Pred', 'AWS', 'Bio']  # Legend
    }
    n_sources = len(binary_sources['df_columns'])

    df = pd.read_csv(f'Results/merged_sources/sub_{id:02d}.csv')
    df['epoch_ts'] = pd.to_datetime(df['epoch_ts'])
    df = df.sort_values('epoch_ts')

    # Reading in participation dates. This helps remove days with little data at the beginning and the end
    valid_days = pd.read_csv('data/participation_dates.csv')
    start_timestamp = valid_days.loc[valid_days['subject_id'] == id, 'start_timestamp'].values[0]
    end_timestamp = valid_days.loc[valid_days['subject_id'] == id, 'end_timestamp'].values[0]

    # Sleep diary
    diary_df = read_sleep_dairies('data/Sleep diaries')
    subject_diary = diary_df[diary_df['participantNo'] == id]

    all_epochs = pd.DataFrame({
        'epoch_ts': pd.date_range(start_timestamp, end_timestamp, freq='30s')
        })
    
    df = pd.merge(
        left=all_epochs,
        right=df,
        on='epoch_ts',
        how='left'
    )

    # # Each date show up in two rows of plots (before and after noon), but +1 because this doesn't apply two first and last dates
    n_days = len(pd.unique(df['epoch_ts'].dt.date)) - 1
    
    # Filter sleep predictions as described in this paper: https://www.pnas.org/doi/eplocal_df/10.1073/pnas.2116729119
    # df['pred'] = filter_sleep_series(df['pred'])

    # Each day starts at noon (12:00:00). To group rows in such days, I use their time difference
    # with a dummy noon timestamp in distant past.
    df = df.assign(day=(df['epoch_ts'] - pd.to_datetime('1970-01-01 12:00:00')).dt.days)
    df['day'] = df['day'] - df['day'].min()  # And shift so the first day is 0

    fig, axes = plt.subplots(n_days, 1, sharex=True, gridspec_kw={'hspace': 0}, figsize=(16, n_days * n_sources * 0.25))
    for day, ax in zip(pd.unique(df['day']), axes):
        day_df = df[df['day'] == day]

        # processing sleep diary episodes
        # These are in a different format and don't have epoch-by epoch (binary) values
        # instead we have sleep start and end times.
        diary_sleep_periods = get_daily_sleep_periods_from_diary(
            subject_diary=subject_diary,
            day_start=day_df['epoch_ts'].min(),
            day_end=day_df['epoch_ts'].max()
            )

        # binary sources may have missing values. We plot twice: first the original data
        # then we mark the nan values. This doesn't apply to sleep diary (below)
        for plot_type in ['normal', 'na']:
            if plot_type == 'na':
                for col in binary_sources['df_columns']:
                    day_df.loc[:, col] = day_df[col].map({1: 0, 0: 0, np.nan: 1})  # mask values and turn on nans
            
            sleep_periods_dict = {
                source: get_sleep_periods(day_df, col, 'epoch_ts', mask_dates=True) for col in binary_sources['df_columns']
                for source, col in zip(binary_sources['plot_labels'], binary_sources['df_columns'])
            }
            if plot_type != 'na':  # No sleep diary in na plot
                sleep_periods_dict['Diary'] = diary_sleep_periods  # Add the sleep diary

            ax = plot_daily_sleep_indicators(
                sleep_periods_dict,
                ax,
                alpha=1 if plot_type=='normal' else 0.4,
                )

        # Because we plot nans without sleep diary, number of series passed to the above function
        # changes between the two call. So, it's better to set the ylim here
        # + 1 below is for sleep diary
        VERT_PAD = 0.25
        ax.set_ylim((-VERT_PAD, len(binary_sources['df_columns']) + 1 + VERT_PAD))
        ax.set_ylabel(f'Day {n_days - (day + 1)}')
    # # # # # # # # # # # # # # # # # # # # # 

    axes[0].set_title(f'Subject {id}')
    handles, labels = axes[-1].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=n_sources * 2)

    fig.savefig(f'{plot_path}/sub_{id:02d}.png', dpi=200)
    plt.close()
    