In [None]:
import os
import sys
import pandas as pd
import numpy as np
import seaborn
import matplotlib
import matplotlib.pyplot as plt 
import scipy
import IPython

sys.path.append('../../')
sys.path.append('../')

import utils
import readers
from readers.describe_bs_output_reader import DescribeBsOutputReader
from readers.patient_info import PatientInfo

In [None]:
output_dir = '/afs/csail.mit.edu/u/t/tzhan/NFS/script_output/describe_bs/'
outputReader = DescribeBsOutputReader(output_dir)
patientInfo = PatientInfo('../../../../patient_outcome_info/')

In [None]:
def get_zs_bsr_over_time_heatmap(zs_or_bsr, outcome_lambda, ds_rate, max_num_hours, plot_in_parts=False):
    # plot_in_parts: if False, makes one plot with all patients. 
    #                if True, makes multiple plots, each sized to fit a pdf page
    # ds_rate: rate at which to downsample. If not high enough, might have memory issues
    # outcome_lambda: boolean function which returns whether we want to include the patient in the plot
    # zs_or_bsr must be either 'zs' or 'bsr'
    assert(zs_or_bsr=='bsr' or zs_or_bsr=='zs'), "zs_or_bsr must be either 'zs' or 'bsr'"
    new_sids = patientInfo.get_all_sids()
    num_samples = utils.hour_to_samples(max_num_hours)/ds_rate

    sids_used = []
    for new_sid in new_sids:
        if patientInfo.has_good_outcome(new_sid)==-1:
            print('sid without outcome! this is not expected! sid: {}'.format(new_sid))
            continue
        if not outcome_lambda(new_sid):
            continue
        sids_used.append(new_sid)
    
    zs_bsrs = np.zeros((len(sids_used), num_samples))
    masks = np.ones((len(sids_used), num_samples))

    for i, new_sid in enumerate(sids_used):
        print(i, new_sid)
        edfs_and_indices = patientInfo.get_edfs_and_indices(new_sid, max_num_hours=max_num_hours)
        patient_zs_bsrs = np.zeros(utils.hour_to_samples(max_num_hours))
        patient_masks = np.ones(utils.hour_to_samples(max_num_hours))
        for edf_num, (edf, start_index, end_index) in enumerate(edfs_and_indices):
            if zs_or_bsr=='bsr':
                zs_bsr = outputReader.get_bsr(edf)
            else:
                zs_bsr = outputReader.get_global_zs(edf)
            patient_zs_bsrs[start_index:end_index] = zs_bsr[:(end_index-start_index)]
            patient_masks[start_index:end_index] = 0
        zs_bsrs[i] = patient_zs_bsrs[::ds_rate]
        masks[i] = patient_masks[::ds_rate]
        IPython.display.clear_output()
    masks = masks.astype(bool)
    index_in_samples= np.arange(0, utils.hour_to_samples(max_num_hours), ds_rate)
    index_in_seconds = index_in_samples/(0.0+utils.sec_to_samples(1))
    minutes, seconds = np.divmod(index_in_seconds, 60)
    hours, minutes = np.divmod(minutes, 60)
    index_formatted = []
    for i in range(len(hours)):
        hour = int(hours[i])
        minute = int(minutes[i])
        time = '{}:{:02d}'.format(hour, minute)
        index_formatted.append(time)
    zs_bsr_df = pd.DataFrame(zs_bsrs, index=sids_used, columns=index_formatted)
    ## can have at most 30 patients
    if zs_or_bsr=='bsr':
        cmap = "YlGnBu"
    else:
        cmap = "YlGnBu_r"
    num_rows = zs_bsr_df.shape[0]
    if plot_in_parts:
        ## Make a bunch of plots, of size befitting a pdf
        row_start_ind = 0
        num_rows_per_plot = 46
        plots = []
        while row_start_ind < num_rows:
            row_end_ind = min(row_start_ind + 45, num_rows)
            zs_bsr_df_part = zs_bsr_df[row_start_ind:row_end_ind]
            masks_part = masks[row_start_ind:row_end_ind, :]
            fig, ax = plt.subplots(figsize=(6,0.5+int(0.2*zs_bsr_df_part.shape[0])))     
            heatmap_plot = seaborn.heatmap(zs_bsr_df_part, cmap=cmap, mask=masks_part, ax=ax)
            plots.append(heatmap_plot)
            row_start_ind = row_end_ind
        return plots
    else:
        fig, ax = plt.subplots(figsize=(6,0.5+0.18*zs_bsr_df.shape[0]))        
        heatmap_plot = seaborn.heatmap(zs_bsr_df, cmap=cmap, mask=masks, ax=ax)
        return heatmap_plot

In [None]:
def save_heatmap_plot_parts(plots, output_file_name):
    for i, plot in enumerate(plots):
        plot.figure.savefig('{}_part{}.png'.format(output_file_name, i))

### Define the parameters

In [None]:
# rate at which to downsample
ds_rate = 100
# max number of hours to plot
max_num_hours = 72
# if plot_in_parts is True, will break up the plot into multiple standard-page-sized plots
# if plot_in_parts is False, will create only one plot contianing all the patients
plot_in_parts = True
# zs_bsr is either 'zs' or 'bsr', depending on which one you want to plot
zs_bsr = 'bsr'
# if outcome_target is 'good', makes the plot for only patients with good outcome
# if outcome_target is 'bad', makes the plot for only patients with bad outcome
outcome_target = 'good'

### Create the plot

In [None]:
outcome_lambda = patientInfo.has_good_outcome
heatmap_plot = get_zs_bsr_over_time_heatmap(zs_bsr, outcome_lambda = outcome_lambda, 
                                        ds_rate=ds_rate, max_num_hours=max_num_hours, plot_in_parts=plot_in_parts)
save_heatmap_plot_parts(heatmap_plot, '{}_{}_ds{}'.format(outcome_target, zs_bsr, ds_rate))