In [None]:
import math
import os
import sys
import pandas as pd
import numpy as np
import seaborn
import matplotlib
import matplotlib.pyplot as plt 
import scipy
import IPython

import similarity_scaling
from similarity_scaling import scale_arr

sys.path.append('../../')
sys.path.append('../')

import utils
import readers
from readers.similarity_output_reader import SimilarityOutputReader
from readers.patient_info import PatientInfo

### Define the directory containing the similarity output

In [None]:
similarity_output_dir = '/afs/csail.mit.edu/u/t/tzhan/NFS/script_output/similarity_clip/'
# a short tag for the directory, so that saved figures are named with the tag
output_dir_tag = 'clip'

### Create the required readers

In [None]:
similarityReader = SimilarityOutputReader(similarity_output_dir)
outcomes_path = '../../../../patient_outcome_info/'
patientInfo = PatientInfo(outcomes_path)

# Similarity over time

In [None]:
def get_sims_per_burst_vector(sims, n_bursts, neighbor_size=3):
    # takes in flattened matrix of similarities for all n_bursts burst pairs, 
    # returns array of length n_bursts of summary similarity number for each burst
    n_pairs = n_bursts * (n_bursts - 1) / 2
    assert(len(sims)==n_pairs), 'npairs {} != len(sims) {}'.format(n_pairs, len(sims))
    sims_matrix = np.zeros((n_bursts, n_bursts))
    triu_indices = np.triu_indices_from(sims_matrix, k=1)
    sims_matrix[triu_indices] = sims
    sims_matrix = sims_matrix + sims_matrix.transpose() # make symmetric
    sims_per_burst = np.zeros(n_bursts)
    for i in range(n_bursts):
        js = range(max(0, i-neighbor_size), i) + range(i+1, min(n_bursts, i+1+neighbor_size))
        burst_sims = sims_matrix[i][np.array(js)]
        sims_per_burst[i] = np.mean(burst_sims)
    return sims_per_burst

In [None]:
def get_sim_over_time(edf, similarity_fn, min_episode_length, scale_fn, scale_param, neighbor_size):
    nsamples = int(patientInfo.get_edf_duration(edf))
    edf_sim_over_time = np.zeros(nsamples)
    edf_mask = np.ones(nsamples)
    episode_to_sims_list = similarityReader.get_similarities(edf, similarity_fn, 
                                                                     convert_to_np=True, 
                                                                     min_episode_length=min_episode_length)
    episode_to_burst_ranges = similarityReader.get_burst_ranges(edf)
    for episode_dict in episode_to_sims_list:
        (episode_start_index, episode_end_index) = episode_dict['episode_start_index'], episode_dict['episode_end_index']
        sims = episode_dict['similarities']
        if len(sims)==0:
            continue
        sims = scale_arr(sims, similarity_fn, scale_fn, scale_param)
        burst_ranges = episode_to_burst_ranges[(episode_start_index, episode_end_index)]
        n_bursts = len(burst_ranges)

        burst_start_indices = burst_ranges[:, 0]
        sims_per_burst_vector = get_sims_per_burst_vector(sims, n_bursts, neighbor_size=neighbor_size)
        # sims over time, interpolated from first burst to last burst
        interpolate_fn = scipy.interpolate.interp1d(burst_start_indices, sims_per_burst_vector)
        min_burst_index, max_burst_index = min(burst_start_indices), max(burst_start_indices)
        interpolated_sims_over_time = interpolate_fn(np.arange(min_burst_index, max_burst_index))
        edf_sim_over_time[min_burst_index:max_burst_index] = interpolated_sims_over_time
        edf_mask[min_burst_index:max_burst_index] = 0
    return edf_sim_over_time, edf_mask

In [None]:
def get_sim_over_time_plot(similarity_fn, scale_fn, scale_param, min_episode_length, max_num_hours, ds_rate,
                          neighbor_size, outcome_lambda, plot_in_parts):
    new_sids = patientInfo.get_all_sids()
    num_samples = utils.hour_to_samples(max_num_hours)/ds_rate

    sids_used = []
    for new_sid in new_sids:
        if patientInfo.has_good_outcome(new_sid)==-1:
            print('sid without outcome! this is not expected! sid: {}'.format(new_sid))
            continue
        if not outcome_lambda(new_sid):
            continue
        sids_used.append(new_sid)

#    sids_used = sids_used[:20]

    sim_over_time_matrix = np.zeros((len(sids_used), num_samples))
    masks = np.ones((len(sids_used), num_samples))
    sim_iter_num = 0

    sids_with_sims = []
    for i, new_sid in enumerate(sids_used):
        print(i, new_sid, ' out of ', len(sids_used))
        edfs_and_indices = patientInfo.get_edfs_and_indices(new_sid, max_num_hours=max_num_hours)
        patient_sim_over_time = np.zeros(utils.hour_to_samples(max_num_hours))
        patient_masks = np.ones(utils.hour_to_samples(max_num_hours))
        for edf_num, (edf, start_index, end_index) in enumerate(edfs_and_indices):
            edf_sim_over_time, edf_mask = get_sim_over_time(edf, similarity_fn, min_episode_length, scale_fn, scale_param, neighbor_size)
            patient_sim_over_time[start_index:end_index] = edf_sim_over_time[:(end_index-start_index)]
            patient_masks[start_index:end_index] = edf_mask[:(end_index-start_index)]
        if not np.all(patient_masks==1):
            # all masked, no actual data, so don't do this one.
            sim_over_time_matrix[sim_iter_num] = patient_sim_over_time[::ds_rate]
            masks[sim_iter_num] = patient_masks[::ds_rate]
            sim_iter_num +=1
            sids_with_sims.append(new_sid)
        IPython.display.clear_output()
    masks = masks.astype(bool)
    index_in_samples= np.arange(0, utils.hour_to_samples(max_num_hours), ds_rate)
    index_in_seconds = index_in_samples/(0.0+utils.sec_to_samples(1))
    minutes, seconds = np.divmod(index_in_seconds, 60)
    hours, minutes = np.divmod(minutes, 60)
    index_formatted = []
    for i in range(len(hours)):
        hour = int(hours[i])
        minute = int(minutes[i])
        time = '{}:{:02d}'.format(hour, minute)
        index_formatted.append(time)

    sim_over_time_matrix = sim_over_time_matrix[0:sim_iter_num, :]
    masks = masks[0:sim_iter_num, :]
    sim_over_time_df = pd.DataFrame(sim_over_time_matrix, index=sids_with_sims, columns=index_formatted)

    num_rows = sim_over_time_df.shape[0]
    if plot_in_parts:
        ## Make a bunch of plots, of size befitting a pdf
        row_start_ind = 0
        num_rows_per_plot = 46
        plots = []
        while row_start_ind < num_rows:
            row_end_ind = min(row_start_ind + 45, num_rows)
            sim_over_time_df_part = sim_over_time_df[row_start_ind:row_end_ind]
            masks_part = masks[row_start_ind:row_end_ind, :]
            fig, ax = plt.subplots(figsize=(6*1.39,1.39*(0.5+0.2*sim_over_time_df_part.shape[0])))
            heatmap_plot = seaborn.heatmap(sim_over_time_df_part, vmin=0.0, vmax=1.0, cmap="YlGnBu", mask=masks_part, ax=ax)
            plots.append(heatmap_plot)
            row_start_ind = row_end_ind
        return plots
    else:
        fig, ax = plt.subplots(figsize=(6,0.5+0.2*sim_over_time_df.shape[0]))         
        heatmap_plot = seaborn.heatmap(sim_over_time_df, cmap="YlGnBu", vmin=0.0, vmax=1.0, mask=masks, ax=ax)
        return heatmap_plot

In [None]:
def save_heatmap_plot_parts(plots, output_file_name):
    for i, plot in enumerate(plots):
        plot.figure.savefig('{}_part{}.png'.format(output_file_name, i))

### Define parameters

In [None]:
similarity_fn = 'xcorr'
scale_fn = 'exp'
scale_param = 50.0
min_episode_length = 30
max_num_hours = 72
ds_rate = 100
neighbor_size = 10
target_outcome='bad'
plot_in_parts = True

### Create derived parameters from the defined parameters

In [None]:
if scale_fn=='exp' and scale_param==50:
    scale_tag = '{}_{}'.format(scale_fn, scale_param)
else:
    scale_tag = scale_fn
image_info = '{}_{}'.format(scale_tag, output_dir_tag)
if target_outcome=='good':
    outcome_lambda = lambda sid: patientInfo.has_good_outcome(sid)
else:
    outcome_lambda = lambda sid: patientInfo.has_bad_outcome(sid)

In [None]:
save_path = 'similarity_over_time_{}_{}_{}_{}'.format(target_outcome, similarity_fn, image_info, min_episode_length)
heatmap_plot = get_sim_over_time_plot(similarity_fn, scale_fn, scale_param, min_episode_length, max_num_hours, 
                                      ds_rate, neighbor_size, outcome_lambda, plot_in_parts)
save_heatmap_plot_parts(heatmap_plot, save_path)