# ***Notebook to interactively generate Kaplan-Meier plots***

**Notebook made by:** *Eduardo Reyes-Alvarez (Ph.D. candidate)*

**Affiliation:** *Dr. Lois Mulligan's lab, Queen's University.*

**Contact:** *eduardo_reyes09@hotmail.com*

**Date of latest version:** August 05, 2023.

## **Instructions**

***Summary:***

This notebook allows the user to easily and interactively create Kaplan-Meier probability curves (survival, progression, recurrence) by using the KaplanMeierFitter library for Python. 

***Note on the current version (V03):***

* This is a third working version that was built based on the METABRIC dataset available on the cBioPortal website (https://www.cbioportal.org/study/summary?id=brca_metabric).
* This version of the notebook allows to easily generate a KM plot using the whole dataset or dividing the dataset into multiple subgroups based on any column/variable contained in the clinical file or in the RNA Sequencing file from the same study.
* This version has not yet been tested for other datasets from cBioPortals or elsewhere and no tests have been done with files other than the mentioned above.
* Future releases will be made addressing this: Fixing some interaction issues. Generalizing pre-processing of files to allow the use of other studies from cBioPortals (first) and other databases (second). Making a version of this notebook for Google Colab to run online.

***How to use this notebook:***

1. **Getting the files:** Go to the link for cBioportals, get the whole METABRIC study (click on the arrow pointing down, on the right side of the study name). This will download a *.tar.gz* compressed file, where you can find and extract the files ***data_clinical_patient.txt*** and any of the ***data_mrna_illumina...*** files. You can extract them anywhere (desktop or downloads), but preferably in the same folder as where the notebook is (if not, don't worry, the notebook will ask you to upload file(s)).
2. ***Rename*** the clinical file to ***clinical.txt*** and the mRNA one to ***RNA.txt***.
3. ***Open the notebook*** with JupyterLab, select "Run" in the top-left menu, then "Run all cells" and wait (~1min). Due to widget layout and compatibility issues, it is best to use JupyterLab.
4. ***Look under*** the code block on "Start plotting here!".
5. ***Plot!***: start by selecting a time variable (the METABRIC study has only Overall Survival), the column with the event to observe, fill the 0 and 1 widgets with the event (0 = No event/Living, 1 = Event/Died). After that you can plot the whole dataset by clicking the green button at the bottom, or select "Use variable(s)" to explore different variables and make subgroups. Once subgroups are selected by filling the widgets with tags or ranges for each variable, you can click the green button to observe the multiple survival probabilities (please avoid overlap in the ranges, even if the numbers are integers, do this instead: 0.00-**30.00**, **30.01**-60.00).
6.  ***Save!***: Once you have generated a valid KM plot, the button below the green button will be enabled, so you can save a jpg containing the plot, and an excel file with relevant information in case you wish to make the plot in other software.
   
**NOTES:** This notebook automatically generates a ***MyLog.txt*** file in the working directory to help troubleshoothing (it rewrites itself if you don't move it or rename it after each notebook interaction). Currently, the button "Use variable(s)" works well the first time, but if you plot once, then add/reduce the number of variables you may have issues so click the "No" button first and then again the "Use variable(s)" button (this issue will be addressed in the next version). Also, if you do not need the RNA Seq data, the notebook works with the clinical file alone too. 

## **Code**

### Install and import required libraries

Some of the required libraries for this notebook may need to be installed before we can import them.

In [None]:
# General libraries for data handling
import os
!pip install numpy
!pip install pandas
!pip install collections
!pip install openpyxl
import numpy as np
import pandas as pd
from collections import OrderedDict
import openpyxl 

# To edit and handle the output of each code cell
from IPython.display import display, clear_output, HTML

# To make plots and figures
!pip install matplotlib
import matplotlib.pyplot as plt

# To do KM survival analysis
!pip install lifelines
from lifelines import KaplanMeierFitter

# To generate interactive widgets
!pip install ipywidgets
import ipywidgets as widgets
from ipywidgets import Output
from ipywidgets import interact, interactive, fixed, interact_manual, Dropdown, HBox, VBox, Layout, Label

# To log relevant information during the runtime
import logging

# Clear the output of the cell
clear_output()

The following steps are to set up the logging file, so we can register some relevant information in case we need to troubleshoot. After each session with the notebook, a file called ***MyLog.txt*** will be created in the same directory where the notebook is stored. It is recommended that the user renames this file if a problem arises because it gets rewritten every time the notebook is re-run. 

In [None]:
# Configure the logging settings
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Create a logger
logger = logging.getLogger()

# Clear the existing log file or create a new empty one
open('MyLog.txt', 'w').close()

# Create a file handler
file_handler = logging.FileHandler('MyLog.txt')
file_handler.setLevel(logging.INFO)

# Create a formatter and add it to the file handler
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)

# Add the file handler to the logger
logger.addHandler(file_handler)

# Log an initial message
logger.info(f"Log file created or cleared. \n")
clear_output()

### Upload datasets

Request the user to upload:

**1.** File with clinical data (named: ***clinical.txt***).

**2.** -Optional- file with RNA Seq data (named: ***RNA.txt***).

**NOTE:**  If the kernel is restarted and the file(s) already exist in the working directory (where the notebook is stored), the notebook will locate the files automatically and won't request their upload (to save time).

In [None]:
# Function that checks if the files required are in the directory or prompts the user to upload them
def upload_input_files():

    # Check if the clinical file is already in the working directory
    if "clinical.txt" in os.listdir():
        # Read the clinical file and log it
        df_clinical = pd.read_csv("clinical.txt", sep="\t", comment="#")
        logger.info(f"Uploaded: clinical.txt \n")
        clear_output()
        
        # Check if there is also an RNA file in the directory
        if "RNA.txt" in os.listdir():
            # Read the RNA file
            df_RNA = pd.read_csv("RNA.txt", sep="\t")
            logger.info(f"Uploaded: RNA.txt \n")
            clear_output()
        else:
            df_RNA = None           
    else:
        # Ask the user to select the files to upload
        print("Select the file(s) to upload (all at once):")
        uploaded_files = files.upload()

        # Read the clinical file
        df_clinical = pd.read_csv("clinical.txt", sep="\t", comment="#")
        logger.info(f"Uploaded: clinical.txt \n")

        # Check if an RNA file was uploaded, and read it if so
        if "RNA.txt" in uploaded_files:
            df_RNA = pd.read_csv("RNA.txt", sep="\t")
            logger.info(f"Uploaded: RNA.txt \n")
        else:
            df_RNA = None
        
    clear_output()

    # Return the dataframes for the next steps
    return df_clinical, df_RNA

### Pre-process uploaded files

This function currently processes a clinical file and a RNA file (optional) in the following way:

* Some columns of the clinical dataset are searched and re-ordered based on relevant information that could be used for a KM plot, including: ***VITAL_STATUS, OS_STATUS + OS_MONTHS, PFS_STATUS + PFS_MONTHS, RFS_STATUS + RFS_MONTHS***. These columns appear in the METABRIC clinical dataset and likely have similar names in other studies (otherwise adjustments may be needed).

* The METABRIC RNA dataset has all the genes measured as rows and patient IDs as columns, so the df is transposed and sorted in both axes so we can merge it later with the clinical df. Before doing that we also removed an unnecesary column with gene ID number (Entrez_Gene_Id).

In [None]:
# This function searches OS/PFS/RFS _STATUS and _MONTHS columns in the clinical data
# This function transposes the RNA dataset to have gene names as columns and patient IDs as rows (like the clinical data)
def file_preprocessing(df_clinical, df_RNA):

    ################### Processing for the clinical dataframe ##################

    # Log the original dataframe
    logger.info(f"Preview of the original clinical dataset: \n {df_clinical.iloc[:15, :10].to_string()} \n")
    logger.info(f"Data types of columns in the original clinical dataset: \n {df_clinical.dtypes.to_string()} \n\n")
    clear_output()

    # Prepare the variable with the reordered column names
    clinical_columns_main = ["PATIENT_ID"]
    time_to_event = []
    event_observation = []

    # Search for a column of Vital status (this is optional and may provide more information)
    if "VITAL_STATUS" in df_clinical.columns:
        clinical_columns_main.append("VITAL_STATUS")
        event_observation.append("VITAL_STATUS")

    # Search for Overall Survival columns
    if "OS_MONTHS" in df_clinical.columns and "OS_STATUS" in df_clinical.columns:
        clinical_columns_main.append("OS_MONTHS")
        time_to_event.append("OS_MONTHS")
        clinical_columns_main.append("OS_STATUS")
        event_observation.append("OS_STATUS")

    # Search for Recurrence-Free Survival columns
    if "RFS_MONTHS" in df_clinical.columns and "RFS_STATUS" in df_clinical.columns:
        clinical_columns_main.append("RFS_MONTHS")
        time_to_event.append("RFS_MONTHS")
        clinical_columns_main.append("RFS_STATUS")
        event_observation.append("RFS_STATUS")

    # Search for Progression-Free Survival columns
    if "PFS_MONTHS" in df_clinical.columns and "PFS_STATUS" in df_clinical.columns:
        clinical_columns_main.append("PFS_MONTHS")
        time_to_event.append("PFS_MONTHS")
        clinical_columns_main.append("PFS_STATUS")
        event_observation.append("PFS_STATUS")

    # Order alphabetically the remaining columns
    clinical_columns_extra = [col for col in df_clinical.columns if col not in clinical_columns_main]
    clinical_columns_extra.sort()

    # Apply the re-ordering to the df
    clinical_columns_ordered = clinical_columns_main + clinical_columns_extra
    df_clinical = df_clinical[clinical_columns_ordered] 

    # Log the re-arranged dataframe
    logger.info(f"Preview of the pre-processed clinical dataset: \n {df_clinical.iloc[:15, :10].to_string()} \n")
    logger.info(f"Data types of columns in the pre-processed clinical dataset: \n {df_clinical.dtypes.to_string()} \n\n")
    clear_output()

    ##################### Processing for the RNA dataframe #####################

    # If an RNA file was uploaded, then the df is not empty
    if df_RNA is not None:
        # Log the original dataframe
        logger.info(f"Preview of the original RNA dataset: \n {df_RNA.iloc[:15, :10].to_string()} \n")
        logger.info(f"Data types of some columns in the original RNA dataset: \n {df_RNA.iloc[:, :10].dtypes.to_string()} \n\n")
        clear_output()
        
        # Drop the "Entrez_Gene_Id" column
        df_RNA.drop("Entrez_Gene_Id", axis=1, inplace=True)

        # Rename the "Hugo_Symbol" column to "PATIENT_ID" as it appears in the clinical df
        df_RNA.rename(columns={"Hugo_Symbol": "PATIENT_ID"}, inplace=True)

        # Transpose the dataframe, making the content of the "PATIENT_ID" column the new column names
        df_RNA = df_RNA.set_index("PATIENT_ID").T

        # Sort the gene names alphabetically
        df_RNA.sort_index(axis=1, inplace=True)

        # Reset the index to a numerical index
        df_RNA = df_RNA.reset_index().rename_axis("", axis="columns")

        # Rename the "index" column to "PATIENT_ID"
        df_RNA.rename(columns={"index": "PATIENT_ID"}, inplace=True)

        # Sort Patient IDs and reset the index
        df_RNA = df_RNA.sort_values("PATIENT_ID").reset_index(drop=True)

        # Log the re-arranged dataframe
        logger.info(f"Preview of the pre-processed RNA dataset: \n {df_RNA.iloc[:15, :10].to_string()} \n")
        logger.info(f"Data types of some columns in the pre-processed RNA dataset: \n {df_RNA.iloc[:, :10].dtypes.to_string()} \n\n")
        clear_output()

    ############################################################################

    # Return the variables with the survival/progression/recurrence labels available in the datset
    return df_clinical, df_RNA, time_to_event, event_observation

### Create interactive widgets

The general way to create interactive widgets is the following:

**1.** Create the ***widget***.

**2.** Create a ***widget output*** object (can be initially empty).

**3.** Declare a ***handler function*** that is called when changes in the widget are made (here things can be shown in the widget output object).

**4.** Call the ***observe*** attribute on the widget, specifying its handler function.

**5.** ***Display*** the initial widget and initial state of the widget output.

* Additionally, we can ***nest more subwidgets*** that appear when certain values/options in the main widget are selected, and ***each can have their own handler function and output object***. This is ideal for fine-tunning of complex interactions, however, the code can get very confusing nesting handler functions, observe/display/output calls inside others, in multiple levels. 

* An alternative approach used in this notebook is to first declare all the widget handler functions in the order they are called (with the exception of 1). Next, declare a general ***widget_preparation*** function that creates the widgets, their output objects, calls the observe attribute, and then displays all of them at once in a specific configuration. For this to work we had to set several variables as global and others to be passed with labmda functions when making the observe calls (since observe calls are not regular function calls with input/output parameters).

#### Time to Event handler

Once one of the options of the time to event dropdown is selected (***OS_MONTHS, PFS_MONTHS, RFS_MONTHS***), this function ***creates and displays a histogram*** so we can have a quick look at the information for all the patients.

In [None]:
# Function to display the output of time_to_event_dropdown (histogram)
def time_to_event_selection_handler(change):
    
    # Clear the output and return early if the default option is selected back 
    if change["new"] == "Click here to select...":
        with time_to_event_selection_info:
            time_to_event_selection_info.clear_output()
        return

    # Check if the name selected is a column on the clinical dataframe
    column_name = change["new"]
    if column_name in df_clinical.columns:
        time_column = df_clinical[column_name].dropna()
        logger.info(f"The user selected: {column_name}     Widget: time_to_event_dropdown. \n")
        logger.info(f"Original dtype of {column_name}: {df_clinical[column_name].dtype}     Dtype once removing NANs: {time_column.dtype} \n")
        
        # Make histogram of values and handle exceptions
        if time_column.dtype != "object":
            plt.figure(figsize=(4, 2))
            plt.hist(time_column, bins=12, color="darkgoldenrod", ec="white")
            plt.xlabel(column_name)
            plt.ylabel('Number of patients')
            plt.title(f'Histogram of {column_name}')
            logger.info(f"A histogram was succesfully made and displayed for: {column_name} \n")
        else:
            info_str = "Warning: Column type is not numeric."
            logger.warning("User attention required: The time to event column may not be numerical. \n")
    else:
        info_str = "Warning: Column not found in the dataframe."
        logger.error("User attention required: The time to event column name was not found in the df. \n")

    # Clear the output widget and display the additional information
    with time_to_event_selection_info:
        time_to_event_selection_info.clear_output()
        plt.show() if time_column.dtype != "object" else print(info_str)

#### Event Observation handler

Once a selection has been made in the event observation dropdown (***VITAL_STATUS, OS_STATUS, PFS_STATUS, RFS_STATUS***), this function ***creates and displays a bar chart*** with the unique values found in the column selected. In addition, this function shows ***two other widgets** (tags input) that allow the user to assign the unique values that want to be ***labeled as 0 (No event was observed) and 1 (An event was observed)***. These selections are necessary and will be applied before passing the data to the KaplanMeierFitter objects.

In [None]:
# Function to display the output of event_observation_dropdown (bar chart)
def event_observation_selection_handler(change):

    # Clear the output and return early if the default option is selected back 
    if change["new"] == "Click here to select...":
        event_observation_selection_info.clear_output()
        event_observation_tagsinput_info.clear_output()
        return
        
    # Check if the name selected is a column on the clinical dataframe
    column_name = change["new"]
    if column_name in df_clinical.columns:
        event_column = df_clinical[column_name]
        logger.info(f"The user selected: {column_name}     Widget: event_observation_dropdown. \n")
        logger.info(f"Dtype of {column_name}: {df_clinical[column_name].dtype}     Unique value counts: \n\t {event_column.value_counts(dropna=False).to_string()} \n")
        
        # Make a bar chart for unique values and handle exceptions
        if event_column.dtype == "object":
            value_counts = event_column.value_counts(dropna=False)
            
            if df_clinical[column_name].nunique() > 15:
                logger.warning("User attention required: There may be something wrong with the event observation column as there are more than 15 unique values. \n")
            
            plt.figure(figsize=(4, 2))
            value_counts.plot(kind='bar', color=['maroon', 'purple', 'green', 'crimson', 'navy', 'coral', 'lavender'])
            plt.xlabel(column_name)
            plt.xticks(rotation=45)
            plt.ylabel('Number of patients')
            plt.title(f'Bar Chart of {column_name}')
            logger.info(f"A bar chart was succesfully made and displayed for: {column_name} \n")
        else:
            info_str = "Warning: Column type is not categorical."
            logger.warning("User attention required: The event observation column may not be text-based. \n")
    else:
        info_str = "Warning: Column not found in the dataframe."
        logger.error("User attention required: The event observation column name was not found in the df. \n")

    # Clear the output widget and display the additional information
    with event_observation_selection_info:
        event_observation_selection_info.clear_output()
        plt.show() if event_column.dtype == "object" else print(info_str)
    
    # Following selection and plotting, get the unique values from the current column selection
    event_values = np.ndarray.tolist(df_clinical[event_observation_dropdown.value].unique())

    # Make subwidgets to specify the event to be observed so we can encode it in binary for kmf.fit
    global event_observation_tagsinput0, event_observation_tagsinput1
    event_observation_tagsinput0 = widgets.TagsInput(allowed_tags=event_values)
    event_observation_tagsinput1 = widgets.TagsInput(allowed_tags=event_values)
    event_observation_tagsinput = [widgets.HTML("<br/>"), HBox([widgets.Label("0:"), event_observation_tagsinput0]),
                                   widgets.HTML("<br/>"), HBox([widgets.Label("1:"), event_observation_tagsinput1])]
    event_observation_tagsinput_box = VBox(event_observation_tagsinput)
    
    with event_observation_tagsinput_info:
        event_observation_tagsinput_info.clear_output()
        display(event_observation_tagsinput_box)

#### Subgroup Options handler

When the button default value (***No***) is changed to ***Use variable(s)***, this function ***displays a slider*** that will make widget areas for each of the variables wanted. When the button default value is selected again, the ***outputs*** from the other button are all ***cleared up***.

In [None]:
# Function to display the output of subgroup_buttons (slider and widget areas based on the slider)
def subgroup_selection_handler(change):

    # The slider needs to be dynamically set/reset
    global variable_number_slider, variable_repeats
    
    # If the user wants to make subgroups request a dataset, variable and number of groups (per variable)
    if change['new'] == 'Use variable(s)':
        logger.info(f"The user selected: Use variable(s)     Widget: subgroup_buttons \n")
        
        # Show and observe changes in the multiple variable slider and add/remove widget areas as needed
        variable_number_slider = widgets.IntSlider(min=1, max=5, value=1)
        variable_number_slider.observe(variable_number_selection_handler, 'value')
        display(HBox([widgets.HTML('\u2003' * 4), widgets.Label("Make subgroups?:"), subgroup_buttons, widgets.HTML('\u2003' * 4), widgets.Label("Number of variables:"), variable_number_slider]), display_id='variable_number_slider')

        # Set the initial/minimum number of variable repeats and call the handler function (because 1 is the default)
        variable_repeats = 1
        variable_number_selection_handler({"new": variable_repeats})
        
    else:
        logger.info(f"The user selected: No     Widget: subgroup_buttons \n")
        
        # Clear the widget area(s) if the other button was selected previously
        for i in reversed(range(5)):
            subgroup_output_areas[i]['subgroup_maker1_info'].clear_output()
            subgroup_output_areas[i]['subgroup_maker2_info'].clear_output()
            subgroup_output_areas[i]['subgroup_options_selection_info'].clear_output()

        # We do not need show the multiple variable slider when No is selected
        display(HBox([widgets.HTML('\u2003' * 4), widgets.Label("Make subgroups?:"), subgroup_buttons]), display_id='variable_number_slider')

#### Variable Number handler

This function is manually triggered with the default (***1***) value of variables to do subgroups. Since the maximum number of widget areas and output objects are already created (***5***), this function ***calls the observe attribute and displays as many widget areas missing*** from the current number. In addition, if we selected a lower number than the previous one (changed our mind from 3 variables to 2, for example), this function ***removes the extra widget areas*** from the display (they are not deleted, just not shown anymore).

In [None]:
# Function to display the output of variable_number_slider (widget areas)
def variable_number_selection_handler(change):
    
    logger.info(f"Number of variables to make subgroups selected: {change['new']} \n")
    global variable_repeats

    # Add more widget areas if the user increased the number of variables desired
    if change['new'] > variable_repeats:
    
        # The number of areas to add is the new number minus the previous (so the missing are added instead of creating all again)
        for repeat in range(variable_repeats, change['new']):
            
            # Reset and initialize the first widget in the widget area for each new variable
            subgroup_widget_areas[repeat]['dataset_dropdown'].value='Click here to select...'
            subgroup_widget_areas[repeat]['dataset_dropdown'].observe(lambda value, repeat=repeat: dataset_selection_handler(value, repeat), 'value')

            # Clear the outputs of the new widget area and show the initial widget
            subgroup_output_areas[repeat]['subgroup_maker1_info'].clear_output()
            subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
            with subgroup_output_areas[repeat]['subgroup_options_selection_info']:
                subgroup_output_areas[repeat]['subgroup_options_selection_info'].clear_output()
                display(widgets.HTML("<hr>"))
                display(subgroup_widget_areas[repeat]['dataset_dropdown'])
   
    # Remove widget areas if the user decreased the number of variables desired
    elif change["new"] < variable_repeats:
        
        # Clear out any previous outputs in the widget areas to remove
        for repeat in range(variable_repeats - 1, change['new'] - 1, -1):
            subgroup_output_areas[repeat]['subgroup_maker1_info'].clear_output()
            subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
            subgroup_output_areas[repeat]['subgroup_options_selection_info'].clear_output()

    # Show the minimum widget repeats number (1) when the button is selected for the first time
    else:
        # Reset and initialize the first widget in the widget area
        subgroup_widget_areas[0]['dataset_dropdown'].value='Click here to select...'
        subgroup_widget_areas[0]['dataset_dropdown'].observe(lambda value, repeat=0: dataset_selection_handler(value, repeat), 'value')

        # Clear the outputs of the new widget area and show the initial widget
        subgroup_output_areas[0]['subgroup_maker1_info'].clear_output()
        subgroup_output_areas[0]['subgroup_maker2_info'].clear_output()
        with subgroup_output_areas[0]['subgroup_options_selection_info']:
            subgroup_output_areas[0]['subgroup_options_selection_info'].clear_output()
            display(widgets.HTML("<hr>"))
            display(subgroup_widget_areas[0]['dataset_dropdown'])
        
    # Get the current number of widget areas to create
    variable_repeats = change["new"]        

#### Dataset handler

Once a widget area is created, a dropdown is shown to select the dataset (***clinical or RNA***) where the variable of interest is. Once a selection is made, this function ***shows two other widgets*** to the right of the original, one is a dropdown or combobox to search for the variable/column of interest, and the other is a slider with the number of subgroups desired.

**NOTE:** The observe call for this function has a lambda function to also pass the number of repeat, which indicates where exactly the two subwidgets need to be shown (area 1-5). This is needed because we are using this function to do the same thing for all the 5 possible areas.

In [None]:
# Function to display the output of dataset_dropdown (two subwidgets) 
def dataset_selection_handler(change, repeat):

    # Reset the subgroup_number_slider
    subgroup_widget_areas[repeat]['subgroup_number_slider'].value = 1
    subgroup_widget_areas[repeat]['variables_dropdown'].value = 'Click here to select...'
    if df_RNA is not None:
        subgroup_widget_areas[repeat]['variables_combobox'].value = ''
        logger.info(f"There are RNA and clinical dataframes available to make subgroups. \n")

    # Clear the previous outputs
    subgroup_output_areas[repeat]['subgroup_maker1_info'].clear_output()
    subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
    subgroup_output_areas[repeat]['subgroup_options_selection_info'].clear_output()
    
    # Display a dropdown and slider for clinical variables (excluding PATIENT_ID column)
    if change["new"] == "clinical":
        logger.info(f"The {change['new']} dataset was selected to make subgroups. \n")
        
        with subgroup_output_areas[repeat]['subgroup_options_selection_info']:
            subgroup_widget_areas[repeat]['variables_dropdown'].observe(lambda value, repeat=repeat: variables_selection_handler(value, repeat), 'value')
            subgroup_widget_areas[repeat]['subgroup_number_slider'].observe(lambda value, repeat=repeat: group_number_selection_handler(value, repeat), 'value')
            display(widgets.HTML("<hr>"))
            display(HBox([subgroup_widget_areas[repeat]['dataset_dropdown'], subgroup_widget_areas[repeat]['variables_dropdown'], widgets.HTML('\u2003' * 3), subgroup_widget_areas[repeat]['subgroup_number_slider']]))
    
    # Display a combobox and slider for RNA variables (excluding PATIENT_ID column)
    elif change["new"] == "RNA":
        logger.info(f"The {change['new']} dataset was selected to make subgroups. \n")
        
        with subgroup_output_areas[repeat]['subgroup_options_selection_info']:
            subgroup_widget_areas[repeat]['variables_combobox'].observe(lambda value, repeat=repeat: variables_selection_handler(value, repeat), 'value')
            subgroup_widget_areas[repeat]['subgroup_number_slider'].observe(lambda value, repeat=repeat: group_number_selection_handler(value, repeat), 'value')
            display(widgets.HTML("<hr>"))
            display(HBox([subgroup_widget_areas[repeat]['dataset_dropdown'], subgroup_widget_areas[repeat]['variables_combobox'], widgets.HTML('\u2003' * 3), subgroup_widget_areas[repeat]['subgroup_number_slider']]))
    
    # Clear the output and display the original dropdown if the default option is selected back again
    else:
        logger.info(f"The previous dataset to make subgroups was de-selected. \n")

        with subgroup_output_areas[repeat]['subgroup_options_selection_info']:
            display(widgets.HTML("<hr>"))
            display(subgroup_widget_areas[repeat]['dataset_dropdown'])

#### Variable Selection handler

Once a dataset selection triggers the appearance of two other widgets, this function is called when a selection is made in the first of those. This function ***creates and displays a bar chart*** when the variable selected (column name) contains ***text*** data, or a ***histogram*** when the variable contains ***numbers***. Depending on the dataset selected, this function handles two possible widgets, either a dropdown (clinical dataset) or a combobox (RNA dataset).

**NOTE:** The observe call for this function has a lambda function to also pass the number of repeat, which indicates where exactly the plots need to be shown (area 1-5). This is needed because we are using this function to do the same thing for all the 5 possible areas.

In [None]:
# Function to display the output of variables_dropdown and variables_combobox (plots)
# Reminder that this function has to work for both df_clinical and df_RNA
def variables_selection_handler(change, repeat):
    
    # Reset the subgroup_number_slider to remove any previous widget boxes
    subgroup_widget_areas[repeat]['subgroup_number_slider'].value = 1
    
    # Display empty space if the default value on the dropdown or combobox is selected
    if change["new"] == 'Click here to select...' or change["new"] == '':
        logger.info(f"The previous variable to make subgroups was de-selected. \n")
        subgroup_output_areas[repeat]['subgroup_maker1_info'].clear_output()
        subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
        return
    # If the user has selected a variable, do 4 steps: Apply event labels, find the column, plot, capture all columns of interest
    else:
        ##### Preparation part
        # Make these variables global so they can be accessed from other subwidgets
        global column_data, KM_data_1var, KM_data_all
        KM_data_1var = df_clinical.copy()
        
        # This is to avoid plotting while incomplete strings are being searched in the combobox
        if subgroup_widget_areas[repeat]['variables_combobox'] is not None and subgroup_widget_areas[repeat]['dataset_dropdown'].value == "RNA": 
            if change["new"] not in subgroup_widget_areas[repeat]['variables_combobox'].options:
                subgroup_output_areas[repeat]['subgroup_maker1_info'].clear_output()
                subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
                return

        ##### Step 01 
        # Apply the 0 and 1 labels to the event observed column and filter those values 
        if event_observation_tagsinput0.value and event_observation_tagsinput1.value:
            for tag in event_observation_tagsinput0.value:
                KM_data_1var[event_observation_dropdown.value] = KM_data_1var[event_observation_dropdown.value].replace(tag, "0")
            for tag in event_observation_tagsinput1.value:
                KM_data_1var[event_observation_dropdown.value] = KM_data_1var[event_observation_dropdown.value].replace(tag, "1")

            # Filter the selected 0/1 event labels, transform column to integers and keep only the common first 3 columns 
            KM_data_1var = KM_data_1var.loc[KM_data_1var[event_observation_dropdown.value].isin(["0", "1"])]
            KM_data_1var[event_observation_dropdown.value] = KM_data_1var[event_observation_dropdown.value].astype(int)

        # Log the current status of KM_data_1var
        logger.info(f"[Subgrouping 1st step] The user selected to label -{str(event_observation_tagsinput0.value)}- as 0, and -{str(event_observation_tagsinput1.value)}- as 1. \n")
        logger.info(f"[Subgrouping 1st step] Apply 0/1 labels to column {event_observation_dropdown.value} on KM_data_1var: \n {KM_data_1var.iloc[:15, :10].to_string()} \n")
        logger.info(f"[Subgrouping 1st step] Data types of KM_data_1var columns: \n {KM_data_1var.dtypes.to_string()} \n\n")

        ##### Step 02
        # Look for the selected column in either df, as it is not specified within this function
        if change["new"] in df_clinical.columns:
            # Keep only the working columns, log it and extract the column to plot the values
            KM_data_1var = KM_data_1var[['PATIENT_ID', time_to_event_dropdown.value, event_observation_dropdown.value, change["new"]]]          
            logger.info(f"[Subgrouping 2nd step] The column {change['new']} -{KM_data_1var.dtypes[change['new']]} dtype- from df_clinical was selected to make subgroups. \n")
            column_data[repeat] = KM_data_1var[change["new"]].copy()
            
        # If the column is in df_RNA, joining is needed to combine it with the clinical columns
        elif df_RNA is not None and change["new"] in df_RNA.columns:
            # Keep only the working columns from both dfs, log it and extract the column to plot the values
            KM_data_1var = KM_data_1var[['PATIENT_ID', time_to_event_dropdown.value, event_observation_dropdown.value]]
            df_RNA2 = df_RNA[['PATIENT_ID', change["new"]]]
            KM_data_1var = KM_data_1var.merge(df_RNA2, on='PATIENT_ID', how='inner')
            logger.info(f"[Subgrouping 2nd step] The column {change['new']} -{KM_data_1var.dtypes[change['new']]} dtype- from df_RNA was selected to make subgroups. \n")
            column_data[repeat] = KM_data_1var[change["new"]].copy()

        # Log the current status of KM_data_1var 
        logger.info(f"[Subgrouping 2nd step] Keep relevant columns of KM_data_1var and only rows with 0/1 event labels: \n {KM_data_1var.iloc[:15, :10].to_string()} \n")
        logger.info(f"[Subgrouping 2nd step] Data types of KM_data_1var columns: \n {KM_data_1var.dtypes.to_string()} \n\n")

        ##### Step 03
        # Make and display a bar chart for text columns showing the counts for unique values
        if column_data[repeat].dtype == 'object':
            value_counts = column_data[repeat].value_counts(dropna=False)
            fig, ax = plt.subplots(figsize=(5, 3))
            value_counts.plot(kind='bar', color=['indigo', 'khaki', 'lightblue', 'salmon', 'sienna', 'silver', 'aquamarine', 'coral', 'teal', 'olive'])
            ax.set_xlabel(change["new"])
            ax.set_ylabel('Count')
            ax.set_title(f'Unique Value Counts for {change["new"]}')
            plt.xticks(rotation=45)

            # Clear the specific space to show the new plot
            subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
            with subgroup_output_areas[repeat]['subgroup_maker1_info']:
                subgroup_output_areas[repeat]['subgroup_maker1_info'].clear_output()
                plt.show()
                display(HTML('<span style="color: red;">This plot shows rows with the 0 and 1 events especified above!</span>'))
        
        # Make and display a histogram of frequencies for numerical columns
        else:
            fig, ax = plt.subplots(figsize=(5, 3))
            ax.hist(column_data[repeat], bins='auto', color="darkblue", ec="white")
            ax.set_xlabel(change["new"])
            ax.set_ylabel('Frequency')
            ax.set_title(f'Histogram for {change["new"]}')

            # Clear the specific space to show the new plot
            subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
            with subgroup_output_areas[repeat]['subgroup_maker1_info']:
                subgroup_output_areas[repeat]['subgroup_maker1_info'].clear_output()
                plt.show()
                display(HTML('<span style="color: red;">This plot shows rows with the 0 and 1 events especified above!</span>'))

        ##### Step 04
        # Pass the information of the current variable KM_data_1var so we can have all variables of interest in KM_data_all
        if repeat == 0:
            # The first variable gets just passed as it is after the steps above
            KM_data_all = KM_data_1var.copy()
        else:
            # Check if the column is already in KM_data_all
            if repeat + 3 in KM_data_all.columns:
                # Replace the values in the existing column
                KM_data_all.iloc[:, repeat + 3] = KM_data_1var[change["new"]].copy()
            else:
                # Create a new column and assign the values
                KM_data_all[change["new"]] = KM_data_1var[change["new"]].copy()
        
            # Rename the column to the desired name
            KM_data_all.rename(columns={repeat + 3: change["new"]}, inplace=True)

        # Log the KM_data_all as we add/replace variables/columns of interest 
        logger.info(f"[Subgrouping 2nd step] Updated KM_data_all with columns of interest: \n {KM_data_all.iloc[:15, :10].to_string()} \n")
        
        # After the plot is made trigger the slider function to show two (the minimum) widgets in the box 
        subgroup_widget_areas[repeat]['subgroup_number_slider'].value = 2

#### Group Number handler

Once a dataset selection triggers the appearance of two other widgets, this function is called when a selection is made in the second of those (***int slider***). This function ***creates and displays a widget box*** containing repeats of either ***tags or float range*** widgets according to the number of groups selected in the slider. 

**NOTE:** The observe call for this function has a lambda function to also pass the number of repeat, which indicates where exactly the widget boxes need to be shown (area 1-5). This is needed because we are using this function to do the same thing for all the 5 possible areas.

In [None]:
# Function to display the output of group_number slider (widget boxes)
def group_number_selection_handler(change, repeat):
     
    # This function uses the global variable column_data[repeat] created in the function below
    global subgroup_boxes, subgroup_tagsinput, subgroup_floatrangeslider

    # When the default value on the variable dropdown or combobox is selected, do not display subgrouping options
    if subgroup_widget_areas[repeat]['variables_dropdown'].value == 'Click here to select...':
        if df_RNA is None or subgroup_widget_areas[repeat]['variables_combobox'].value == '':
            with subgroup_output_areas[repeat]['subgroup_maker2_info']:
                subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
                display(HTML('<span style="color: red;">Choose a variable first!</span>'))
            return
    
    # Use tags to specify the desired groups for text columns (min 2, max 10 groups)
    if change["new"]>1 and column_data[repeat].dtype == 'object':
        logger.info(f"[Subgrouping 3rd step] The user selected to make {change['new']} subgroups with tags input labels. \n")

        # Make as many tags input widgets and labels as selected and put them in their corresponding box
        unique_values = np.ndarray.tolist(column_data[repeat].unique())
        subgroup_tagsinput = [widgets.TagsInput(allowed_tags=unique_values, description=f'Group {i+1}') for i in range(change["new"])]
        subgroup_tagsinput_labels = [widgets.Label(value=f'Group {i+1}') for i in range(change["new"])]
        subgroup_boxes[repeat] = VBox([HBox([label, tagsinput]) for label, tagsinput in zip(subgroup_tagsinput_labels, subgroup_tagsinput)])

        # Clear the specific space to show the widget box
        with subgroup_output_areas[repeat]['subgroup_maker2_info']:
            subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
            display(subgroup_boxes[repeat])
            
    # Use range sliders to specify the desired groups for numerical columns (min 2, max 10 groups)
    elif change["new"]>1:
        logger.info(f"[Subgrouping 3rd step] The user selected to make {change['new']} subgroups with float range sliders. \n")
        
        # Make as many float range slider widgets as selected and put them in their corresponding box
        cleaned_column_data = column_data[repeat].dropna()
        min_value = cleaned_column_data.min()
        max_value = cleaned_column_data.max()
        subgroup_floatrangeslider = [widgets.FloatRangeSlider(min=min_value, max=max_value, step=0.01, description=f'Group {i + 1}') for i in range(change["new"])]
        subgroup_boxes[repeat] = VBox(subgroup_floatrangeslider)

        # Clear the specific space to show the widget box
        with subgroup_output_areas[repeat]['subgroup_maker2_info']:
            subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
            display(subgroup_boxes[repeat]) 
    
    # Tell the user to remove the current variable if they want one group
    else:
        with subgroup_output_areas[repeat]['subgroup_maker2_info']:
            subgroup_output_areas[repeat]['subgroup_maker2_info'].clear_output()
            display(HTML('<span style="color: red;">For 1 group just remove this variable!</span>'))

#### Displaying Widget Areas

Once the file(s) are uploaded by the user and pre-processed by the other function above, this function is called with the dataset(s) provided and the time+event columns found in the clinical dataset. Then, it ***creates all the main widgets, their output objects, makes the observe calls to their respective handlers, and makes the layout to display everything (defined by the author)***.

***NOTE:*** Most widgets are straightforward and are created with their output, then observe with their function (above), and are displayed. However, the gray buttons widget to make subgroups is the most complex one and required to make pre-defined widget areas and widget output areas (5), which are displayed always but will fill and update upon user interaction. There are three output areas: the first where 3 widgets display horizontally upon specific selections, then the other two are below of that area and split the screen in two, on the left side we have plots that provide information useful to decide on the ranges or labels of the subgroups to make, and on the right side we have the widgets that will help us select the specific values for each subgroup.

In [None]:
# Main function to prepare and display the interactive widgets and subwidgets
def widget_preparation(df_clinical, df_RNA, time_to_event, event_observation):

    logger.info(f"---------------User interaction with the widgets starts here--------------- \n")
    clear_output()
    
    #################################### First widget - Time to event ########################################

    global time_to_event_dropdown, time_to_event_selection_info
    
    # Create the widget 
    time_to_event_dropdown = widgets.Dropdown(options=["Click here to select..."] + time_to_event)

    # Create the widget output (plot)
    time_to_event_selection_info = Output()

    # Observe changes in the widget to call the function above            
    time_to_event_dropdown.observe(time_to_event_selection_handler, names="value")
    
    #################################### Second widget - Event observation ####################################

    # Variables accessed by multiple functions
    global event_observation_dropdown, event_observation_selection_info, event_observation_tagsinput_info
    
    # Create the widget 
    event_observation_dropdown = widgets.Dropdown(options=["Click here to select..."] + event_observation)

    # Create the widget outputs (plots and subwidgets)
    event_observation_selection_info = Output()
    event_observation_tagsinput_info = Output()
    
    # Observe changes in the widget to call the function above 
    event_observation_dropdown.observe(event_observation_selection_handler, names="value")

    #################### Third widget - Handle time and event columns that have other names ###################

    # Create the main widget
    event_observation_checkbox = widgets.Checkbox(description="Can't find your columns?", value=False)

    # This situation will be included in later releases, for now...
    event_observation_checkbox = widgets.HTML("")
    
    #################################### Fourth widget - Making subgroups #####################################

    # Variables accessed by multiple functions
    global subgroup_buttons, variable_number_slider, subgroup_widget_areas, subgroup_output_areas
    global event_observation_tagsinput0, event_observation_tagsinput1, column_data, subgroup_boxes, KM_data_all
    
    # Create the main subgrouping widgets 
    subgroup_buttons = widgets.ToggleButtons(options=["No", "Use variable(s)"])
    variable_number_slider = widgets.HTML("")

    # Create 5 subwidget areas (more subwidgets) to use a maximum of 5 variables with this button (min 1)
    subgroup_widget_areas = [{ 'dataset_dropdown' : widgets.Dropdown(options=['Click here to select...', 'clinical'] + (['RNA'] if df_RNA is not None else []),
                                                                     value='Click here to select...', description='Dataset:'),
                              'variables_dropdown' : widgets.Dropdown(options=['Click here to select...'] + list(df_clinical.columns[1:]), 
                                                                      value='Click here to select...', description='Variables:'),
                              'variables_combobox' : widgets.Combobox(options=list(df_RNA.columns[1:]), placeholder='Type gene of interest here', 
                                                                      description='Genes:') if df_RNA is not None else None,
                              'subgroup_number_slider' : widgets.IntSlider(min=1, max=10, description='Groups:', value=1) 
                              } for i in range(5)]      

    # Create 5 subwidget output areas (one per area above)
    subgroup_output_areas = [{'subgroup_options_selection_info' : Output(), 
                              'subgroup_maker1_info' : Output(),
                              'subgroup_maker2_info' : Output()
                             } for i in range(5)]

    # Variables that need to be initialized with specific values/number of items
    event_observation_tagsinput0 = None
    event_observation_tagsinput1 = None
    column_data = ["0", "1", "2", "3", "4"]
    subgroup_boxes = ["0", "1", "2", "3", "4"]
    KM_data_all = pd.DataFrame(columns=["0", "1", "2"])
    
    # Observe changes in the main widget to call the rest if a change is detected  
    subgroup_buttons.observe(subgroup_selection_handler, 'value') 

    #################################### Fifth widget - Calling KM Fitter #####################################
    
    # Variables accessed by multiple functions
    global generate_plot_button, KM_plot_area
    
    # Create the widget 
    generate_plot_button = widgets.Button(description='Generate plot', disabled=False, button_style='success', 
                                          tooltip='Click me and wait for the plot to be displayed below!',
                                          icon='chart-line') 

    # Create the widget output (plot)
    KM_plot_area = Output()
            
    # Call the KM Fitter when the button is clicked
    generate_plot_button.on_click(pass_KM_parameters)

    ################################## Six widget - KM plot customization tools ###############################
    
    # Variables accessed by multiple functions
    global CI_checkbox, plot_labels_checkbox
    
    # Create the main widget
    CI_checkbox = widgets.Checkbox(description="Show Confidence Intervals", value=True)
    plot_labels_checkbox = widgets.Checkbox(description="Move legend to right side", value=False)

    #################################### Seventh widget - Saving results ######################################
    
    # Variables accessed by multiple functions
    global save_button
    
    # Once any of the plots is made and displayed, we update the saving button to show it
    save_button = widgets.Button(description='Save results', disabled=True, button_style='', 
                                 tooltip='Click and look for a jpg and an excel file in the current directory!',
                                 icon='download') 
            
    # Call the saving function when there is a plot and the save button appears
    save_button.on_click(save_KM_results)
                
    ########################################## Display initial widgets ########################################

    # First, second and third widgets go together in the first row
    display(widgets.HTML("<br/>"))
    display(HBox([widgets.Label("Time to Event:"), time_to_event_dropdown, widgets.HTML('\u2003' * 7), widgets.Label("Event Observation:"), event_observation_dropdown, event_observation_checkbox]))
    
    # The output of the first two widgets is displayed in the second row
    time_to_event_selection_info.layout.width = '41%'
    event_observation_selection_info.layout.width = '41%'
    event_observation_tagsinput_info.layout.width = '18%'
    display(HBox([time_to_event_selection_info, event_observation_selection_info, event_observation_tagsinput_info]))
    display(widgets.HTML("<br/>"))

    # The fourth widget is displayed in the third row
    display(HBox([widgets.HTML('\u2003' * 4), widgets.Label("Make subgroups?:"), subgroup_buttons]), display_id='variable_number_slider')
    display(widgets.HTML("<br/>"))

    # Make room for the maximum widget areas (5)
    for i in range(5):
        display(subgroup_output_areas[i]['subgroup_options_selection_info'])
        display(widgets.HTML("<br/>"))
        display(HBox([subgroup_output_areas[i]['subgroup_maker1_info'], subgroup_output_areas[i]['subgroup_maker2_info']]))
    
    # Finally, the button to start the KM Fitter and save the plot are diplayed in the sixth and seventg row
    display(widgets.HTML("<br/>"))
    display(HBox([generate_plot_button, CI_checkbox, plot_labels_checkbox]))
    display(save_button)
    display(KM_plot_area)

### Perform the KM analysis

This function ***prepares and passes the data to the KM fitter*** whenever the green button is clicked on, and ***retrieves the KM estimates***. Ideally, the green button should be clicked when all the other wigets have some selections on them. This function handles the case when no subgroups are needed, and the cases when multiple subgroups are selected based on one or multiple variables (columns). 

In order to pass the data in the right format, when ***No*** subgroups are needed this function applies the labels selected for event observed and uses all the rows in the clinical dataset to calculate the estimates. On the other hand, when ***Using variable(s)*** is selected and one or more variables are chosen with 2 or more subgroups for each, this function applies the subgrouping labels on each column (the event labels were already applied), then filters out rows without the selected labels, and makes subsets of the daaset (which may combine clinical and RNA columns). The ***KM_analysis*** function is called and can work with any of these inputs, providing the estimates for each subset of the dataset, which are retrieved in this function and plotted.

***NOTES:*** When plotting the KM estimates, this function will use 95% of the values to prevent the big drops or stretches that occurs when there are few samples left in the curve (the % can be easily adjusted).

In [None]:
# Function to feed the current variable selections to the KM_analysis function (in the next subsection)
def pass_KM_parameters(change):

    # These variables will get the KMF objects and plots in all scenarios
    global KM_analysis_output
    
    # If no subgrouping is required, apply the event observed tags and pass the data to KM_analysis
    if subgroup_buttons.value == 'No':
        
        # Apply the selected labels on the event observation column 
        KM_data = df_clinical.copy()
        if event_observation_tagsinput0.value and event_observation_tagsinput1.value:
            for tag in event_observation_tagsinput0.value:
                KM_data[event_observation_dropdown.value] = KM_data[event_observation_dropdown.value].replace(tag, "0")
            for tag in event_observation_tagsinput1.value:
                KM_data[event_observation_dropdown.value] = KM_data[event_observation_dropdown.value].replace(tag, "1")            
        else:
            with KM_plot_area:
                display(HTML('<span style="color: red;">Select first the values to label as 0 and 1 (No event, event)</span>'))
            return

        # Log the current status of KM_data
        logger.info(f"[No subgroups 1st step] The user selected to label -{str(event_observation_tagsinput0.value)}- as 0, and -{str(event_observation_tagsinput1.value)}- as 1. \n")
        logger.info(f"[No subgroups 1st step] Apply 0/1 labels to column {event_observation_dropdown.value} on KM_data: \n {KM_data.iloc[:15, :10].to_string()} \n")
        logger.info(f"[No subgroups 1st step] Data types of KM_data columns: \n {KM_data.dtypes.to_string()} \n\n")
                
        # Filter out non-desired values and convert column to numbers for the KM Fitter
        KM_data = KM_data[['PATIENT_ID', time_to_event_dropdown.value, event_observation_dropdown.value]]
        KM_data = KM_data.loc[KM_data[event_observation_dropdown.value].isin(["0", "1"])]
        KM_data[event_observation_dropdown.value] = KM_data[event_observation_dropdown.value].astype(int)

        # Log the current status of KM_data
        logger.info(f"[No subgroups 2nd step] Keep relevant columns of KM_data and only rows with 0/1 event labels: \n {KM_data.head(15).to_string()} \n")
        logger.info(f"[No subgroups 2nd step] Data types of KM_data columns: \n {KM_data.dtypes.to_string()} \n\n")
        
        # Pass the input parameters to the KM_analysis function and get back the KM objectt
        KM_subgroups = []           
        KM_analysis_output = KM_analysis(KM_data, KM_subgroups)

        # Plot the estimate from the KMF object
        with KM_plot_area:
            KM_plot_area.clear_output()
            plt.figure(figsize=(10, 6))
            KM_analysis_output.plot(ci_show=CI_checkbox.value, legend=False, iloc=slice(0, int(len(KM_analysis_output.survival_function_) * 0.95)))
            plt.xlabel("Time")
            plt.ylabel("Probability")
            plt.title("Kaplan-Meier Estimate")
            plt.show()
            
    ####################################################################
    # If subgroups were selected, apply the corresponding tags or ranges
    else:
        # The dfs with all the information needed are these
        global KM_data_all, subgroup_boxes

        # Log the current status of KM_data_working
        logger.info(f"[Subgrouping 3rd step] Dataset KM_data_all before applying subgrouping labels: \n {KM_data_all.head(15).to_string()} \n")
        logger.info(f"[Subgrouping 3rd step] Data types of KM_data_all before applying subgrouping labels: \n {KM_data_all.dtypes.to_string()} \n\n")
        
        # If the subgrouping changes are made multiple times, apply them to a copy of the original df
        KM_data_working = KM_data_all.copy()

        # Create an empty dictionary to store the mapping for each variable to reassign real group names
        correct_group_labels = [{} for i in range(5)]

        # Only iterate through the last number of variables in the slider (in case the user had more before)
        for repeat in range(variable_repeats):

            # To apply the tags in tagsinput we check if the VBoxes have repeats of label+tagsinput or floatrangesliders alone
            if isinstance(subgroup_boxes[repeat].children[0], widgets.FloatRangeSlider):
                
                # Create a new column to store the group labels (original is numbers, new one will be text and next to the original)
                KM_data_working.insert(repeat+4, 'TextSubgroup', '')

                # Iterate through the float range sliders
                for i, floatslider in enumerate(subgroup_boxes[repeat].children):
                    # Retrieve the range selection and corresponding label
                    subgroup_range = floatslider.value
                    subgroup_label = floatslider.description
                   
                    # Get the indices of rows within the range selected
                    subgroup_rows = (KM_data_working.iloc[:, repeat+3] >= subgroup_range[0]) & (KM_data_working.iloc[:, repeat+3] < subgroup_range[1])
                
                    # Assign the subgroup label to the matching rows
                    KM_data_working.loc[subgroup_rows, 'TextSubgroup'] = subgroup_label

                    # Add the correct label to the dictionary
                    correct_group_labels[repeat][subgroup_label] = f"{subgroup_range[0]:.2f} to {subgroup_range[1]:.2f}"
                
                # Remove rows where the subgroup label is not assigned
                KM_data_working = KM_data_working[KM_data_working['TextSubgroup'] != '']
                variable_column_name = KM_data_working.columns[repeat+3]
                KM_data_working.drop(variable_column_name, axis=1, inplace=True)
                KM_data_working.rename(columns={'TextSubgroup': variable_column_name}, inplace=True)

                # Log the ranges corresponding to each subgroup
                log_string = " - ".join([f"Group {i+1}: {slider.value[0]:.2f} to {slider.value[1]:.2f}" for i, slider in enumerate(subgroup_boxes[repeat].children)])
                logger.info(f"[Subgrouping 3rd step] Subgrouping labels applied to variable {repeat+1}---> {log_string} \n")

            # The HBoxes containing tagsinput widgets and labels do not have the attribute lenght
            else:
                
                # Retrieve the labels and elements to label
                subgroup_selections = [tagsHBox.children[1].value for tagsHBox in subgroup_boxes[repeat].children]
                subgroup_labels = [tagsHBox.children[0].value for tagsHBox in subgroup_boxes[repeat].children]
                label_content_pairs = []
                
                # Iterate through the subgroup_selections list
                for i, tags_list in enumerate(subgroup_selections):
                    subgroup_elements = tags_list
    
                    # Generate a mapping of unique values to group labels
                    element_to_label = {element: subgroup_labels[i] for element in subgroup_elements}

                    # Add the group label to its list and the correct label to the dictionary
                    label_content_pairs.extend([f"{element}: {subgroup_labels[i]}" for element in subgroup_elements])
                    correct_group_labels[repeat][subgroup_labels[i]] = '+'.join(subgroup_elements)
                    
                    # Replace the subgroup elements with the new labels
                    variable_column_name = KM_data_working.columns[repeat+3]
                    KM_data_working[variable_column_name] = KM_data_working[variable_column_name].replace(element_to_label)
            
                # Filter out rows with new subgroup labels and log the labels selected
                KM_data_working = KM_data_working[KM_data_working[variable_column_name].isin(subgroup_labels)]
                log_string = "  -  ".join(label_content_pairs)
                logger.info(f"[Subgrouping 3rd step] Subgrouping labels applied to variable {repeat+1}---> {log_string} \n")
            
            # Log the updated df
            logger.info(f"[Subgrouping 3rd step] Dataset KM_data_working after applying subgrouping labels: \n {KM_data_working.head(15).to_string()} \n")
        ########
        # Once all labels have been applied to each column, make the subgroups

        # Get the column indices for the extra columns
        extra_column_indices = range(3, 3 + variable_repeats)
        
        # Get the unique values for each extra column
        extra_column_unique_values = [KM_data_working.iloc[:, i].unique() for i in extra_column_indices]
        
        # Create an empty dictionary to store the subsets
        KM_subgroups = {}
        
        # Generate all possible combinations of unique values from the extra columns
        combinations = [[]]
        for values in extra_column_unique_values:
            combinations = [sublist + [value] for sublist in combinations for value in values]
        
        # Iterate through each combination of unique values
        for combination in combinations:
            # Create a subset of KM_data_working for the current combination
            subset = KM_data_working.copy()
            
            # Filter the rows based on the current combination
            for i, index in enumerate(extra_column_indices):
                subset = subset[subset.iloc[:, index] == combination[i]]
            
            # Add the subset to the KM_subgroups dictionary
            KM_subgroups[tuple(combination)] = subset

        # Log the subgroups created
        logger.info(f"[Subgrouping 3rd step] Subgroups made from the dataset:\n")
        for combination, subgroup in KM_subgroups.items():
            logger.info(f"Subgroup label: {combination}")
            logger.info(f"\n{subgroup.head(10)}\n")
        
        ########
                
        # Finally, pass the input parameters to the KM_analysis function and get back the KM object
        KM_analysis_output = KM_analysis(KM_data_working, KM_subgroups)

        # Reassign the real/correct labels (they are as Group X, and we will correct to the actual tag or range)
        # Iterate through each key-value pair in the KM_analysis_output dictionary
        for old_key, KM_object in list(KM_analysis_output.items()):
            # Create a list to store the new key for this combination
            new_key = []
        
            # Iterate through each element of the old key (a tuple of strings)
            for i, label in enumerate(old_key):
                # Retrieve the correct label from the corresponding dictionary in correct_group_labels
                new_key.append(correct_group_labels[i].get(label, label))
        
            # Convert the new key (list) to a single string
            new_key = ', '.join(new_key)
        
            # Replace the current key with the corrected_key in the KM_analysis_output dictionary
            KM_analysis_output[new_key] = KM_analysis_output.pop(old_key)

        
        # Plot the estimates of all KMF objects (95% of data points)
        with KM_plot_area:
            KM_plot_area.clear_output()
            plt.figure(figsize=(10, 6))
            for label, KM_object in KM_analysis_output.items():
                KM_object.plot(label=label, ci_show=CI_checkbox.value, iloc=slice(0, int(len(KM_object.survival_function_) * 0.95)))
            plt.xlabel('Time')
            plt.ylabel('Probability')
            plt.title('Kaplan-Meier Estimates')
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') if plot_labels_checkbox.value else plt.legend()
            plt.show()

    # Once any of the plots is made and displayed, we enable the saving button to show it
    save_button.description = 'Save results'
    save_button.button_style = 'info'
    save_button.disabled = False 

This function ***feeds the input dataset(s) to the KaplanMeierFitter*** object, which can be a single instance or a dictionary containing an KMF object per subset of the dataset provided. It returns ***KMF_object*** which contains the estimates and can be plotted by the function above (pass_KM_parameters).

***NOTES:*** The KM analysis currently does a simple .fit(time, event) of the data, but the resulting KMF objects have several properties that can be used to extract more insights or further complemented with other tools of the KMF library.

In [None]:
def KM_analysis(KM_data, KM_subgroups):

    # Unpack the input parameters provided
    current_time_column = KM_data.columns[1]
    current_event_column = KM_data.columns[2]
    
    global KMF_object

    # Use the whole dataset when no groups were made
    if subgroup_buttons.value == 'No':

        # Create a single KaplanMeierFitter object
        KMF_object = KaplanMeierFitter()

        # Generate the plot using the specified columns
        KMF_object.fit(durations=KM_data[current_time_column], event_observed=KM_data[current_event_column])

        # Log part of the curve to verify the data was passed correctly
        logger.info(f"[No Subgroups 3rd step] The KM Fitter succesfully calculated the probabilities and made the plot. \n")
        logger.info(f"[No Subgroups 3rd step] Calculated survival function: \n {KMF_object.survival_function_.head(7).to_string()} \n ... \n {KMF_object.survival_function_.tail(7).to_string()} \n\n")

    # Make a fit for every subset provided (based on the number of groups and subgroups made
    else:

        # Sort the subgroups in alphabetical order to plot them in the same order and colour
        KM_subgroups = OrderedDict(sorted(KM_subgroups.items()))
        
        # Create an empty dictionary to store the KaplanMeierFitter objects
        KMF_object = {}
        logger.info(f"[Subgrouping 4th step] The KM Fitter succesfully calculated the probabilities. \n")
        
        # Create KaplanMeierFitter objects for each subgroup in KM_subgroups
        for label, subset in KM_subgroups.items():
            kmf = KaplanMeierFitter()
            kmf.fit(durations=subset[current_time_column], event_observed=subset[current_event_column])
            KMF_object[label] = kmf

            # Log part of the curve to verify the data was passed correctly
            logger.info(f"[Subgrouping 4th step] Calculated survival function of: {label}")
            logger.info(f"\n {kmf.survival_function_.head(7).to_string()} \n ... \n {kmf.survival_function_.tail(7).to_string()} \n\n")
        
    return KMF_object

### Save results

This function first ***calls the helper function below to save the plot of interest to a .jpg image***, and then ***retrieves and saves the event table, survival function, confidence intervals, and median survival time of each KMF object***, putting all the information for each object in a formatted excel worksheet. The file is then saved as a .xlsx file in the current directory and then this 

In [None]:
# Global counter to keep track of the number of times the button is clicked
file_count = 1

def save_KM_results(save_button):
    global file_count

    # Increment the file count to avoid overwriting previous plots and excel files
    file_count_str = str(file_count).zfill(2)  # Convert the counter to a 2-digit zero-padded string
    file_count += 1

    # File names for Excel and plot
    excel_filename = f"KM_results_{file_count_str}.xlsx"
    plot_filename = f"KM_results_{file_count_str}.jpg"

    # Save the plot image with the function below
    save_plot_image(KM_analysis_output, plot_filename)

    # Create a new Excel workbook and remove the default Sheet
    workbook = openpyxl.Workbook()
    workbook.remove(workbook['Sheet'])
    
    # Prepare the data to be processed (single KM object or list of KM objects)
    if isinstance(KM_analysis_output, dict):
        KM_objects_to_process = [{"label": f"KM_Subgroup_{i+1}", "KM_object": KM_object} for i, (label, KM_object) in enumerate(KM_analysis_output.items())]
        real_labels = [f"KM_Subgroup_{i+1}: {label}" for i, (label, KM_object) in enumerate(KM_analysis_output.items())]
    else:
        # If KM_analysis_output is a single KM object, add it to the list as a dictionary with a general label
        KM_objects_to_process = [{"label": "KM_Dataset", "KM_object": KM_analysis_output}]
        real_labels = ["KM_Dataset: Whole dataset - No subgroups"]

    # Process all KM curves/objects the same way 
    for index, data in enumerate(KM_objects_to_process):
        
        # Create a sheet per KM object
        sheet = workbook.create_sheet(title=data["label"])

        # Write what the curve/object corresponds to
        sheet.merge_cells(start_row=2, start_column=1, end_row=2, end_column=13)
        sheet.cell(row=2, column=1).alignment = openpyxl.styles.Alignment(horizontal='center', vertical='center')
        sheet.cell(row=2, column=1, value=real_labels[index]).font = openpyxl.styles.Font(bold=True, size=16)

        # Get the tables from the KMF object
        event_table = data["KM_object"].event_table
        survival_function = pd.DataFrame({"Time": data["KM_object"].survival_function_.index,
                                          "Survival Probability": np.ravel(data["KM_object"].survival_function_.values)})
        confidence_interval = data["KM_object"].confidence_interval_
        median_survival_time = data["KM_object"].median_survival_time_

        # Write the tables to the Excel sheet
        tables = [event_table, survival_function, confidence_interval, median_survival_time]
        table_names = ["Event Table", "Survival Function", "Confidence Intervals", "Median Survival Time"]
        table_column_numbers = [1, 7, 10, 13]
        
        # Write all tables to the sheet
        for col_index, (table, table_name) in enumerate(zip(tables, table_names)):
            # Define the current column number for the table
            current_column = table_column_numbers[col_index]
    
            # Set the header for the current table
            sheet.cell(row=4, column=current_column, value=table_name).font = openpyxl.styles.Font(bold=True)
    
            if isinstance(table, pd.DataFrame):
                # If the table is a DataFrame, convert it to a NumPy array
                rows = table.to_numpy()
                num_cols = len(table.columns)
    
                for row_index, row in enumerate(rows):
                    # Write the data from the DataFrame to the Excel sheet
                    for col_offset, value in enumerate(row):
                        sheet.cell(row=row_index + 6, column=current_column + col_offset, value=value)  
            else:
                # If the table is not a DataFrame, write the single value to the Excel sheet
                sheet.cell(row=5, column=current_column, value=table)

        ##### Extra worksheet formatting 
        # Merge and center table titles
        sheet.merge_cells(start_row=4, start_column=1, end_row=4, end_column=5)
        sheet.cell(row=4, column=1).alignment = openpyxl.styles.Alignment(horizontal='center', vertical='center')
        sheet.merge_cells(start_row=4, start_column=7, end_row=4, end_column=8)
        sheet.cell(row=4, column=7).alignment = openpyxl.styles.Alignment(horizontal='center', vertical='center')
        sheet.merge_cells(start_row=4, start_column=10, end_row=4, end_column=11)
        sheet.cell(row=4, column=10).alignment = openpyxl.styles.Alignment(horizontal='center', vertical='center')
        
        # Write column titles and center them
        sheet.cell(row=5, column=1, value="Removed")
        sheet.cell(row=5, column=2, value="Observed")
        sheet.cell(row=5, column=3, value="Censored")
        sheet.cell(row=5, column=4, value="Entrance")
        sheet.cell(row=5, column=5, value="At Risk")
        sheet.cell(row=5, column=7, value="Time")
        sheet.cell(row=5, column=8, value="Probability")
        sheet.cell(row=5, column=10, value="Lower Bound")
        sheet.cell(row=5, column=11, value="Upper Bound")
        for cell in ["A5", "B5", "C5", "D5", "E5", "G5", "H5", "J5", "K5"]:
            sheet[cell].alignment = openpyxl.styles.Alignment(horizontal='center')
        
        # Adjust some column widths
        for column, width in zip(['H', 'J', 'K', 'M'], [10, 12, 12, 22]):
            sheet.column_dimensions[column].width = width
        #####
    
    # Save the Excel file and log it
    workbook.save(excel_filename)
    logger.info(f"An excel file containing the results has been saved to the current directory and the name {excel_filename} \n")
    
    # Update the button description
    save_button.description = 'Results Saved!'
    save_button.button_style = ''
    save_button.disabled = True

This is a helper function of the function above. This function ***makes again the KM plot to save it***, since the user may not want to save all the plots generated (like if exploratory analysis is done). For that reason, a button generates the plot, which can be either composed of a single curve (survival function) or by multiple, and another button does the saving of the data of interest.

In [None]:
# Function to save the plot image
def save_plot_image(KM_analysis_output, filename):
    
    # Make the figure to fill
    plt.figure(figsize=(10, 6))

    if isinstance(KM_analysis_output, dict):
        # Plot the estimates of all KMF objects (95% of data points)
        for label, KM_object in KM_analysis_output.items():
            KM_object.plot(label=label, ci_show=CI_checkbox.value, iloc=slice(0, int(len(KM_object.survival_function_) * 0.95)))
    else:
        # Plot the estimate from the single KMF object
        KM_analysis_output.plot(ci_show=CI_checkbox.value, iloc=slice(0, int(len(KM_analysis_output.survival_function_) * 0.95)))
    
    # Customize the plot
    plt.xlabel('Time')
    plt.ylabel('Probability')
    plt.title('Kaplan-Meier Estimates')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') if plot_labels_checkbox.value else plt.legend()

    # Save, log and close the plot to release memory
    plt.savefig(filename, dpi=600, bbox_inches='tight')
    logger.info(f"A jpg file containing the KM plot has been saved to the current directory and the name {filename} \n")
    plt.close()  

### Function for flow control

This function ***calls the functions to load, pre-process, and show the widgets***, so we can do all the steps and see all the outputs on the last code cell, which is when this function is called with a short line. This also ensures that other than the installation + import of libraries, we do not do anything until the very end, that way we avoid issues with variables not defined since not all the functions above are in strict order as they call each other.

In [None]:
def start_plotting():

    # Declare variables that need to be used by multiple functions
    global df_clinical, df_RNA, time_to_event, event_observation
    
    # Load the input files
    df_clinical, df_RNA = upload_input_files()

    # Preprocess the uploaded files
    df_clinical, df_RNA, time_to_event, event_observation = file_preprocessing(df_clinical, df_RNA)

    # Generate and display the interactive widgets, these allow to select the data, make the KM plot(s) and save the results
    widget_preparation(df_clinical, df_RNA, time_to_event, event_observation)


## **Start KM-plotting here!!!**

In [None]:
# To begin, select: "Run all cells" and look under this cell
start_plotting()