In [15]:
from config import project_config as config
from utils.sleep_wake_filter import filter_sleep_series
import pandas as pd
import os
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from utils.data_utils import read_sleep_dairies_v2
import matplotlib.dates as md

In [16]:
def mask_dates_of_sleep_periods(sleep_periods):
    # This is a trick that will help align the x-axis on the plots
    # We will replace the dates on different days with the same dummy date
    # Note that each plot covers 12 noon to 12 noon on the next day
    # So, the first half of hours (afternoon) and the second half will have different dates
    for i, _ in enumerate(sleep_periods):
        for j, _ in enumerate(sleep_periods[i]):
            sleep_periods[i][j] = sleep_periods[i][j].replace(
                year=1970,
                month=1,
                day=1 if 12 <= sleep_periods[i][j].hour <= 23 else 2
            )
        
    return sleep_periods

In [17]:
def get_sleep_periods(df, sleep_col, time_col, mask_dates=False):
    local_df = df.copy()
    local_df['start'] = local_df[sleep_col].diff().fillna(1)  # The first row (Nan) is by nature an edge
    local_df['end'] = local_df[sleep_col].diff(-1).fillna(1)  # The last row (Nan) is by nature an edge
    local_df['start'] = (local_df['start'] == 1) & (local_df[sleep_col] == 1)  # Only intersted in start of sleep
    local_df['end'] = (local_df['end'] == 1) & (local_df[sleep_col] == 1)  # Only intersted in end of sleep

    start_times = local_df.loc[local_df['start'], time_col]
    end_times = local_df.loc[local_df['end'], time_col]

    # Don't want to return list of tuples because tuples are immutable
    return [[s, e] for s, e in zip(start_times, end_times)]

In [18]:
def get_daily_sleep_periods_from_diary(subject_diary, day_start, day_end):
    sleep_periods = list(zip(subject_diary['lights_off'], subject_diary['lights_on']))
    day_sleep_periods = [  # Keep only episodes that start and/or end on the current day
        [sleep_start, sleep_end] for sleep_start, sleep_end in sleep_periods
        if (day_start <= sleep_start <= day_end) or (day_start <= sleep_end <= day_end)
    ]
    
    # Since we plot each day separately we need to make sure that both the start and the end of the sleep
    # episode are with the bounds of the current day (noon to noon)
    # If not, we will break the sleep episode into two and plot the part that belongs to the current
    # day. The rest of the episode will be plotted on its own day.
    for i, (sleep_start, sleep_end) in enumerate(day_sleep_periods):
        if sleep_start < day_start:  # sleep started before noon (belongs to a different row of the plot)
            day_sleep_periods[i][0] = day_start  # Set the start to beginning of the day
        if sleep_end > day_end:
            day_sleep_periods[i][1] = day_end

    return day_sleep_periods

In [19]:
def plot_daily_sleep_indicators(sleep_periods_dict, ax, labels, color=None, **kwargs):
    
    ONE_EPOCH = np.timedelta64(config['seconds_per_epoch'], 's')
    colormap = {
        0: 0,
        1: 1,
        2: 2,
        3: 4,
        4: 5,
        5: 7,
        6: 8
    }
    colors = matplotlib.colormaps['tab10']
    for idx, (_, sleep_periods) in enumerate(sleep_periods_dict.items()):
        # This is to avoid plotting over other variables when plotting na only for some sources
        if len(sleep_periods) == 0:  # skip empty sources
            continue
        sleep_periods = mask_dates_of_sleep_periods(sleep_periods)  # Masking dates allows hours to align on the x-axis
        plot_x_ranges = [(s, pd.Timedelta(e - s) + ONE_EPOCH) for s, e in sleep_periods]
        c = color if color is not None else colors(colormap[idx+1])
        ax.broken_barh(plot_x_ranges, (idx, 1), color=c, label=labels[idx], **kwargs)

    ax.set_yticks([])
    ax.set_xlabel('Time')
    
    return ax

In [20]:
def collect_plot_data(df, all_plot_sources, keys_to_collect):
    plot_sources = {}
    for plot_name, data_col in all_plot_sources.items():
        if plot_name in keys_to_collect:
            plot_data = get_sleep_periods(df, data_col, 'epoch_ts', mask_dates=True)
        else:
            # don't want to plot this source for this plot type
            # But we still want to reserve its place and not plot over it
            plot_data = []
        
        plot_sources[plot_name] = plot_data

    return plot_sources

In [21]:
sleep_diaries_path = 'data/Sleep diaries'
plot_path = 'Results/plots/all-five'
# plot_path = 'Results/plots/AWS-PSG-Biobank'
# plot_path = 'Results/plots/AWS-PSG'
# plot_path = 'Results/plots/AWS-Model'
os.makedirs(plot_path, exist_ok=True)

for id in config['subject_ids']:
    cv_source = 'pred_PSG-CNN'  # This is the one that needs to be plotted differently between "normal" and "cv"
    all_plot_sources = {  # name to df column mapping
        'AX3 Model': 'pred_PSG-CNN',
        'AWS': 'AWS Sleep',
        'PSG': 'PSG Sleep',
        'Biobank': 'Biobank Sleep'
    }

    plot_grps = {
        'normal': ['AX3 Model', 'AWS', 'PSG', 'Biobank'],  # Legend
        'na': ['AWS', 'Biobank'],  # Legend
        'cv': ['AX3 Model']
    }

    n_sources = len(all_plot_sources.keys())

    df = pd.read_csv(f'Results/merged_indicators/sub_{id:02d}.csv')
    df['epoch_ts'] = pd.to_datetime(df['epoch_ts'])
    df = df.sort_values('epoch_ts')

    # Reading in participation dates. This helps remove days with little data at the beginning and the end
    valid_days = pd.read_csv('data/participation_dates.csv')
    start_timestamp = valid_days.loc[valid_days['subject_id'] == id, 'start_timestamp'].values[0]
    end_timestamp = valid_days.loc[valid_days['subject_id'] == id, 'end_timestamp'].values[0]

    # Sleep diary
    diary_df = read_sleep_dairies_v2(sleep_diaries_path, include_naps=True)
    subject_diary = diary_df[diary_df['subject_id'] == id]

    all_epochs = pd.DataFrame({
        'epoch_ts': pd.date_range(start_timestamp, end_timestamp, freq='30s')
        })
    
    df = pd.merge(
        left=all_epochs,
        right=df,
        on='epoch_ts',
        how='left'
    )

    # # Each date show up in two rows of plots (before and after noon), but +1 because this doesn't apply two first and last dates
    n_days = len(pd.unique(df['epoch_ts'].dt.date)) - 1
    
    # Filter sleep predictions as described in this paper: https://www.pnas.org/doi/eplocal_df/10.1073/pnas.2116729119
    # df['pred_PSG-CNN'] = filter_sleep_series(df['pred_PSG-CNN'])

    # Each day starts at noon (12:00:00). To group rows in such days, I use their time difference
    # with a dummy noon timestamp in distant past.
    df = df.assign(day=(df['epoch_ts'] - pd.to_datetime('1970-01-01 12:00:00')).dt.days)
    df['day'] = df['day'] - df['day'].min()  # And shift so the first day is 0

    fig, axes = plt.subplots(n_days, 1, sharex=True, gridspec_kw={'hspace': 0}, figsize=(16, n_days * n_sources * 0.25))
    # # # # # # # #
    for day, ax in zip(pd.unique(df['day']), axes):
        day_df = df[df['day'] == day]

        # # # # # # # #
        # Normal plots
        plot_df = day_df.copy()
        plot_df.loc[plot_df['is_cv_prediction'] == 1, cv_source] = np.nan
        plot_sources = collect_plot_data(plot_df, all_plot_sources, plot_grps['normal'])

        # processing sleep diary episodes
        # These are in a different format and don't have epoch-by epoch (binary) values
        # instead we have sleep start and end times.
        diary_sleep_periods = get_daily_sleep_periods_from_diary(
            subject_diary=subject_diary,
            day_start=day_df['epoch_ts'].min(),
            day_end=day_df['epoch_ts'].max()
            )
        plot_sources['Diary'] = diary_sleep_periods  # Add diary only for normal plots

        labels = list(plot_sources.keys())
        ax = plot_daily_sleep_indicators(plot_sources, ax, labels=labels)

        # # # # # # # # 
        # na plots
        plot_df = day_df.copy()
        for col in all_plot_sources.values():
            plot_df.loc[:, col] = plot_df[col].map({1: 0, 0: 0, np.nan: 1})  # mask values and turn on nans
        
        plot_sources = collect_plot_data(plot_df, all_plot_sources, plot_grps['na'])
        labels = [k + ' (NA)' for k in plot_sources.keys()]
        
        ax = plot_daily_sleep_indicators(plot_sources, ax, alpha=0.5, labels=labels)
        
        # # # # # # # 
        # cv plots
        plot_df = day_df.copy()
        plot_df.loc[plot_df['is_cv_prediction'] == 0, cv_source] = np.nan
        plot_sources = collect_plot_data(plot_df, all_plot_sources, plot_grps['cv'])
        labels = [k + ' (CV)' for k in plot_sources.keys()]
        ax = plot_daily_sleep_indicators(plot_sources, ax, alpha=0.5, labels=labels)
        # # # # # # # 

        ax.xaxis.set_major_locator(md.HourLocator(interval=2))
        ax.xaxis.set_major_formatter(md.DateFormatter('%H:%M'))
        # plt.setp(ax.xaxis.get_majorticklabels(), rotation = 90)

        # Because we plot nans without sleep diary, number of series passed to the above function
        # changes between the two calls. So, it's better to set the ylim here
        # + 1 below is for sleep diary
        VERT_PAD = 0.25
        ax.set_ylim((-VERT_PAD, len(all_plot_sources.keys()) + 1 + VERT_PAD))
        date_str = day_df['epoch_ts'].min().strftime("%b %d")
        ax.set_ylabel(f'Day {day + 1 - n_days}\n{date_str}')
        
    # # # # # # # # # # # # # # # # # # # # # 

    axes[0].set_title(f'Subject {id}', fontsize=24)
    handles, labels = axes[-1].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=n_sources * 2)

    fig.savefig(f'{plot_path}/sub_{id:02d}.png', dpi=200)
    plt.close()
    