# MEDCohort Visualization

## Here is a list of what you can do with the interactive figure:

- Display/Hide point values by right-clicking near the corresponding point.
- Display/Hide patient IDs corresponding to a point by left-clicking near the corresponding point. This action will also select the whole profile and highlight it in red.
- Set the time relative to a specific class by pressing "r" on your keyboard while hovering over the desired class. The same action can be used to retrieve the absolute time.
- Open profile figures of the selected patients by pressing "p" on your keyboard. In these figures, you will be able to manipulate them as mentioned in the *MEDprofiles_semi_front* tutorial.
- Delete the selected profiles from the cohort by pressing "delete" on your keyboard.
- Set the cohort in bins by month by pressing "m" on your keyboard. This will open a new figure where each point represents a group of patients. You can display/hide the patient IDs related to a point by clicking near the corresponding point.
- Set the cohort in bins by year by pressing "y" on your keyboard. This will open a new figure where each point represents a group of patients. You can display/hide the patient IDs related to a point by clicking near the corresponding point.
- Assign/remove time points to classes by pressing "t" while hovering over the concerned class.



## Import modules and patient data

In [None]:
import datetime
import pickle

from MEDclasses import *
from src.semi_front.MEDcohort_utils import *
from src.semi_front.MEDprofiles_utils import *
import matplotlib as mpl
import numpy as np

%matplotlib qt

In [None]:
# Get MEDProfiles
data_file = open('../data/mimic/MEDprofileData', 'rb')
# data_file = open('../data/meningioma/MEDprofileData', 'rb')
MEDprofile_list = pickle.load(data_file)[:10]
data_file.close()
cohort = MEDcohort(list_MEDprofile=MEDprofile_list)
df_cohort = cohort.profile_list_to_df()

In [None]:
display(df_cohort)

## Define plot functions

In [None]:
# Define global attributes
dict_selected_points = {}
dict_selected_annotations = {}
dict_points = {}
dict_annotations = {}
dict_class_time_points = {}
dict_bin_points = {}
dict_bin_annotations = {}
dict_figure_profile = {}
legend_points = []
xaxis = [FIXED_COLUMNS[1]]

In [None]:
def button_pressed(event):
    """
    Called when a button pressed event occurs.
    :param event:
    :return:
    """
    # If the button pressed is the left MouseButton we display the data of
    # the nearest point in the axis
    if event.button == mpl.backend_bases.MouseButton.LEFT:
        display_attributes_values(event, df_cohort)

    # If the button pressed is the right MouseButton we highlight the patient
    # points in all the plot (for the nearest patient)
    elif event.button == mpl.backend_bases.MouseButton.RIGHT:
        display_patient_id_and_data(event, axes, df_cohort, classes_attributes_dict, dict_points, dict_annotations,
                                    dict_selected_points, dict_selected_annotations, xaxis)
    fig.canvas.draw()

In [None]:
def button_pressed_in_simple_figure(event):
    """
    Called when a button is pressed on figure where there is one annotation per point.
    :param event:
    :return:
    """
    if event.button == mpl.backend_bases.MouseButton.LEFT or event.button == mpl.backend_bases.MouseButton.RIGHT:
        display_annotations(event)
        plt.gcf().canvas.draw_idle()

In [None]:
def key_pressed_in_profile(event):
    """
    Called when a key is pressed on profile figure.
    :param event:
    :return:
    """
    # Set relative time at the class matching the axis in which the event occurs
    if event.key == 'r':
        figure_dict = dict_figure_profile[plt.gcf().number][1]
        r_pressed_in_profile(event, figure_dict['profile_axes'], figure_dict['profile_df'], classes_attributes_dict, figure_dict['profile_xaxis'], figure_dict['profile_points'], figure_dict['profile_annotations'])
        center_data_in_profile_plot(figure_dict['profile_points'])
        plt.gcf().canvas.draw_idle()

In [None]:
def on_close_figure(event):
    """
    Remove profile data from figure profile dict.
    :param event:
    :return:
    """
    dict_figure_profile.pop(plt.gcf().number, None)

In [None]:
def key_pressed(event):
    """
    Called when a key pressed event occurs.
    :param event:
    :return:
    """
    # Delete selected patients while delete key is pressed
    if event.key == 'delete':
        delete_pressed(dict_points, dict_selected_points, dict_selected_annotations, title)
    # Set relative time at the class matching the axis in which the event occurs
    elif event.key == 'r':
        r_pressed(event, axes, df_cohort, classes_attributes_dict, dict_points, dict_annotations, dict_selected_points, dict_selected_annotations, xaxis)
    # Set time points
    elif event.key == 't':
        t_pressed(event, axes, df_cohort, classes_attributes_dict, dict_points, dict_annotations, dict_selected_points, dict_selected_annotations, xaxis, dict_class_time_points, legend_points, fig)
    # Open bin figure
    elif event.key == 'y' or event.key == 'm':
        if event.key == 'y':
            frequency = 'year'
        else:
            frequency = 'month'
        bin_pressed(df_cohort, classes_attributes_dict, frequency, subplot_height, plot_width, colors, dict_bin_points, dict_bin_annotations, button_pressed_in_simple_figure)
    # Open selected profiles
    elif event.key == 'p':
        p_pressed(df_cohort, classes_attributes_dict, subplot_height, plot_width, xaxis, colors, dict_selected_points, dict_figure_profile, button_pressed_in_simple_figure, key_pressed_in_profile, on_close_figure)
    fig.canvas.draw()

In [None]:
def on_close(event):
    """
    Called when the figure is closed.
    :param event:
    :return:
    """
    # Update the cohort dataframe according to the plot figure
    update_cohort(df_cohort, dict_points, dict_selected_points)

## Display Cohort

In [None]:
# Define figure parameters
subplot_height = 1
plot_width = 10
classes_attributes_dict = {'demographic': ([], 'compact'), 'labevent': (['sodium_max', 'sodium_min', 'sodium_trend'], 'complete'), 'nrad': (['attr_0', 'attr_1', 'attr_2'], 'compact'), 'vp': ([], 'compact')}
# classes_attributes_dict = {'Demographic': ([], 'compact'), 'Therapy': ([], 'compact'), 'Pathology': ([], 'complete'), 'Event': ([], 'complete')}
colors = mpl.colormaps['Blues'](np.linspace(0, 1, len(classes_attributes_dict.keys()) + 1))

In [None]:
# Create and display figure
fig, axes = set_plot(subplot_height, plot_width, classes_attributes_dict, colors)
fig.canvas.mpl_connect('button_press_event', button_pressed)
fig.canvas.mpl_connect('key_press_event', key_pressed)
fig.canvas.mpl_connect('close_event', on_close)
plt.xlabel("Date", fontsize=12)
display_cohort(axes, df_cohort, classes_attributes_dict, dict_points, dict_annotations, xaxis)
title = fig.suptitle(f'MEDcohort composed by {len(set(df_cohort.index))} patients', fontsize=16)

## Generate output csv from time points

In [None]:
for time_point in set(df_cohort[FIXED_COLUMNS[2]].dropna()):
    df_cohort[df_cohort[FIXED_COLUMNS[2]] == time_point].dropna(axis=1, how='all').drop(FIXED_COLUMNS[2], axis=1).to_csv('../output/time_point_' + str(int(time_point)) + '.csv')