# **CellTracksColab - Distance to ROI**
---
<font size = 4> This notebook is specifically designed to analyze movement tracks in relation to designated Regions of Interest (ROIs). Its primary aim is to compute and analyze the distances between moving objects (tracks) and ROIs, which may also be dynamic. By evaluating these distances over time, the notebook provides insights into the spatial behavior of the tracked entities relative to the ROIs. The notebook integrates data processing, distance calculations, and statistical analysis to offer a comprehensive tool for researchers to study movement patterns, interactions, and the overall dynamics of moving entities in relation to critical areas of interest.

---

# **Part 0. Before getting started**
---

<font size = 5>**Important notes**

---

## Data Requirements for Analysis

<font size = 4>Be advised of one significant limitation inherent to this notebook.

<font size = 4 color="red">**This notebook only supports 2D + t datasets**</font>.

---
# Prerequisites for Using This Notebook

<font size = 4>To effectively utilize this notebook for analyzing track distances to Regions of Interest (ROIs), the following prerequisites are essential:
<font size = 4>
1. **DataFrames from CellTrackColab**:
   - Ensure you have the `spots` and `tracks` DataFrames compiled by CellTrackColab.
<font size = 4>
2. **ROI Images in TIF Format**:
   - The ROI images or movies should be in TIF (`.tif`) file format. These images should be mask or label images where the background has a pixel value of 0.
<font size = 4>
3. **Proper Naming of ROI Images**:
   - Adopt a consistent and unique naming convention for your ROIs. This naming should be reflected in both your data analysis and image files.
   - Follow the file naming format: `File_name_ROI_name.tif`. For example, if your tracking file is named 'sample' and the ROI is 'nuclei', the corresponding image file should be named 'sample_nuclei.tif'.
   - Place all ROI image files in the same folder for streamlined access and analysis.



In [None]:
# @title #MIT License

print("""
**MIT License**

Copyright (c) 2023 Guillaume Jacquemet

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.""")

--------------------------------------------------------
# **Part 1: Prepare the session and load your data**
--------------------------------------------------------


## **1.1. Install key dependencies**
---
<font size = 4>

In [None]:
#@markdown ##Play to install
%pip -q install pandas scikit-learn
%pip -q install hdbscan
%pip -q install umap-learn
%pip -q install plotly
%pip -q install tqdm


In [None]:
#@markdown ##Play to load the dependancies

import ipywidgets as widgets
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
import itertools
from matplotlib.gridspec import GridSpec
import requests

!pip freeze > requirements.txt

# Current version of the notebook the user is running
current_version = "0.9.1"
Notebook_name = 'Distance_to_ROI'

# URL to the raw content of the version file in the repository
version_url = "https://raw.githubusercontent.com/guijacquemet/CellTracksColab/main/Notebook/latest_version.txt"

# Function to define colors for formatting messages
class bcolors:
    WARNING = '\033[91m'  # Red color for warning messages
    ENDC = '\033[0m'      # Reset color to default

# Check if this is the latest version of the notebook
try:
    All_notebook_versions = pd.read_csv(version_url, dtype=str)
    print('Notebook version: ' + current_version)

    # Check if 'Version' column exists in the DataFrame
    if 'Version' in All_notebook_versions.columns:
        Latest_Notebook_version = All_notebook_versions[All_notebook_versions["Notebook"] == Notebook_name]['Version'].iloc[0]
        print('Latest notebook version: ' + Latest_Notebook_version)

        if current_version == Latest_Notebook_version:
            print("This notebook is up-to-date.")
        else:
            print(bcolors.WARNING + "A new version of this notebook has been released. We recommend that you download it at https://github.com/guijacquemet/CellTracksColab" + bcolors.ENDC)
    else:
        print("The 'Version' column is not present in the version file.")
except requests.exceptions.RequestException as e:
    print("Unable to fetch the latest version information. Please check your internet connection.")
except Exception as e:
    print("An error occurred:", str(e))

#----------------------- Key functions -----------------------------#

# Function to calculate Cohen's d
def cohen_d(group1, group2):
    diff = group1.mean() - group2.mean()
    n1, n2 = len(group1), len(group2)
    var1 = group1.var()
    var2 = group2.var()
    pooled_var = ((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2)
    d = diff / np.sqrt(pooled_var)
    return d


def save_dataframe_with_progress(df, path, desc="Saving", chunk_size=50000):
    """Save a DataFrame with a progress bar."""

    # Estimating the number of chunks based on the provided chunk size
    num_chunks = int(len(df) / chunk_size) + 1

    # Create a tqdm instance for progress tracking
    with tqdm(total=len(df), unit="rows", desc=desc) as pbar:
        # Open the file for writing
        with open(path, "w") as f:
            # Write the header once at the beginning
            df.head(0).to_csv(f, index=False)

            for chunk in np.array_split(df, num_chunks):
                chunk.to_csv(f, mode="a", header=False, index=False)
                pbar.update(len(chunk))

def check_for_nans(df, df_name):
    """
    Checks the given DataFrame for NaN values and prints the count for each column containing NaNs.

    Args:
    df (pd.DataFrame): DataFrame to be checked for NaN values.
    df_name (str): The name of the DataFrame as a string, used for printing.
    """
    # Check if the DataFrame has any NaN values and print a warning if it does.
    nan_columns = df.columns[df.isna().any()].tolist()

    if nan_columns:
        for col in nan_columns:
            nan_count = df[col].isna().sum()
            print(f"Column '{col}' in {df_name} contains {nan_count} NaN values.")
    else:
        print(f"No NaN values found in {df_name}.")

def save_parameters(params, file_path, param_type):
    # Convert params dictionary to a DataFrame for human readability
    new_params_df = pd.DataFrame(list(params.items()), columns=['Parameter', 'Value'])
    new_params_df['Type'] = param_type

    if os.path.exists(file_path):
        # Read existing file
        existing_params_df = pd.read_csv(file_path)

        # Merge the new parameters with the existing ones
        # Update existing parameters or append new ones
        updated_params_df = pd.merge(existing_params_df, new_params_df,
                                     on=['Type', 'Parameter'],
                                     how='outer',
                                     suffixes=('', '_new'))

        # If there's a new value, update it, otherwise keep the old value
        updated_params_df['Value'] = updated_params_df['Value_new'].combine_first(updated_params_df['Value'])

        # Drop the temporary new value column
        updated_params_df.drop(columns='Value_new', inplace=True)
    else:
        # Use new parameters DataFrame directly if file doesn't exist
        updated_params_df = new_params_df

    # Save the updated DataFrame to CSV
    updated_params_df.to_csv(file_path, index=False)

## **1.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the instructions.

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
#@markdown ##Play the cell to connect your Google Drive to Colab

from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive



## **1.3. Load your CellTracksColab dataset**
---

<font size="4"> Before proceeding, please ensure that your data has been properly processed using CellTracksColab. Typically, your Track table should be named `merged_Tracks.csv`, and your Spot table should be named `merged_Spots.csv`.

For the `Results_Folder` parameter, you can choose the same folder that already contains all the results associated with your dataset. Any results generated by this notebook will be saved in the `Distance_to_ROI` subfolder.



In [None]:
#@markdown ##Provide the path to your dataset:

import os
import re
import glob
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np
import requests
import zipfile

#@markdown ###You have existing dataframes, provide the path to your:

Track_table = ''  # @param {type: "string"}
Spot_table = ''  # @param {type: "string"}

#@markdown ###Provide the path to your Result folder

Results_Folder = ""  # @param {type: "string"}

if not Results_Folder:
    Results_Folder = '/content/Results'  # Default Results_Folder path if not defined

if not os.path.exists(Results_Folder):
    os.makedirs(Results_Folder)  # Create Results_Folder if it doesn't exist

# Print the location of the result folder
print(f"Result folder is located at: {Results_Folder}")

def validate_tracks_df(df):
    """Validate the tracks dataframe for necessary columns and data types."""
    required_columns = ['TRACK_ID']
    for col in required_columns:
        if col not in df.columns:
            print(f"Error: Column '{col}' missing in tracks dataframe.")
            return False

    # Additional data type checks or value ranges can be added here
    return True

def validate_spots_df(df):
    """Validate the spots dataframe for necessary columns and data types."""
    required_columns = ['TRACK_ID', 'POSITION_X', 'POSITION_Y', 'POSITION_T']
    for col in required_columns:
        if col not in df.columns:
            print(f"Error: Column '{col}' missing in spots dataframe.")
            return False

    # Additional data type checks or value ranges can be added here
    return True

def check_unique_id_match(df1, df2):
    df1_ids = set(df1['Unique_ID'])
    df2_ids = set(df2['Unique_ID'])

    # Check if the IDs in the two dataframes match
    if df1_ids == df2_ids:
        print("The Unique_ID values in both dataframes match perfectly!")
    else:
        missing_in_df1 = df2_ids - df1_ids
        missing_in_df2 = df1_ids - df2_ids

        if missing_in_df1:
            print(f"There are {len(missing_in_df1)} Unique_ID values present in the second dataframe but missing in the first.")
            print("Examples of these IDs are:", list(missing_in_df1)[:5])

        if missing_in_df2:
            print(f"There are {len(missing_in_df2)} Unique_ID values present in the first dataframe but missing in the second.")
            print("Examples of these IDs are:", list(missing_in_df2)[:5])

# For existing dataframes
if Track_table:
    print("Loading track table file....")
    merged_tracks_df = pd.read_csv(Track_table, low_memory=False)
    if not validate_tracks_df(merged_tracks_df):
        print("Error: Validation failed for loaded tracks dataframe.")

if Spot_table:
    print("Loading spot table file....")
    merged_spots_df = pd.read_csv(Spot_table, low_memory=False)
    if not validate_spots_df(merged_spots_df):
        print("Error: Validation failed for loaded spots dataframe.")

def check_unique_id_match(df1, df2):
    df1_ids = set(df1['Unique_ID'])
    df2_ids = set(df2['Unique_ID'])

    # Check if the IDs in the two dataframes match
    if df1_ids == df2_ids:
        print("The Unique_ID values in both dataframes match perfectly!")
    else:
        missing_in_df1 = df2_ids - df1_ids
        missing_in_df2 = df1_ids - df2_ids

        if missing_in_df1:
            print(f"There are {len(missing_in_df1)} Unique_ID values present in the second dataframe but missing in the first.")
            print("Examples of these IDs are:", list(missing_in_df1)[:5])

        if missing_in_df2:
            print(f"There are {len(missing_in_df2)} Unique_ID values present in the first dataframe but missing in the second.")
            print("Examples of these IDs are:", list(missing_in_df2)[:5])

check_unique_id_match(merged_spots_df, merged_tracks_df)

check_for_nans(merged_spots_df, "merged_spots_df")
check_for_nans(merged_tracks_df, "merged_tracks_df")

# Define the metadata columns that are expected to have identical values for each filename
metadata_columns = ['Condition', 'experiment_nb', 'Repeat']

def check_metadata(df, df_name="DataFrame"):
    consistent_metadata = True
    for name, group in df.groupby('File_name'):
        for col in metadata_columns:
            if not group[col].nunique() == 1:
                consistent_metadata = False
                print(f"Inconsistency found in {df_name} for file: {name} in column: {col}")
                break  # Stop checking other columns for this group
        if not consistent_metadata:
            break  # Stop the entire process if any inconsistency is found

    if consistent_metadata:
        print(f"{df_name} has consistent metadata.")
    else:
        print(f"{df_name} has inconsistencies in the metadata. Please check the output for details.")
    return

check_metadata(merged_tracks_df, "merged_tracks_df")
check_metadata(merged_spots_df, "merged_spots_df")


# **Part 2: Compute the distance to the nearest ROI**

<font size = 4>
This process can be repeated for different types of ROIs, such as 'Nuclei', 'Junctions', etc.
Simply change the `ROI_name` parameter to correspond to each specific ROI type you wish to analyze. For instance, set `ROI_name` to 'Nuclei' for analyzing nuclei, and then change it to 'Junctions' for junctions. This flexibility allows you to perform separate analyses for each ROI type using the same notebook, ensuring a thorough and comprehensive examination of tracks in relation to diverse ROI types.





## **2.1. Input Parameters**
<font size="4">

**ROI Folder Path**: Specify the directory containing your Region of Interest (ROI) images. These images should be masks or label images where the background has a pixel value of 0, and objects of interest are represented by values greater than 0.

**ROI Name**: Define the unique name for your region of interest (ROI). This name should align with the naming convention used in your image files. You can perform analyses on multiple ROIs one at a time by adjusting the `ROI_name` parameter for each specific ROI image. Image files should adhere to the following format: `File_name_ROI_name.tif`. For example, if your tracking file is named 'sample,' and the ROI is 'nuclei,' the corresponding image file should be named 'sample_nuclei.tif.'

**Pixel Calibration**: Enter the calibration value that converts pixel distances into real-world units (e.g., micrometers). This crucial parameter ensures accurate measurements in physical dimensions.




In [None]:
ROI_folder = ''  # @param {type: "string"}

ROI_name = ''# @param {type: "string"}

Pixel_calibration = None# @param {type: "number"}

if not os.path.exists(Results_Folder+"/Distance_to_ROI"):
    os.makedirs(Results_Folder+"/Distance_to_ROI")

if not os.path.exists(Results_Folder+"/Distance_to_ROI/"+ROI_name):
    os.makedirs(Results_Folder+"/Distance_to_ROI/"+ROI_name)

distance_column_name = f'DistanceTo{ROI_name}'

if distance_column_name in merged_spots_df.columns:
    print(f"The column '{distance_column_name}' already exists in the DataFrame.")
    print("This indicates that the distance to ROI calculation might have been done previously.")
    print("Computing the distance to ROI again may not required unless the data has changed.")

print(f"Done.")

## **2.2. Check that your dataset matches your image files**

1. **Checking and Correcting Coordinates**: This process assesses and corrects the coordinates in the dataset to ensure they fall within the bounds of the associated images or videos.

2. **Zero Pixel Percentage**: For each file in the dataset, the code calculates the percentage of background (zero) pixels in the associated Region of Interest (ROI) image or video. This metric helps evaluate the dataset's quality and whether the coordinates need adjustment.

3. **Correction Process**: If the coordinates are found outside the image or video bounds, the code corrects them, ensuring they are within acceptable limits. Any corrections made are reported in the console.

4. **Supported Data Types**: The code can handle image and video data. It determines the type and dimensions of the data automatically.

5. **Pixel Calibration**: The coordinates are adjusted for the specified pixel calibration, ensuring accurate positioning in real-world units.

By running this code, you validate and, if necessary, correct the coordinates in your dataset to match the dimensions of your ROI images or videos.


In [None]:

# @title ##Check that your dataset matches your image files

from tqdm.notebook import tqdm
from tqdm.notebook import tqdm
import pandas as pd
from skimage import io
import matplotlib.pyplot as plt
from tifffile import imread
from skimage.measure import label, regionprops, find_contours
from scipy.ndimage import distance_transform_edt

def check_zero_pixel_percentage(ROI_img):
    zero_pixel_count = np.sum(ROI_img == 0)
    total_pixel_count = ROI_img.size
    zero_pixel_percentage = (zero_pixel_count / total_pixel_count) * 100
    return zero_pixel_percentage


def check_and_correct_coordinates(df, file_name, image_dir, ROI_name):
    """
    Checks and corrects the coordinates in the DataFrame for a given file to ensure they are within the bounds
    of the associated image or video.

    Parameters:
    df (DataFrame): DataFrame containing the spots' data.
    file_name (str): The name of the file to check and correct.
    image_dir (str): Directory where the images or videos are stored.
    ROI_name (str): Suffix or identifier for the image or video file name.

    Returns:
    DataFrame: Updated DataFrame with corrected coordinates.
    """
    # Load the image or video
    ROI_img_path = f"{image_dir}/{file_name}_{ROI_name}.tif"
    try:
        ROI_img = io.imread(ROI_img_path)
    except FileNotFoundError:
        print(f"Image or video for {file_name} not found.")
        return df
    # Check the percentage of zero pixels
    zero_percentage = check_zero_pixel_percentage(ROI_img)
    print(f"File {file_name}: Percentage of background pixels is {zero_percentage:.2f}%.")

    # Determine if it's an image or a video
    if ROI_img.ndim == 3:  # Video
        max_x, max_y = ROI_img.shape[2] - 1, ROI_img.shape[1] - 1
    elif ROI_img.ndim == 2:  # Image
        max_x, max_y = ROI_img.shape[1] - 1, ROI_img.shape[0] - 1
    else:
        print(f"Unsupported number of dimensions ({ROI_img.ndim}) in the file {file_name}.")
        return df

    # Apply pixel calibration
    max_x, max_y = max_x * Pixel_calibration, max_y * Pixel_calibration

    # Filter dataframe for the current file
    file_df = df[df['File_name'] == file_name]

    # Correct each coordinate with tqdm for progress
    for idx in tqdm(file_df.index, desc=f"Processing {file_name}"):
        x, y = int(df.at[idx, 'POSITION_X']), int(df.at[idx, 'POSITION_Y'])
        corrected_x = max(0, min(x, max_x))
        corrected_y = max(0, min(y, max_y))
        if corrected_x != x or corrected_y != y:
            print(f"Corrected coordinates for index {idx} from (x={x}, y={y}) to (x={corrected_x}, y={corrected_y})")
        df.at[idx, 'POSITION_X'] = corrected_x
        df.at[idx, 'POSITION_Y'] = corrected_y

    return df

# Apply the function to each file in the DataFrame
for file_name in tqdm(merged_spots_df['File_name'].unique(), desc="Checking and correcting coordinates"):
    merged_spots_df = check_and_correct_coordinates(merged_spots_df, file_name, ROI_folder, ROI_name)


## **2.3. Visualise your tracks**
---

In [None]:
# @title ##Run the cell and choose the file you want to inspect

import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt

if not os.path.exists(Results_Folder+"/Tracks"):
    os.makedirs(Results_Folder+"/Tracks")  # Create Results_Folder if it doesn't exist

# Extract unique filenames from the dataframe
filenames = merged_spots_df['File_name'].unique()

# Create a Dropdown widget with the filenames
filename_dropdown = widgets.Dropdown(
    options=filenames,
    value=filenames[0] if len(filenames) > 0 else None,  # Default selected value
    description='File Name:',
)

def plot_coordinates(filename):
    if filename:
        # Filter the DataFrame based on the selected filename
        filtered_df = merged_spots_df[merged_spots_df['File_name'] == filename]

        plt.figure(figsize=(10, 8))
        for unique_id in filtered_df['Unique_ID'].unique():
            unique_df = filtered_df[filtered_df['Unique_ID'] == unique_id].sort_values(by='POSITION_T')
            plt.plot(unique_df['POSITION_X'], unique_df['POSITION_Y'], marker='o', linestyle='-', markersize=2)

        plt.xlabel('POSITION_X')
        plt.ylabel('POSITION_Y')
        plt.title(f'Coordinates for {filename}')
        plt.show()
    else:
        print("No valid filename selected")

# Link the Dropdown widget to the plotting function
interact(plot_coordinates, filename=filename_dropdown)


## **2.4. Compute the distance to the nearest ROI**

By running this code, you compute and record the distances from each data point to the nearest ROI pixel, facilitating further analysis and interpretation of your dataset.




In [None]:
# @title ##Compute the distance to the nearest ROI


from tqdm.notebook import tqdm
import pandas as pd
from skimage import io
import matplotlib.pyplot as plt
from tifffile import imread
from skimage.measure import label, regionprops, find_contours
from scipy.ndimage import distance_transform_edt


def compute_distances_using_distance_transform(df, image_dir):
    """
    Compute distances to the nearest labeled pixel for each spot using the distance transform method.
    Automatically detects if the file is a single image or a video sequence and checks if the frame
    number corresponds to the actual number of frames in the video.

    Parameters:
    df (DataFrame): The dataframe containing the spots' data.
    image_dir (str): The directory where the ROI images or videos are stored.
    """
    for file_name in tqdm(df['File_name'].unique(), desc="Processing files"):
        # Paths to the label images or video
        ROI_img_path = f"{image_dir}/{file_name}_{ROI_name}.tif"

        try:
            ROI_img = io.imread(ROI_img_path)
            # Ensure the image dimensions are 3 or below
            if ROI_img.ndim > 3:
                raise ValueError(f"Image file {file_name} has more than 3 dimensions, which is not supported.")

            # Determine if the file is a video by checking if it has more than two dimensions
            is_video = ROI_img.ndim == 3

            # Verify that the 'FRAME' number in the dataframe does not exceed the number of frames in the video
            if is_video and 'FRAME' in df.columns:
                file_df = df[df['File_name'] == file_name]
            # Compute max_frame_num for the current file_name
                max_frame_num = file_df['FRAME'].max()
                num_frames = ROI_img.shape[0]
                if max_frame_num > num_frames:
                    print(f"Error: max_frame_num ({max_frame_num}) exceeds num_frames ({num_frames}) in file {file_name}.")
                    raise ValueError(f"DataFrame contains 'FRAME' numbers that exceed the number of frames in the video for file {file_name}.")
                for frame_idx in range(num_frames):
                    # Process each frame with matching spots
                    process_frame(ROI_img[frame_idx], df, file_name, frame_idx)
            else:
                # Process a single image
                process_frame(ROI_img, df, file_name)

        except FileNotFoundError:
            print(f"Error: Image for {file_name} not found. Skipping...")
            continue
        except ValueError as e:
            print(e)
            break

    return df

def process_frame(ROI_img, df, file_name, frame_idx=None):
    """
    Process a single frame or image and update the dataframe with distance values.

    Parameters:
    ROI_img (ndarray): The ROI image or a single frame from a video.
    df (DataFrame): The dataframe to update.
    file_name (str): The name of the file being processed.
    frame_idx (int, optional): The index of the frame in the video.
    """
    # Compute distance transform
    distance_transform_ROI = distance_transform_edt(ROI_img == 0) * Pixel_calibration

    # Filter dataframe for the current file and frame
    file_df = df[df['File_name'] == file_name]
    if frame_idx is not None:
        file_df = file_df[file_df['FRAME'] == frame_idx]

    for idx, row in tqdm(file_df.iterrows(), total=file_df.shape[0], desc=f"Processing coordinates for {file_name}", leave=False):
        y, x = int(row['POSITION_Y'] / Pixel_calibration), int(row['POSITION_X'] / Pixel_calibration)
                # Check if x and y are within the bounds of the image

        if 0 <= x < distance_transform_ROI.shape[1] and 0 <= y < distance_transform_ROI.shape[0]:
            df.loc[df.index[idx], f'DistanceTo{ROI_name}'] = distance_transform_ROI[y, x]
        else:
            print(f"Warning: Coordinates (x={x}, y={y}) out of bounds for {file_name}")

compute_distances_using_distance_transform(merged_spots_df, ROI_folder)

save_dataframe_with_progress(merged_spots_df, Results_Folder + '/' + 'merged_Spots.csv', desc="Saving Spots")

## **2.5. Check for missing values**

By running this code, you can quickly identify which filenames in your dataset contain data points with missing distance values. This information is valuable for data quality assessment. This likely occurs because ROI images are missing.



In [None]:

# @title ##Check for NaN


def print_filenames_with_nan_distances(spots_df, ROI_name):
    nan_filenames = []

    grouped_spots = spots_df.groupby('Unique_ID')

    for unique_id, group in grouped_spots:
        if group[f'DistanceTo{ROI_name}'].isna().any():
            # Store filenames associated with NaN distances
            nan_filenames.extend(group['File_name'].unique())

    # Print unique filenames with NaN distances
    unique_nan_filenames = set(nan_filenames)
    print(f"Filenames with NaN distances: {unique_nan_filenames}")

# Usage
print_filenames_with_nan_distances(merged_spots_df, ROI_name)


## **2.6. Visual validation**

1. **Filename Dropdown**: Select a specific filename from the dropdown list. This corresponds to the tracked data associated with the ROI you want to analyze.

2. **Display Option**: Choose between 'All' or 'Specific number' to control the number of points displayed on the plot.

3. **Number of Points**: If you select 'Specific number' in the display option, enter the desired number of points to display.

4. **Save as PDF**: Check this box if you want to save the generated plot as a PDF file.

5. **Visualize Distances Button**: Click this button to visualize the distances measured for the selected ROI and filename.

The resulting plot will display the ROI image with distances represented by red circles around the points of interest. Each circle's radius corresponds to the distance measurement in real-world units (e.g., micrometers), taking into account the specified pixel calibration.

If the 'Save as PDF' option is checked, the plot will be saved as a PDF file in the designated results folder.

Use this interface to visually inspect and validate the correctness of the measured distances for your ROI data.



In [None]:
# @title ##Run to check visually that the distances measured are correct.
from ipywidgets import Button, interactive, IntSlider, widgets, fixed
from ipywidgets import Output
from IPython.display import clear_output
import numpy as np
import matplotlib.pyplot as plt
from tifffile import imread
import matplotlib.backends.backend_pdf


error_output = Output()

def display_error_message(message):
    with error_output:
        clear_output(wait=False)
        print(f"Error: {message}")

filename_dropdown = widgets.Dropdown(
    options=merged_spots_df['File_name'].unique(),
    description='Filename:',
    disabled=False
)

# Additional widgets for user input
display_option = widgets.Dropdown(
    options=['All', 'Specific number'],
    value='All',
    description='Display:',
    disabled=False
)

number_of_points = widgets.BoundedIntText(
    value=5,
    min=1,
    max=1,  # Adjust max as per your data range
    step=1,
    description='Number of points:',
    disabled=False
)

# Checkbox for saving plot as PDF
save_as_pdf_checkbox = widgets.Checkbox(
    value=False,
    description='Save as PDF',
    disabled=False
)

# Display these widgets immediately
display(filename_dropdown)

def update_display(frame_number, ROI_img, data_for_frame, filename, display_mode, point_count, save_pdf):
    plt.figure(figsize=(8, 8))

    frame = ROI_img.copy()
    if frame.ndim == 3:  # If the image is a video
        frame = frame[frame_number]

    coords_for_frame = data_for_frame[data_for_frame['POSITION_T'] == frame_number]
    if coords_for_frame.empty:
        print(f"No data available for frame {frame_number} in file {filename}.")
    else:
        plt.imshow(frame, cmap='gray')

        plt.xlim(0, frame.shape[1])
        plt.ylim(frame.shape[0], 0)

        if display_mode == 'All':
            points_to_display = coords_for_frame
        else:
            points_to_display = coords_for_frame.head(point_count)

        for idx, row in points_to_display.iterrows():
            x, y = int(row['POSITION_X']/Pixel_calibration), int(row['POSITION_Y']/Pixel_calibration)
            distance_to_edge = row[f'DistanceTo{ROI_name}']/Pixel_calibration
            circle_edge = plt.Circle((x, y), distance_to_edge, color='red', fill=False, linewidth=1)
            plt.gca().add_patch(circle_edge)

        plt.scatter(points_to_display['POSITION_X']/Pixel_calibration, points_to_display['POSITION_Y']/Pixel_calibration, c='yellow', s=50, zorder=3)


    plt.title(f"Frame {frame_number} for {filename}")

    if save_pdf:
        pdf_filename = f"{Results_Folder}/Distance_to_ROI/{ROI_name}/{filename}_frame_{frame_number}.pdf"
        plt.savefig(pdf_filename, format='pdf')
        print(f"Plot saved as '{pdf_filename}'")
    plt.show()

def visualize_precomputed_distances_for_filename(filename):
    ROI_img_path = f"{ROI_folder}/{filename}_{ROI_name}.tif"
        # Calculate the maximum number of points for the selected file
    max_points_for_file = merged_spots_df[merged_spots_df['File_name'] == filename].groupby('FRAME').size().max()

    # Update the maximum value for the number_of_points widget
    number_of_points.max = max_points_for_file

    try:
        ROI_img = imread(ROI_img_path)
    except FileNotFoundError:
        display_error_message(f"Image for {filename} not found.")
        return

    data_for_frame = merged_spots_df[merged_spots_df['File_name'] == filename]

    max_frame = data_for_frame['FRAME'].max()
    frame_slider = widgets.IntSlider(min=0, max=max_frame, description='Frame')

    # Modified call to include new widgets
    w = interactive(update_display,
                    frame_number=frame_slider,
                    ROI_img=fixed(ROI_img),
                    data_for_frame=fixed(data_for_frame),
                    filename=fixed(filename),
                    display_mode=display_option,
                    point_count=number_of_points,
                    save_pdf=save_as_pdf_checkbox)

    display(w)

# Button to trigger visualization
plot_button_filename = Button(description="Visualize Distances", button_style='info')

# Function to handle button click for filename visualization
def on_plot_button_filename_click(b):
    filename = filename_dropdown.value
    print("In progress...")
    # Clear the previous output
    clear_output()

    # Call the visualization function without redisplaying the widgets
    visualize_precomputed_distances_for_filename(filename)

# Bind the function to the button click event
plot_button_filename.on_click(on_plot_button_filename_click)

# Display the button and error output
display(plot_button_filename)
display(error_output)


## **2.5. Compute general track metrics associated with the ROI**


1. **MaxDistance_{ROI_name}**:
   - The maximum distance of the track from the ROI during the tracking period.
   - Indicates the farthest point reached relative to the ROI.

2. **MinDistance_{ROI_name}**:
   - The minimum distance of the track from the ROI during the tracking period.
   - Represents the closest approach to the ROI.

3. **StartDistance_{ROI_name}** and **EndDistance_{ROI_name}**:
   - Distances from the ROI at the start and end of the tracking period, respectively.
   - Useful for understanding initial and final positioning relative to the ROI.

4. **MedianDistance_{ROI_name}**:
   - The median of all recorded distances to the ROI.
   - Provides a central tendency measure, less affected by outliers than the mean.

5. **StdDevDistance_{ROI_name}**:
   - Standard deviation of the distances.
   - Indicates the variability or consistency of the track's distance from the ROI.

6. **DirectionMovement_{ROI_name}**:
   - Calculated as `EndDistance - StartDistance`.
   - A positive value indicates moving away from the ROI over time, and a negative value suggests moving closer.

7. **AvgRateChange_{ROI_name}**:
   - Average rate of change in distance per frame.
   - Helps assess the speed of movement towards or away from the ROI.

8. **PercentageChange_{ROI_name}**:
   - Percentage change in distance from the start to the end of the track.
   - Normalizes the movement relative to the initial distance.

9. **TrendSlope_{ROI_name}**:
   - Slope of a linear regression line fitted to the distance values over time.
   - Indicates the general trend of movement (increasing or decreasing distance).

## Interpretation
- These metrics are calculated considering the distance to the closest ROI at each time point. If the ROI is moving, the metrics reflect the relative motion between the track and the ROI.
- The closest ROI to a track at each time point may change if there are multiple ROIs. This factor is inherently considered in the distance calculations.
- It is essential to consider the movement of both the track and the ROI when interpreting these metrics. For example, a decreasing distance over time could mean the track is moving towards the ROI, the ROI is moving towards the track, or both.


In [None]:
# @title ##Run to compute the metrics.


from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
from scipy.stats import linregress

def get_distances_and_metrics(track_df, spots_df, ROI_name):
    results = []

    grouped_spots = spots_df.groupby('Unique_ID')

    for _, track in tqdm(track_df.iterrows(), total=track_df.shape[0], desc="Processing Tracks"):
        unique_id = track['Unique_ID']

        if unique_id in grouped_spots.groups:
            track_spots = grouped_spots.get_group(unique_id)
            distances = track_spots[f'DistanceTo{ROI_name}']

            if distances.empty or distances.isna().all():
                max_distance = min_distance = start_distance = end_distance = median_distance = average_rate_of_change = percentage_change = direction_of_movement = np.nan
            # Basic metrics

            else:
              max_distance = distances.max(skipna=True)
              min_distance = distances.min(skipna=True)
              start_distance = distances.iloc[0] if not distances.empty else np.nan
              end_distance = distances.iloc[-1] if not distances.empty else np.nan
              median_distance = distances.median(skipna=True)
              std_dev_distance = distances.std(skipna=True)

              # Advanced metrics
              direction_of_movement = end_distance - start_distance
              average_rate_of_change = direction_of_movement / len(distances) if len(distances) > 0 else np.nan
              percentage_change = (direction_of_movement / start_distance * 100) if start_distance != 0 else np.nan

              # Linear regression to determine trend
              slope, _, _, _, _ = linregress(range(len(distances)), distances) if not distances.empty else (np.nan,)*5

            results.append({
                'Unique_ID': unique_id,
                f'MaxDistance_{ROI_name}': max_distance,
                f'MinDistance_{ROI_name}': min_distance,
                f'StartDistance_{ROI_name}': start_distance,
                f'EndDistance_{ROI_name}': end_distance,
                f'MedianDistance_{ROI_name}': median_distance,
                f'StdDevDistance_{ROI_name}': std_dev_distance,
                f'DirectionMovement_{ROI_name}': direction_of_movement,
                f'AvgRateChange_{ROI_name}': average_rate_of_change,
                f'PercentageChange_{ROI_name}': percentage_change,
                f'TrendSlope_{ROI_name}': slope
            })

    return pd.DataFrame(results)

# Using the function with your DataFrames
distances_metrics_df = get_distances_and_metrics(merged_tracks_df, merged_spots_df, ROI_name)

# Merging Process
overlapping_columns = merged_tracks_df.columns.intersection(distances_metrics_df.columns).drop('Unique_ID')
merged_tracks_df.drop(columns=overlapping_columns, inplace=True)
merged_tracks_df = pd.merge(merged_tracks_df, distances_metrics_df, on='Unique_ID', how='left')

# Save the updated DataFrame
save_dataframe_with_progress(merged_tracks_df, Results_Folder + '/' + 'merged_Tracks.csv')


--------
# **Part 3. Quality Control**
--------

      



## **3.1. Assess if your dataset is balanced**
---

In cell tracking and similar biological analyses, the balance of the dataset is important, particularly in ensuring that each biological repeat carries equal weight. Here's why this balance is essential:

### Accurate Representation of Biological Variability

- **Capturing True Biological Variation**: Biological repeats are crucial for capturing the natural variability inherent in biological systems. Equal weighting ensures that this variability is accurately represented.
- **Reducing Sampling Bias**: By balancing the dataset, we avoid overemphasizing the characteristics of any single repeat, which might not be representative of the broader biological context.

If your data is too imbalanced, it may be useful to ensure that this does not shift your results.



In [None]:
import pandas as pd

# @title ##Check the number of track per condition per repeats


if not os.path.exists(f"{Results_Folder}/QC"):
    os.makedirs(f"{Results_Folder}/QC")


import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import os

def count_tracks_by_condition_and_repeat(df, Results_Folder, condition_col='Condition', repeat_col='Repeat', track_id_col='Unique_ID'):
    """
    Counts the number of unique tracks for each combination of condition and repeat in the given DataFrame and
    saves a stacked histogram plot as a PDF in the QC folder with annotations for each stack.

    Parameters:
    df (pandas.DataFrame): The DataFrame containing the data.
    Results_Folder (str): The base folder where the results will be saved.
    condition_col (str): The name of the column representing the condition. Default is 'Condition'.
    repeat_col (str): The name of the column representing the repeat. Default is 'Repeat'.
    track_id_col (str): The name of the column representing the track ID. Default is 'Unique_ID'.
    """
    track_counts = df.groupby([condition_col, repeat_col])[track_id_col].nunique()
    track_counts_df = track_counts.reset_index()
    track_counts_df.rename(columns={track_id_col: 'Number_of_Tracks'}, inplace=True)

    # Pivot the data for plotting
    pivot_df = track_counts_df.pivot(index=condition_col, columns=repeat_col, values='Number_of_Tracks').fillna(0)

    # Plotting
    fig, ax = plt.subplots(figsize=(12, 6))
    bars = pivot_df.plot(kind='bar', stacked=True, ax=ax)
    ax.set_xlabel('Condition')
    ax.set_ylabel('Number of Tracks')
    ax.set_title('Stacked Histogram of Track Counts per Condition and Repeat')
    ax.legend(title=repeat_col)
    ax.grid(axis='y', linestyle='--')

    # Hide horizontal grid lines
    ax.yaxis.grid(False)

    # Add number annotations on each stack
    for bar in bars.patches:
        ax.text(bar.get_x() + bar.get_width() / 2,
                bar.get_y() + bar.get_height() / 2,
                int(bar.get_height()),
                ha='center', va='center', color='black', fontweight='bold', fontsize=8)

    # Save the plot as a PDF
    pdf_file = os.path.join(Results_Folder, 'Track_Counts_Histogram.pdf')
    plt.savefig(pdf_file, bbox_inches='tight')
    print(f"Saved histogram to {pdf_file}")

    plt.show()

    return track_counts_df

result_df = count_tracks_by_condition_and_repeat(merged_tracks_df, f"{Results_Folder}/QC")




## **3.2. Compute Similarity Metrics between Field of Views (FOV) and between Conditions and Repeats**
---

<font size = 4>**Purpose**:

<font size = 4>This section provides a set of tools to compute and visualize similarities between different field of views (FOV) based on selected track parameters. By leveraging hierarchical clustering, the resulting dendrogram offers a clear visualization of how different FOV, conditions, or repeats relate to one another. This tool is essential for:

<font size = 4>1. **Quality Control**:
    - Ensuring that FOVs from the same condition or experimental setup are more similar to each other than to FOVs from different conditions.
    - Confirming that repeats of the same experiment yield consistent results and cluster together.
    
<font size = 4>2. **Data Integrity**:
    - Identifying potential outliers or anomalies in the dataset.
    - Assessing the overall consistency of the experiment and ensuring reproducibility.

<font size = 4>**How to Use**:

<font size = 4>1. **Track Parameters Selection**:
    - A list of checkboxes allows users to select which track parameters they want to consider for similarity calculations. By default, all parameters are selected. Users can deselect parameters that they believe might not contribute significantly to the similarity.

<font size = 4>2. **Similarity Metric**:
    - Users can choose a similarity metric from a dropdown list. Options include cosine, euclidean, cityblock, jaccard, and correlation. The choice of similarity metric can influence the clustering results, so users might need to experiment with different metrics to see which one provides the most meaningful results.

<font size = 4>3. **Linkage Method**:
    - Determines how the distance between clusters is calculated in the hierarchical clustering process. Different linkage methods can produce different dendrograms, so users might want to try various methods.

<font size = 4>4. **Visualization**:
    - Once the parameters are selected, users can click on the "Select the track parameters and visualize similarity" button. This will compute the hierarchical clustering and display two dendrograms:
        - One dendrogram displays similarities between individual FOVs.
        - Another dendrogram aggregates the data based on conditions and repeats, providing a higher-level view of the similarities.
      


In [None]:
# @title ##Compute similarity metrics between FOV and between conditions and repeats

import pandas as pd
import numpy as np
from scipy.spatial.distance import cosine
from scipy.cluster.hierarchy import linkage, dendrogram
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import ipywidgets as widgets
from sklearn.metrics import pairwise_distances
from scipy.spatial.distance import pdist

# Check and create "QC" folder
if not os.path.exists(f"{Results_Folder}/QC"):
    os.makedirs(f"{Results_Folder}/QC")

# Columns to exclude
excluded_columns = ['Condition', 'experiment_nb', 'File_name', 'Repeat', 'Unique_ID', 'LABEL', 'TRACK_INDEX', 'TRACK_ID', 'TRACK_X_LOCATION', 'TRACK_Y_LOCATION', 'TRACK_Z_LOCATION', 'Exemplar','TRACK_STOP', 'TRACK_START', 'Cluster_UMAP', 'Cluster_tsne']

selected_df = pd.DataFrame()

# Filter out non-numeric columns but keep 'File_name'
numeric_df = merged_tracks_df.select_dtypes(include=['float64', 'int64']).copy()
numeric_df['File_name'] = merged_tracks_df['File_name']

# Create a list of column names excluding 'File_name'
column_names = [col for col in numeric_df.columns if col not in excluded_columns]

# Create a checkbox for each column
checkboxes = [widgets.Checkbox(value=True, description=col, indent=False) for col in column_names]

# Dropdown for similarity metrics
similarity_dropdown = widgets.Dropdown(
    options=['cosine', 'euclidean', 'cityblock', 'jaccard', 'correlation'],
    value='cosine',
    description='Similarity Metric:'
)

# Dropdown for linkage methods
linkage_dropdown = widgets.Dropdown(
    options=['single', 'complete', 'average', 'ward'],
    value='single',
    description='Linkage Method:'
)

# Arrange checkboxes in a 2x grid
grid = widgets.GridBox(checkboxes, layout=widgets.Layout(grid_template_columns="repeat(2, 300px)"))

# Create a button to trigger the selection and visualization
button = widgets.Button(description="Select the track parameters and visualize similarity", layout=widgets.Layout(width='400px'), button_style='info')

# Define the button click event handler
def on_button_click(b):
    global selected_df  # Declare selected_df as global

    # Get the selected columns from the checkboxes
    selected_columns = [box.description for box in checkboxes if box.value]
    selected_columns.append('File_name')  # Always include 'File_name'

    # Extract the selected columns from the DataFrame
    selected_df = numeric_df[selected_columns]

    # Check and print the percentage of NaNs for each selected column
    for column in selected_columns:
        if selected_df[column].isna().any():
            nan_percentage = selected_df[column].isna().mean() * 100
            print("Warning: NaN values found in the selected data.")
            print(f"{column}: {nan_percentage:.2f}%")
            any_nan = True
            print("Proceeding to handle NaN values.")
            selected_df = selected_df.dropna()

    if not any_nan:
        print("No NaN values found in the selected columns.")

    # Aggregate the data by filename
    aggregated_by_filename = selected_df.groupby('File_name').mean(numeric_only=True)

    # Aggregate the data by condition and repeat
    aggregated_by_condition_repeat = merged_tracks_df.groupby(['Condition', 'Repeat'])[selected_columns].mean(numeric_only=True)

    # Compute condensed distance matrices
    distance_matrix_filename = pdist(aggregated_by_filename, metric=similarity_dropdown.value)
    distance_matrix_condition_repeat = pdist(aggregated_by_condition_repeat, metric=similarity_dropdown.value)

    # Perform hierarchical clustering
    linked_filename = linkage(distance_matrix_filename, method=linkage_dropdown.value)
    linked_condition_repeat = linkage(distance_matrix_condition_repeat, method=linkage_dropdown.value)

    annotation_text = f"Similarity Method: {similarity_dropdown.value}, Linkage Method: {linkage_dropdown.value}"

        # Prepare the parameters dictionary
    similarity_params = {
        'Similarity Metric': similarity_dropdown.value,
        'Linkage Method': linkage_dropdown.value,
        'Selected Columns': ', '.join(selected_columns)
    }

    # Save the parameters
    params_file_path = os.path.join(Results_Folder, "QC/analysis_parameters.csv")
    save_parameters(similarity_params, params_file_path, 'Similarity Metrics')

    # Plot the dendrograms one under the other
    plt.figure(figsize=(10, 10))

    # Dendrogram for individual filenames
    plt.subplot(2, 1, 1)
    dendrogram(linked_filename, labels=aggregated_by_filename.index, orientation='top', distance_sort='descending', leaf_rotation=90)
    plt.title(f'Dendrogram of Field of view Similarities\n{annotation_text}')

    # Dendrogram for aggregated data based on condition and repeat
    plt.subplot(2, 1, 2)
    dendrogram(linked_condition_repeat, labels=aggregated_by_condition_repeat.index, orientation='top', distance_sort='descending', leaf_rotation=90)
    plt.title(f'Dendrogram of Aggregated Similarities by Condition and Repeat\n{annotation_text}')

    plt.tight_layout()

    # Save the dendrogram to a PDF
    pdf_pages = PdfPages(f"{Results_Folder}/QC/Dendrogram_Similarities.pdf")

    # Save the current figure to the PDF
    pdf_pages.savefig()

    # Close the PdfPages object to finalize the document
    pdf_pages.close()

    plt.show()

# Set the button click event handler
button.on_click(on_button_click)

# Display the widgets
display(grid, similarity_dropdown, linkage_dropdown, button)


-------------------------------------------

# **Part 4. Plot track parameters**
-------------------------------------------

<font size = 4> In this section you can plot all the track parameters previously computed. Data and graphs are automatically saved in your result folder.

<b>Note on Units:</b> The parameters plotted are in the unit of measurement you used when tracking your data.
</font>

<font size="4" color="red">
<b>Results Storage:</b>
Results generated by in this section are saved  in the sub-folder named `track_parameters_plots` within your `Results_Folder`.


##**Statistical analyses**
### Cohen's d (Effect Size):
<font size = 4>Cohen's d measures the size of the difference between two groups, normalized by their pooled standard deviation. Values can be interpreted as small (0 to 0.2), medium (0.2 to 0.5), or large (0.5 and above) effects. It helps quantify how significant the observed difference is, beyond just being statistically significant.

### Randomization Test:
<font size = 4>This non-parametric test evaluates if observed differences between conditions could have arisen by random chance. It shuffles condition labels multiple times, recalculating the Cohen's d each time. The resulting p-value, which indicates the likelihood of observing the actual difference by chance, provides evidence against the null hypothesis: a smaller p-value implies stronger evidence against the null.

### Bonferroni Correction:
<font size = 4>Given multiple comparisons, the Bonferroni Correction adjusts significance thresholds to mitigate the risk of false positives. By dividing the standard significance level (alpha) by the number of tests, it ensures that only robust findings are considered significant. However, it's worth noting that this method can be conservative, sometimes overlooking genuine effects.

## **4.1. Plot your entire dataset**
--------

In [None]:
import os
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import pandas as pd
from scipy.stats import zscore

# @title ##Plot track normalized track parameters based on conditions as an heatmap (entire dataset)

# Parameters to adapt in function of the notebook section
base_folder = f"{Results_Folder}/track_parameters_plots"
Conditions = 'Condition'
df_to_plot = merged_tracks_df

# Check and create necessary directories
folders = ["pdf", "csv"]
for folder in folders:
    dir_path = os.path.join(base_folder, folder)
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


def get_selectable_columns(df):
    # Exclude certain columns from being plotted
    exclude_cols = ['Condition', 'experiment_nb', 'File_name', 'Repeat', 'Unique_ID', 'LABEL', 'TRACK_INDEX', 'TRACK_ID', 'TRACK_X_LOCATION', 'TRACK_Y_LOCATION', 'TRACK_Z_LOCATION', 'Exemplar','TRACK_STOP', 'TRACK_START', 'Cluster_UMAP', 'Cluster_tsne']
    # Select only numerical columns
    return [col for col in df.columns if (df[col].dtype.kind in 'biufc') and (col not in exclude_cols)]


def heatmap_comparison(df, Results_Folder, Conditions, variables_per_page=40):
    # Get all the selectable columns
    variables_to_plot = get_selectable_columns(df)

    # Drop rows where all elements are NaNs in the variables_to_plot columns
    df = df.dropna()

    # Compute median for each variable across Conditions
    median_values = df.groupby(Conditions)[variables_to_plot].median().transpose()

    # Normalize the median values using Z-score
    normalized_values = median_values.apply(zscore, axis=1)

    # Number of pages
    total_variables = len(variables_to_plot)
    num_pages = int(np.ceil(total_variables / variables_per_page))

    # Initialize an empty DataFrame to store all pages' data
    all_pages_data = pd.DataFrame()

    # Create a PDF file to save the heatmaps
    with PdfPages(f"{Results_Folder}/Heatmaps_Normalized_Median_Values_by_Condition.pdf") as pdf:
        for page in range(num_pages):
            start = page * variables_per_page
            end = min(start + variables_per_page, total_variables)
            page_data = normalized_values.iloc[start:end]

            # Append this page's data to the all_pages_data DataFrame
            all_pages_data = pd.concat([all_pages_data, page_data])

            plt.figure(figsize=(16, 10))
            sns.heatmap(page_data, cmap='coolwarm', annot=True, linewidths=.1)
            plt.title(f"Z-score Normalized Median Values of Variables by Condition (Page {page + 1})")
            plt.tight_layout()

            pdf.savefig()  # saves the current figure into a pdf page
            plt.show()
            plt.close()

    # Save all pages data to a single CSV file
    all_pages_data.to_csv(f"{Results_Folder}/Normalized_Median_Values_by_Condition.csv")

    print(f"Heatmaps saved to {Results_Folder}/Heatmaps_Normalized_Median_Values_by_Condition.pdf")
    print(f"All data saved to {Results_Folder}/Normalized_Median_Values_by_Condition.csv")

# Example usage
heatmap_comparison(merged_tracks_df, base_folder, Conditions)


In [None]:
# @title ##Plot track parameters (entire dataset)

import ipywidgets as widgets
from ipywidgets import Layout, VBox, Button, Accordion, SelectMultiple, IntText
import pandas as pd
import os
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.ticker import FixedLocator



# Parameters to adapt in function of the notebook section
base_folder = f"{Results_Folder}/track_parameters_plots"
Conditions = 'Condition'
df_to_plot = merged_tracks_df

# Check and create necessary directories
folders = ["pdf", "csv"]
for folder in folders:
    dir_path = os.path.join(base_folder, folder)
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def get_selectable_columns(df):
    # Exclude certain columns from being plotted
    exclude_cols = ['Condition', 'experiment_nb', 'File_name', 'Repeat', 'Unique_ID', 'LABEL', 'TRACK_INDEX', 'TRACK_ID', 'TRACK_X_LOCATION', 'TRACK_Y_LOCATION', 'TRACK_Z_LOCATION', 'Exemplar','TRACK_STOP', 'TRACK_START', 'Cluster_UMAP', 'Cluster_tsne']
    # Select only numerical columns
    return [col for col in df.columns if (df[col].dtype.kind in 'biufc') and (col not in exclude_cols)]


def display_variable_checkboxes(selectable_columns):
    # Create checkboxes for selectable columns
    variable_checkboxes = [widgets.Checkbox(value=False, description=col) for col in selectable_columns]

    # Display checkboxes in the notebook
    display(widgets.VBox([
        widgets.Label('Variables to Plot:'),
        widgets.GridBox(variable_checkboxes, layout=widgets.Layout(grid_template_columns="repeat(%d, 300px)" % 3)),
    ]))
    return variable_checkboxes

def create_condition_selector(df, column_name):
    conditions = df[column_name].unique()
    condition_selector = SelectMultiple(
        options=conditions,
        description='Conditions:',
        disabled=False,
        layout=Layout(width='100%')  # Adjusting the layout width
    )
    return condition_selector

def display_condition_selection(df, column_name):
    condition_selector = create_condition_selector(df, column_name)

    condition_accordion = Accordion(children=[VBox([condition_selector])])
    condition_accordion.set_title(0, 'Select Conditions')
    display(condition_accordion)
    return condition_selector


def plot_selected_vars(button, variable_checkboxes, df, Conditions, Results_Folder, condition_selector):

    plt.clf()  # Clear the current figure before creating a new plot
    print("Plotting in progress...")

  # Get selected variables
    variables_to_plot = [box.description for box in variable_checkboxes if box.value]
    n_plots = len(variables_to_plot)

    if n_plots == 0:
        print("No variables selected for plotting")
        return

  # Get selected conditions
    selected_conditions = condition_selector.value
    n_selected_conditions = len(selected_conditions)

    if n_selected_conditions == 0:
        print("No conditions selected for plotting")
        return

# Use only selected and ordered conditions
    filtered_df = df[df[Conditions].isin(selected_conditions)].copy()

# Initialize matrices to store effect sizes and p-values for each variable
    effect_size_matrices = {}
    p_value_matrices = {}
    bonferroni_matrices = {}

    unique_conditions = filtered_df[Conditions].unique().tolist()
    num_comparisons = len(unique_conditions) * (len(unique_conditions) - 1) // 2
    alpha = 0.05
    corrected_alpha = alpha / num_comparisons
    n_iterations = 1000

# Loop through each variable to plot
    for var in variables_to_plot:

      pdf_pages = PdfPages(f"{Results_Folder}/pdf/{var}_Boxplots_and_Statistics.pdf")
      effect_size_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)
      p_value_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)
      bonferroni_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)

      for cond1, cond2 in itertools.combinations(unique_conditions, 2):
        group1 = df[df[Conditions] == cond1][var]
        group2 = df[df[Conditions] == cond2][var]

        original_d = abs(cohen_d(group1, group2))
        effect_size_matrix.loc[cond1, cond2] = original_d
        effect_size_matrix.loc[cond2, cond1] = original_d  # Mirroring

        count_extreme = 0
        for i in range(n_iterations):
            combined = pd.concat([group1, group2])
            shuffled = combined.sample(frac=1, replace=False).reset_index(drop=True)
            new_group1 = shuffled[:len(group1)]
            new_group2 = shuffled[len(group1):]

            new_d = abs(cohen_d(new_group1, new_group2))
            if np.abs(new_d) >= np.abs(original_d):
                count_extreme += 1

        p_value = count_extreme / n_iterations
        p_value_matrix.loc[cond1, cond2] = p_value
        p_value_matrix.loc[cond2, cond1] = p_value  # Mirroring

        # Apply Bonferroni correction
        bonferroni_corrected_p_value = min(p_value * num_comparisons, 1.0)
        bonferroni_matrix.loc[cond1, cond2] = bonferroni_corrected_p_value
        bonferroni_matrix.loc[cond2, cond1] = bonferroni_corrected_p_value  # Mirroring

      effect_size_matrices[var] = effect_size_matrix
      p_value_matrices[var] = p_value_matrix
      bonferroni_matrices[var] = bonferroni_matrix

    # Concatenate the three matrices side-by-side
      combined_df = pd.concat(
        [
            effect_size_matrices[var].rename(columns={col: f"{col} (Effect Size)" for col in effect_size_matrices[var].columns}),
            p_value_matrices[var].rename(columns={col: f"{col} (P-Value)" for col in p_value_matrices[var].columns}),
            bonferroni_matrices[var].rename(columns={col: f"{col} (Bonferroni-corrected P-Value)" for col in bonferroni_matrices[var].columns})
        ], axis=1
    )

    # Save the combined DataFrame to a CSV file
      combined_df.to_csv(f"{Results_Folder}/csv/{var}_statistics_combined.csv")

    # Create a new figure
      fig = plt.figure(figsize=(16, 10))

    # Create a gridspec for 2 rows and 4 columns
      gs = GridSpec(2, 3, height_ratios=[1.5, 1])

    # Create the ax for boxplot using the gridspec
      ax_box = fig.add_subplot(gs[0, :])

    # Extract the data for this variable
      data_for_var = df[[Conditions, var, 'Repeat', 'File_name' ]]

    # Save the data_for_var to a CSV for replotting
      data_for_var.to_csv(f"{Results_Folder}/csv/{var}_boxplot_data.csv", index=False)

    # Calculate the Interquartile Range (IQR) using the 25th and 75th percentiles
      Q1 = df[var].quantile(0.25)
      Q3 = df[var].quantile(0.75)
      IQR = Q3 - Q1

    # Define bounds for the outliers
      multiplier = 10
      lower_bound = Q1 - multiplier * IQR
      upper_bound = Q3 + multiplier * IQR

    # Plotting
      sns.boxplot(x=Conditions, y=var, data=filtered_df, ax=ax_box, color='lightgray')  # Boxplot
      sns.stripplot(x=Conditions, y=var, data=filtered_df, ax=ax_box, hue='Repeat', dodge=True, jitter=True, alpha=0.2)  # Individual data points
      ax_box.set_ylim([max(min(filtered_df[var]), lower_bound), min(max(filtered_df[var]), upper_bound)])
      ax_box.set_title(f"{var}")
      ax_box.set_xlabel('Condition')
      ax_box.set_ylabel(var)
      tick_labels = ax_box.get_xticklabels()
      tick_locations = ax_box.get_xticks()
      ax_box.xaxis.set_major_locator(FixedLocator(tick_locations))
      ax_box.set_xticklabels(tick_labels, rotation=90)
      ax_box.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='Repeat')

    # Statistical Analyses and Heatmaps

    # Effect Size heatmap ax
      ax_d = fig.add_subplot(gs[1, 0])
      sns.heatmap(effect_size_matrices[var].fillna(0), annot=True, cmap="viridis", cbar=True, square=True, ax=ax_d, vmax=1)
      ax_d.set_title(f"Effect Size (Cohen's d) for {var}")

    # p-value heatmap ax
      ax_p = fig.add_subplot(gs[1, 1])
      sns.heatmap(p_value_matrices[var].fillna(1), annot=True, cmap="viridis_r", cbar=True, square=True, ax=ax_p, vmax=0.1)
      ax_p.set_title(f"Randomization Test p-value for {var}")

    # Bonferroni corrected p-value heatmap ax
      ax_bonf = fig.add_subplot(gs[1, 2])
      sns.heatmap(bonferroni_matrices[var].fillna(1), annot=True, cmap="viridis_r", cbar=True, square=True, ax=ax_bonf, vmax=0.1)
      ax_bonf.set_title(f"Bonferroni-corrected p-value for {var}")

      plt.tight_layout()
      pdf_pages.savefig(fig)

    # Close the PDF
      pdf_pages.close()

condition_selector = display_condition_selection(df_to_plot, Conditions)
selectable_columns = get_selectable_columns(df_to_plot)
variable_checkboxes = display_variable_checkboxes(selectable_columns)

button = Button(description="Plot Selected Variables", layout=Layout(width='400px'), button_style='info')
button.on_click(lambda b: plot_selected_vars(b, variable_checkboxes, df_to_plot, Conditions, base_folder, condition_selector))
display(button)

## **4.2. Plot a balanced dataset**
--------

## **4.2.1. Downsample your dataset to ensure that it is balanced**
--------

### Downsampling and Balancing Dataset

This section of the notebook is dedicated to addressing imbalances in the dataset, which is crucial for ensuring the accuracy and reliability of the analysis. The cell bellow will downsample the dataset to balance the number of tracks across different conditions and repeats. It allows for reproducibility by including a `random_seed` parameter, which is set to 42 by default but can be adjusted as needed.

All results from this section will be saved in the Balanced Dataset Directory created in your `Results_Folder`.




In [None]:
# @title ##Run this cell to downsample and balance your dataset

random_seed = 42  # @param {type: "number"}

if not os.path.exists(f"{Results_Folder}/Balanced_dataset"):
    os.makedirs(f"{Results_Folder}/Balanced_dataset")


def count_tracks_by_condition_and_repeat(df, Results_Folder, condition_col='Condition', repeat_col='Repeat', track_id_col='Unique_ID'):
    """
    Counts the number of unique tracks for each combination of condition and repeat in the given DataFrame and
    saves a stacked histogram plot as a PDF in the QC folder with annotations for each stack.

    Parameters:
    df (pandas.DataFrame): The DataFrame containing the data.
    Results_Folder (str): The base folder where the results will be saved.
    condition_col (str): The name of the column representing the condition. Default is 'Condition'.
    repeat_col (str): The name of the column representing the repeat. Default is 'Repeat'.
    track_id_col (str): The name of the column representing the track ID. Default is 'Unique_ID'.
    """
    track_counts = df.groupby([condition_col, repeat_col])[track_id_col].nunique()
    track_counts_df = track_counts.reset_index()
    track_counts_df.rename(columns={track_id_col: 'Number_of_Tracks'}, inplace=True)

    # Pivot the data for plotting
    pivot_df = track_counts_df.pivot(index=condition_col, columns=repeat_col, values='Number_of_Tracks').fillna(0)

    # Plotting
    fig, ax = plt.subplots(figsize=(12, 6))
    bars = pivot_df.plot(kind='bar', stacked=True, ax=ax)
    ax.set_xlabel('Condition')
    ax.set_ylabel('Number of Tracks')
    ax.set_title('Stacked Histogram of Track Counts per Condition and Repeat')
    ax.legend(title=repeat_col)
    ax.grid(axis='y', linestyle='--')

    # Hide horizontal grid lines
    ax.yaxis.grid(False)

    # Add number annotations on each stack
    for bar in bars.patches:
        ax.text(bar.get_x() + bar.get_width() / 2,
                bar.get_y() + bar.get_height() / 2,
                int(bar.get_height()),
                ha='center', va='center', color='black', fontweight='bold', fontsize=8)

    # Save the plot as a PDF
    pdf_file = os.path.join(Results_Folder, 'Track_Counts_Histogram.pdf')
    plt.savefig(pdf_file, bbox_inches='tight')
    print(f"Saved histogram to {pdf_file}")

    plt.show()

    return track_counts_df

def balance_dataset(df, condition_col='Condition', repeat_col='Repeat', track_id_col='Unique_ID', random_seed=None):
    """
    Balances the dataset by downsampling tracks for each condition and repeat combination.

    Parameters:
    df (pandas.DataFrame): The DataFrame containing the data.
    condition_col (str): The name of the column representing the condition.
    repeat_col (str): The name of the column representing the repeat.
    track_id_col (str): The name of the column representing the track ID.
    random_seed (int, optional): The seed for the random number generator. Default is None.

    Returns:
    pandas.DataFrame: A new DataFrame with balanced track counts.
    """
    # Group by condition and repeat, and find the minimum track count
    min_track_count = df.groupby([condition_col, repeat_col])[track_id_col].nunique().min()

    # Function to sample min_track_count tracks from each group
    def sample_tracks(group):
        return group.sample(n=min_track_count, random_state=random_seed)

    # Apply sampling to each group and concatenate the results
    balanced_merged_tracks_df = df.groupby([condition_col, repeat_col]).apply(sample_tracks).reset_index(drop=True)

    return balanced_merged_tracks_df

balanced_merged_tracks_df = balance_dataset(merged_tracks_df, random_seed=random_seed)
result_df = count_tracks_by_condition_and_repeat(balanced_merged_tracks_df, f"{Results_Folder}/Balanced_dataset")

check_for_nans(balanced_merged_tracks_df, "balanced_merged_tracks_df")
save_dataframe_with_progress(balanced_merged_tracks_df, Results_Folder + '/Balanced_dataset/merged_Tracks_balanced_dataset.csv')


## **4.2.2. Check if the downsampling has affected data distribution**
--------

This section of the notebook generates a heatmap visualizing the Kolmogorov-Smirnov (KS) p-values for each numerical column in the dataset, comparing the distributions before and after downsampling. This heatmap serves as a tool for assessing the impact of downsampling on data quality, guiding decisions on whether the downsampled dataset is suitable for further analysis.

#### Purpose of the Heatmap
- **KS Test:** The KS test is used to determine if two samples are drawn from the same distribution. In this context, it compares the distribution of each numerical column in the original dataset (`merged_tracks_df`) with its counterpart in the downsampled dataset (`balanced_merged_tracks_df`).
- **P-Value Interpretation:** The p-value indicates the probability that the two samples come from the same distribution. A higher p-value suggests a greater likelihood that the distributions are similar.

#### Interpreting the Heatmap
- **Color Coding:** The heatmap uses a color gradient (from viridis) to represent the range of p-values. Darker colors indicate higher p-values.
- **P-Value Thresholds:**
  - **High P-Values (Lighter Areas):** Indicate that the downsampling process likely did not significantly alter the distribution of that numerical column for the specific condition-repeat group.
  - **Low P-Values (Darker Areas):** Suggest that the downsampling process may have affected the distribution significantly.
- **Varying P-Values:** Variations in color across different columns and rows help identify which specific numerical columns and condition-repeat groups are most affected by the downsampling.




In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp

# @title ##Check if your downsampling has affected your data distribution

def calculate_ks_p_value(df1, df2, column):
    """
    Calculate the KS p-value for a given column between two dataframes.

    Parameters:
    df1 (pandas.DataFrame): Original DataFrame.
    df2 (pandas.DataFrame): DataFrame after downsampling.
    column (str): Column name to compare.

    Returns:
    float: KS p-value.
    """
    return ks_2samp(df1[column].dropna(), df2[column].dropna())[1]

# Identify numerical columns
numerical_columns = merged_tracks_df.select_dtypes(include=['int64', 'float64']).columns

# Initialize a DataFrame to store KS p-values
ks_p_values = pd.DataFrame(columns=numerical_columns)

# Iterate over each group and numerical column
for group, group_df in merged_tracks_df.groupby(['Condition', 'Repeat']):
    group_p_values = []
    balanced_group_df = balanced_merged_tracks_df[(balanced_merged_tracks_df['Condition'] == group[0]) & (balanced_merged_tracks_df['Repeat'] == group[1])]
    for column in numerical_columns:
        p_value = calculate_ks_p_value(group_df, balanced_group_df, column)
        group_p_values.append(p_value)
    ks_p_values.loc[f'Condition: {group[0]}, Repeat: {group[1]}'] = group_p_values

# Maximum number of columns per heatmap
max_columns_per_heatmap = 20

# Total number of columns
total_columns = len(ks_p_values.columns)

# Calculate the number of heatmaps needed
num_heatmaps = -(-total_columns // max_columns_per_heatmap)  # Ceiling division

# File path for the PDF
pdf_filepath = Results_Folder+'/Balanced_dataset/p-Value Heatmap.pdf'

# Create a PDF file
with PdfPages(pdf_filepath) as pdf:
    # Loop through each subset of columns and create a heatmap
    for i in range(num_heatmaps):
        start_col = i * max_columns_per_heatmap
        end_col = min(start_col + max_columns_per_heatmap, total_columns)

        # Subset of columns for this heatmap
        subset_columns = ks_p_values.columns[start_col:end_col]

        # Create the heatmap for the subset of columns
        plt.figure(figsize=(12, 8))
        sns.heatmap(ks_p_values[subset_columns], cmap='viridis', vmax=0.5, vmin=0)
        plt.title(f'Kolmogorov-Smirnov P-Value Heatmap (Columns {start_col+1} to {end_col})')
        plt.xlabel('Numerical Columns')
        plt.ylabel('Condition-Repeat Groups')
        plt.tight_layout()

        # Save the current figure to the PDF
        pdf.savefig()
        plt.show()
        plt.close()

print(f"Saved all heatmaps to {pdf_filepath}")

# Save the p-values to a CSV file
ks_p_values.to_csv(Results_Folder + '/Balanced_dataset/ks_p_values.csv')
print("Saved KS p-values to ks_p_values.csv")


## **4.2.3. Plot your balanced dataset**
--------

In [None]:
# @title ##Plot track parameters (balanced dataset)

import ipywidgets as widgets
from ipywidgets import Layout, VBox, Button, Accordion, SelectMultiple, IntText
import pandas as pd
import os
from matplotlib.ticker import FixedLocator


# Parameters to adapt in function of the notebook section
base_folder = f"{Results_Folder}/Balanced_dataset/track_parameters_plots"
Conditions = 'Condition'
df_to_plot = balanced_merged_tracks_df

# Check and create necessary directories
folders = ["pdf", "csv"]
for folder in folders:
    dir_path = os.path.join(base_folder, folder)
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def get_selectable_columns(df):
    # Exclude certain columns from being plotted
    exclude_cols = ['Condition', 'experiment_nb', 'File_name', 'Repeat', 'Unique_ID', 'LABEL', 'TRACK_INDEX', 'TRACK_ID', 'TRACK_X_LOCATION', 'TRACK_Y_LOCATION', 'TRACK_Z_LOCATION', 'Exemplar','TRACK_STOP', 'TRACK_START', 'Cluster_UMAP', 'Cluster_tsne']
    # Select only numerical columns
    return [col for col in df.columns if (df[col].dtype.kind in 'biufc') and (col not in exclude_cols)]


def display_variable_checkboxes(selectable_columns):
    # Create checkboxes for selectable columns
    variable_checkboxes = [widgets.Checkbox(value=False, description=col) for col in selectable_columns]

    # Display checkboxes in the notebook
    display(widgets.VBox([
        widgets.Label('Variables to Plot:'),
        widgets.GridBox(variable_checkboxes, layout=widgets.Layout(grid_template_columns="repeat(%d, 300px)" % 3)),
    ]))
    return variable_checkboxes

def create_condition_selector(df, column_name):
    conditions = df[column_name].unique()
    condition_selector = SelectMultiple(
        options=conditions,
        description='Conditions:',
        disabled=False,
        layout=Layout(width='100%')  # Adjusting the layout width
    )
    return condition_selector

def display_condition_selection(df, column_name):
    condition_selector = create_condition_selector(df, column_name)

    condition_accordion = Accordion(children=[VBox([condition_selector])])
    condition_accordion.set_title(0, 'Select Conditions')
    display(condition_accordion)
    return condition_selector


def plot_selected_vars(button, variable_checkboxes, df, Conditions, Results_Folder, condition_selector):

    plt.clf()  # Clear the current figure before creating a new plot
    print("Plotting in progress...")

  # Get selected variables
    variables_to_plot = [box.description for box in variable_checkboxes if box.value]
    n_plots = len(variables_to_plot)

    if n_plots == 0:
        print("No variables selected for plotting")
        return

  # Get selected conditions
    selected_conditions = condition_selector.value
    n_selected_conditions = len(selected_conditions)

    if n_selected_conditions == 0:
        print("No conditions selected for plotting")
        return

# Use only selected and ordered conditions
    filtered_df = df[df[Conditions].isin(selected_conditions)].copy()

# Initialize matrices to store effect sizes and p-values for each variable
    effect_size_matrices = {}
    p_value_matrices = {}
    bonferroni_matrices = {}

    unique_conditions = filtered_df[Conditions].unique().tolist()
    num_comparisons = len(unique_conditions) * (len(unique_conditions) - 1) // 2
    alpha = 0.05
    corrected_alpha = alpha / num_comparisons
    n_iterations = 1000

# Loop through each variable to plot
    for var in variables_to_plot:

      pdf_pages = PdfPages(f"{Results_Folder}/pdf/{var}_Boxplots_and_Statistics.pdf")
      effect_size_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)
      p_value_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)
      bonferroni_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)

      for cond1, cond2 in itertools.combinations(unique_conditions, 2):
        group1 = df[df[Conditions] == cond1][var]
        group2 = df[df[Conditions] == cond2][var]

        original_d = abs(cohen_d(group1, group2))
        effect_size_matrix.loc[cond1, cond2] = original_d
        effect_size_matrix.loc[cond2, cond1] = original_d  # Mirroring

        count_extreme = 0
        for i in range(n_iterations):
            combined = pd.concat([group1, group2])
            shuffled = combined.sample(frac=1, replace=False).reset_index(drop=True)
            new_group1 = shuffled[:len(group1)]
            new_group2 = shuffled[len(group1):]

            new_d = abs(cohen_d(new_group1, new_group2))
            if np.abs(new_d) >= np.abs(original_d):
                count_extreme += 1

        p_value = count_extreme / n_iterations
        p_value_matrix.loc[cond1, cond2] = p_value
        p_value_matrix.loc[cond2, cond1] = p_value  # Mirroring

        # Apply Bonferroni correction
        bonferroni_corrected_p_value = min(p_value * num_comparisons, 1.0)
        bonferroni_matrix.loc[cond1, cond2] = bonferroni_corrected_p_value
        bonferroni_matrix.loc[cond2, cond1] = bonferroni_corrected_p_value  # Mirroring

      effect_size_matrices[var] = effect_size_matrix
      p_value_matrices[var] = p_value_matrix
      bonferroni_matrices[var] = bonferroni_matrix

    # Concatenate the three matrices side-by-side
      combined_df = pd.concat(
        [
            effect_size_matrices[var].rename(columns={col: f"{col} (Effect Size)" for col in effect_size_matrices[var].columns}),
            p_value_matrices[var].rename(columns={col: f"{col} (P-Value)" for col in p_value_matrices[var].columns}),
            bonferroni_matrices[var].rename(columns={col: f"{col} (Bonferroni-corrected P-Value)" for col in bonferroni_matrices[var].columns})
        ], axis=1
    )

    # Save the combined DataFrame to a CSV file
      combined_df.to_csv(f"{Results_Folder}/csv/{var}_statistics_combined.csv")

    # Create a new figure
      fig = plt.figure(figsize=(16, 10))

    # Create a gridspec for 2 rows and 4 columns
      gs = GridSpec(2, 3, height_ratios=[1.5, 1])

    # Create the ax for boxplot using the gridspec
      ax_box = fig.add_subplot(gs[0, :])

    # Extract the data for this variable
      data_for_var = df[[Conditions, var, 'Repeat', 'File_name' ]]

    # Save the data_for_var to a CSV for replotting
      data_for_var.to_csv(f"{Results_Folder}/csv/{var}_boxplot_data.csv", index=False)

    # Calculate the Interquartile Range (IQR) using the 25th and 75th percentiles
      Q1 = df[var].quantile(0.25)
      Q3 = df[var].quantile(0.75)
      IQR = Q3 - Q1

    # Define bounds for the outliers
      multiplier = 10
      lower_bound = Q1 - multiplier * IQR
      upper_bound = Q3 + multiplier * IQR

    # Plotting
      sns.boxplot(x=Conditions, y=var, data=filtered_df, ax=ax_box, color='lightgray')  # Boxplot
      sns.stripplot(x=Conditions, y=var, data=filtered_df, ax=ax_box, hue='Repeat', dodge=True, jitter=True, alpha=0.2)  # Individual data points
      ax_box.set_ylim([max(min(filtered_df[var]), lower_bound), min(max(filtered_df[var]), upper_bound)])
      ax_box.set_title(f"{var}")
      ax_box.set_xlabel('Condition')
      ax_box.set_ylabel(var)
      tick_labels = ax_box.get_xticklabels()
      tick_locations = ax_box.get_xticks()
      ax_box.xaxis.set_major_locator(FixedLocator(tick_locations))
      ax_box.set_xticklabels(tick_labels, rotation=90)
      ax_box.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='Repeat')

    # Statistical Analyses and Heatmaps

    # Effect Size heatmap ax
      ax_d = fig.add_subplot(gs[1, 0])
      sns.heatmap(effect_size_matrices[var].fillna(0), annot=True, cmap="viridis", cbar=True, square=True, ax=ax_d, vmax=1)
      ax_d.set_title(f"Effect Size (Cohen's d) for {var}")

    # p-value heatmap ax
      ax_p = fig.add_subplot(gs[1, 1])
      sns.heatmap(p_value_matrices[var].fillna(1), annot=True, cmap="viridis_r", cbar=True, square=True, ax=ax_p, vmax=0.1)
      ax_p.set_title(f"Randomization Test p-value for {var}")

    # Bonferroni corrected p-value heatmap ax
      ax_bonf = fig.add_subplot(gs[1, 2])
      sns.heatmap(bonferroni_matrices[var].fillna(1), annot=True, cmap="viridis_r", cbar=True, square=True, ax=ax_bonf, vmax=0.1)
      ax_bonf.set_title(f"Bonferroni-corrected p-value for {var}")

      plt.tight_layout()
      pdf_pages.savefig(fig)

    # Close the PDF
      pdf_pages.close()

condition_selector = display_condition_selection(df_to_plot, Conditions)
selectable_columns = get_selectable_columns(df_to_plot)
variable_checkboxes = display_variable_checkboxes(selectable_columns)

button = Button(description="Plot Selected Variables", layout=Layout(width='400px'), button_style='info')
button.on_click(lambda b: plot_selected_vars(b, variable_checkboxes, df_to_plot, Conditions, base_folder, condition_selector))
display(button)

# **Part 5. Classify your tracks by distance and plot your dataset**
--------
<font size = 4>
<b>Note on Units:</b> The parameters plotted are in the unit of measurement you used when tracking your data.
</font>


<font size="4" color="red">
<b>Results Storage:</b> The results generated in this section are saved in the sub-folder named `Distance_to_ROI` located within your `Results_Folder`.


In [None]:
# @title ##Check the distribution of your dataset

import seaborn as sns
import matplotlib.pyplot as plt

ROI_name = 'edge'  #@param {type:"string"}

def plot_tracks_vs_distance(dataframe, ROI_name):
    sns.set(style="whitegrid")
    plt.figure(figsize=(20, 6))

    # Creating a histogram with a bin size of 10
    ax = sns.histplot(data=dataframe, x=f'MaxDistance_{ROI_name}', bins=range(0, int(dataframe[f'MaxDistance_{ROI_name}'].max()) + 10, 10), kde=False)

    plt.title(f'Number of Tracks vs Max Distance to {ROI_name}')
    plt.xlabel(f'Distance to {ROI_name}')
    plt.ylabel('Number of Tracks')

    # Set x-ticks and rotate the labels for better readability
    plt.xticks(range(0, int(dataframe[f'MaxDistance_{ROI_name}'].max()) + 10, 10), rotation=90, ha='right')

    plt.tight_layout()  # Adjust the layout to accommodate label sizes
    plt.show()

# Example usage
plot_tracks_vs_distance(merged_tracks_df, ROI_name)  # Replace 'ROI_name' with the actual ROI name


## **5.1. Classify your tracks by distance**
--------
This section classify tracks based on distance from a specified Region of Interest (ROI).

## Input Parameters
1. `distance_threshold` (type: number): The distance in units (as per your data) that defines the threshold for classification. Tracks within this distance (based on the Max distance computed previously) from the ROI are classified as 'Close,' and those beyond it as 'Far.'
2. `ROI_name` (type: string): The name of the Region of Interest. This is used to label the classification columns in the DataFrame.

In [None]:
# Imports
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np

# Function to classify tracks by distance
def classify_tracks_by_distance(dataframe, distance_threshold, ROI_name):
    classification_column = f'Track_Classification_{ROI_name}'
    dataframe[classification_column] = dataframe.apply(
        lambda row: f'Close_{ROI_name}' if row[f'MaxDistance_{ROI_name}'] <= distance_threshold else f'Far_{ROI_name}', axis=1)
    dataframe[f'Track_Classification_Condition_{ROI_name}'] = dataframe['Condition']+'_'+dataframe[classification_column]
    return dataframe

# Google Colab form fields
distance_threshold = 75  #@param {type:"number"}
ROI_name = 'edge'  #@param {type:"string"}

# Classifying tracks
classified_tracks_df = classify_tracks_by_distance(merged_tracks_df, distance_threshold, ROI_name)

# Displaying and saving results
close_tracks_count = len(classified_tracks_df[classified_tracks_df[f'Track_Classification_{ROI_name}'] == f'Close_{ROI_name}'])
far_tracks_count = len(classified_tracks_df) - close_tracks_count

print(f"Classification of tracks based on distance to {ROI_name}:")
print(f"Close to {ROI_name}: {close_tracks_count}")
print(f"Far from {ROI_name}: {far_tracks_count}")

# Save the updated DataFrame
save_dataframe_with_progress(merged_tracks_df, Results_Folder + '/' + 'merged_Tracks.csv')

## **5.2. Plot your entire dataset**
--------

In [None]:
# @title ##Plot track parameters (entire dataset)

import ipywidgets as widgets
from ipywidgets import Layout, VBox, Button, Accordion, SelectMultiple, IntText
import pandas as pd
import os
from matplotlib.ticker import FixedLocator

# Parameters to adapt in function of the notebook section
base_folder = f"{Results_Folder}/Distance_to_ROI/{ROI_name}/Classified_track_parameters_plots"
Conditions = f'Track_Classification_Condition_{ROI_name}'
df_to_plot = merged_tracks_df

# Check and create necessary directories
folders = ["pdf", "csv"]
for folder in folders:
    dir_path = os.path.join(base_folder, folder)
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def get_selectable_columns(df):
    # Exclude certain columns from being plotted
    exclude_cols = ['Condition', 'experiment_nb', 'File_name', 'Repeat', 'Unique_ID', 'LABEL', 'TRACK_INDEX', 'TRACK_ID', 'TRACK_X_LOCATION', 'TRACK_Y_LOCATION', 'TRACK_Z_LOCATION', 'Exemplar','TRACK_STOP', 'TRACK_START', 'Cluster_UMAP', 'Cluster_tsne']
    # Select only numerical columns
    return [col for col in df.columns if (df[col].dtype.kind in 'biufc') and (col not in exclude_cols)]

def display_variable_checkboxes(selectable_columns):
    # Create checkboxes for selectable columns
    variable_checkboxes = [widgets.Checkbox(value=False, description=col) for col in selectable_columns]

    # Display checkboxes in the notebook
    display(widgets.VBox([
        widgets.Label('Variables to Plot:'),
        widgets.GridBox(variable_checkboxes, layout=widgets.Layout(grid_template_columns="repeat(%d, 300px)" % 3)),
    ]))
    return variable_checkboxes

def create_condition_selector(df, column_name):
    conditions = df[column_name].unique()
    condition_selector = SelectMultiple(
        options=conditions,
        description='Conditions:',
        disabled=False,
        layout=Layout(width='100%')  # Adjusting the layout width
    )
    return condition_selector

def display_condition_selection(df, column_name):
    condition_selector = create_condition_selector(df, column_name)

    condition_accordion = Accordion(children=[VBox([condition_selector])])
    condition_accordion.set_title(0, 'Select Conditions')
    display(condition_accordion)
    return condition_selector


def plot_selected_vars(button, variable_checkboxes, df, Conditions, Results_Folder, condition_selector):

    plt.clf()  # Clear the current figure before creating a new plot
    print("Plotting in progress...")

  # Get selected variables
    variables_to_plot = [box.description for box in variable_checkboxes if box.value]
    n_plots = len(variables_to_plot)

    if n_plots == 0:
        print("No variables selected for plotting")
        return

  # Get selected conditions
    selected_conditions = condition_selector.value
    n_selected_conditions = len(selected_conditions)

    if n_selected_conditions == 0:
        print("No conditions selected for plotting")
        return

# Use only selected and ordered conditions
    filtered_df = df[df[Conditions].isin(selected_conditions)].copy()

# Initialize matrices to store effect sizes and p-values for each variable
    effect_size_matrices = {}
    p_value_matrices = {}
    bonferroni_matrices = {}

    unique_conditions = filtered_df[Conditions].unique().tolist()
    num_comparisons = len(unique_conditions) * (len(unique_conditions) - 1) // 2
    alpha = 0.05
    corrected_alpha = alpha / num_comparisons
    n_iterations = 1000

# Loop through each variable to plot
    for var in variables_to_plot:

      pdf_pages = PdfPages(f"{Results_Folder}/pdf/{var}_Boxplots_and_Statistics.pdf")
      effect_size_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)
      p_value_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)
      bonferroni_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)

      for cond1, cond2 in itertools.combinations(unique_conditions, 2):
        group1 = df[df[Conditions] == cond1][var]
        group2 = df[df[Conditions] == cond2][var]

        original_d = abs(cohen_d(group1, group2))
        effect_size_matrix.loc[cond1, cond2] = original_d
        effect_size_matrix.loc[cond2, cond1] = original_d  # Mirroring

        count_extreme = 0
        for i in range(n_iterations):
            combined = pd.concat([group1, group2])
            shuffled = combined.sample(frac=1, replace=False).reset_index(drop=True)
            new_group1 = shuffled[:len(group1)]
            new_group2 = shuffled[len(group1):]

            new_d = abs(cohen_d(new_group1, new_group2))
            if np.abs(new_d) >= np.abs(original_d):
                count_extreme += 1

        p_value = count_extreme / n_iterations
        p_value_matrix.loc[cond1, cond2] = p_value
        p_value_matrix.loc[cond2, cond1] = p_value  # Mirroring

        # Apply Bonferroni correction
        bonferroni_corrected_p_value = min(p_value * num_comparisons, 1.0)
        bonferroni_matrix.loc[cond1, cond2] = bonferroni_corrected_p_value
        bonferroni_matrix.loc[cond2, cond1] = bonferroni_corrected_p_value  # Mirroring

      effect_size_matrices[var] = effect_size_matrix
      p_value_matrices[var] = p_value_matrix
      bonferroni_matrices[var] = bonferroni_matrix

    # Concatenate the three matrices side-by-side
      combined_df = pd.concat(
        [
            effect_size_matrices[var].rename(columns={col: f"{col} (Effect Size)" for col in effect_size_matrices[var].columns}),
            p_value_matrices[var].rename(columns={col: f"{col} (P-Value)" for col in p_value_matrices[var].columns}),
            bonferroni_matrices[var].rename(columns={col: f"{col} (Bonferroni-corrected P-Value)" for col in bonferroni_matrices[var].columns})
        ], axis=1
    )

    # Save the combined DataFrame to a CSV file
      combined_df.to_csv(f"{Results_Folder}/csv/{var}_statistics_combined.csv")

    # Create a new figure
      fig = plt.figure(figsize=(16, 10))

    # Create a gridspec for 2 rows and 4 columns
      gs = GridSpec(2, 3, height_ratios=[1.5, 1])

    # Create the ax for boxplot using the gridspec
      ax_box = fig.add_subplot(gs[0, :])

    # Extract the data for this variable
      data_for_var = df[[Conditions, var, 'Repeat', 'File_name' ]]

    # Save the data_for_var to a CSV for replotting
      data_for_var.to_csv(f"{Results_Folder}/csv/{var}_boxplot_data.csv", index=False)

    # Calculate the Interquartile Range (IQR) using the 25th and 75th percentiles
      Q1 = df[var].quantile(0.25)
      Q3 = df[var].quantile(0.75)
      IQR = Q3 - Q1

    # Define bounds for the outliers
      multiplier = 10
      lower_bound = Q1 - multiplier * IQR
      upper_bound = Q3 + multiplier * IQR

    # Plotting
      sns.boxplot(x=Conditions, y=var, data=filtered_df, ax=ax_box, color='lightgray')  # Boxplot
      sns.stripplot(x=Conditions, y=var, data=filtered_df, ax=ax_box, hue='Repeat', dodge=True, jitter=True, alpha=0.2)  # Individual data points
      ax_box.set_ylim([max(min(filtered_df[var]), lower_bound), min(max(filtered_df[var]), upper_bound)])
      ax_box.set_title(f"{var}")
      ax_box.set_xlabel('Condition')
      ax_box.set_ylabel(var)
      tick_labels = ax_box.get_xticklabels()
      tick_locations = ax_box.get_xticks()
      ax_box.xaxis.set_major_locator(FixedLocator(tick_locations))
      ax_box.set_xticklabels(tick_labels, rotation=90)
      ax_box.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='Repeat')

    # Statistical Analyses and Heatmaps

    # Effect Size heatmap ax
      ax_d = fig.add_subplot(gs[1, 0])
      sns.heatmap(effect_size_matrices[var].fillna(0), annot=True, cmap="viridis", cbar=True, square=True, ax=ax_d, vmax=1)
      ax_d.set_title(f"Effect Size (Cohen's d) for {var}")

    # p-value heatmap ax
      ax_p = fig.add_subplot(gs[1, 1])
      sns.heatmap(p_value_matrices[var].fillna(1), annot=True, cmap="viridis_r", cbar=True, square=True, ax=ax_p, vmax=0.1)
      ax_p.set_title(f"Randomization Test p-value for {var}")

    # Bonferroni corrected p-value heatmap ax
      ax_bonf = fig.add_subplot(gs[1, 2])
      sns.heatmap(bonferroni_matrices[var].fillna(1), annot=True, cmap="viridis_r", cbar=True, square=True, ax=ax_bonf, vmax=0.1)
      ax_bonf.set_title(f"Bonferroni-corrected p-value for {var}")

      plt.tight_layout()
      pdf_pages.savefig(fig)

    # Close the PDF
      pdf_pages.close()

condition_selector = display_condition_selection(df_to_plot, Conditions)
selectable_columns = get_selectable_columns(df_to_plot)
variable_checkboxes = display_variable_checkboxes(selectable_columns)

button = Button(description="Plot Selected Variables", layout=Layout(width='400px'), button_style='info')
button.on_click(lambda b: plot_selected_vars(b, variable_checkboxes, df_to_plot, Conditions, base_folder, condition_selector))
display(button)

## **5.3. Plot a balanced dataset**
--------

## **5.3.1. Downsample your dataset to ensure that it is balanced**
--------

### Downsampling and Balancing Dataset

This section of the notebook is dedicated to addressing imbalances in the dataset, which is crucial for ensuring the accuracy and reliability of the analysis. The cell bellow will downsample the dataset to balance the number of tracks across different conditions and repeats. It allows for reproducibility by including a `random_seed` parameter, which is set to 42 by default but can be adjusted as needed.


<font size="4" color="red">
<b>Results Storage:</b> The results generated in this section are saved in the sub-folder named `Distance_to_ROI` located within your `Results_Folder`.



In [None]:
# @title ##Run this cell to downsample and balance your dataset

random_seed = 42  # @param {type: "number"}

if not os.path.exists(f"{Results_Folder}/Distance_to_ROI/{ROI_name}/Classified_Balanced_dataset"):
    os.makedirs(f"{Results_Folder}/Distance_to_ROI/{ROI_name}/Classified_Balanced_dataset")


def count_tracks_by_condition_and_repeat(df, Results_Folder, condition_col=f'Track_Classification_Condition_{ROI_name}', repeat_col='Repeat', track_id_col='Unique_ID'):
    """
    Counts the number of unique tracks for each combination of condition and repeat in the given DataFrame and
    saves a stacked histogram plot as a PDF in the QC folder with annotations for each stack.

    Parameters:
    df (pandas.DataFrame): The DataFrame containing the data.
    Results_Folder (str): The base folder where the results will be saved.
    condition_col (str): The name of the column representing the condition. Default is 'Condition'.
    repeat_col (str): The name of the column representing the repeat. Default is 'Repeat'.
    track_id_col (str): The name of the column representing the track ID. Default is 'Unique_ID'.
    """
    track_counts = df.groupby([condition_col, repeat_col])[track_id_col].nunique()
    track_counts_df = track_counts.reset_index()
    track_counts_df.rename(columns={track_id_col: 'Number_of_Tracks'}, inplace=True)

    # Pivot the data for plotting
    pivot_df = track_counts_df.pivot(index=condition_col, columns=repeat_col, values='Number_of_Tracks').fillna(0)

    # Plotting
    fig, ax = plt.subplots(figsize=(12, 6))
    bars = pivot_df.plot(kind='bar', stacked=True, ax=ax)
    ax.set_xlabel('Condition')
    ax.set_ylabel('Number of Tracks')
    ax.set_title('Stacked Histogram of Track Counts per Condition and Repeat')
    ax.legend(title=repeat_col)
    ax.grid(axis='y', linestyle='--')

    # Hide horizontal grid lines
    ax.yaxis.grid(False)

    # Add number annotations on each stack
    for bar in bars.patches:
        ax.text(bar.get_x() + bar.get_width() / 2,
                bar.get_y() + bar.get_height() / 2,
                int(bar.get_height()),
                ha='center', va='center', color='black', fontweight='bold', fontsize=8)

    # Save the plot as a PDF
    pdf_file = os.path.join(Results_Folder, 'Track_Counts_Histogram.pdf')
    plt.savefig(pdf_file, bbox_inches='tight')
    print(f"Saved histogram to {pdf_file}")

    plt.show()

    return track_counts_df


def balance_dataset(df, condition_col=f'Track_Classification_Condition_{ROI_name}', repeat_col='Repeat', track_id_col='Unique_ID', random_seed=None):
    """
    Balances the dataset by downsampling tracks for each condition and repeat combination.

    Parameters:
    df (pandas.DataFrame): The DataFrame containing the data.
    condition_col (str): The name of the column representing the condition.
    repeat_col (str): The name of the column representing the repeat.
    track_id_col (str): The name of the column representing the track ID.
    random_seed (int, optional): The seed for the random number generator. Default is None.

    Returns:
    pandas.DataFrame: A new DataFrame with balanced track counts.
    """
    # Group by condition and repeat, and find the minimum track count
    min_track_count = df.groupby([condition_col, repeat_col])[track_id_col].nunique().min()

    # Function to sample min_track_count tracks from each group
    def sample_tracks(group):
        return group.sample(n=min_track_count, random_state=random_seed)

    # Apply sampling to each group and concatenate the results
    balanced_merged_tracks_df = df.groupby([condition_col, repeat_col]).apply(sample_tracks).reset_index(drop=True)

    return balanced_merged_tracks_df

balanced_merged_tracks_df = balance_dataset(merged_tracks_df, random_seed=random_seed)
result_df = count_tracks_by_condition_and_repeat(balanced_merged_tracks_df, f"{Results_Folder}/Distance_to_ROI/{ROI_name}/Classified_Balanced_dataset")

check_for_nans(balanced_merged_tracks_df, "balanced_merged_tracks_df")
save_dataframe_with_progress(balanced_merged_tracks_df,f"{Results_Folder}/Distance_to_ROI/{ROI_name}/Classified_Balanced_dataset/merged_Tracks_balanced_dataset.csv")


## **5.3.2. Check if the downsampling has affected data distribution**
--------

This section of the notebook generates a heatmap visualizing the Kolmogorov-Smirnov (KS) p-values for each numerical column in the dataset, comparing the distributions before and after downsampling. This heatmap serves as a tool for assessing the impact of downsampling on data quality, guiding decisions on whether the downsampled dataset is suitable for further analysis.

#### Purpose of the Heatmap
- **KS Test:** The KS test is used to determine if two samples are drawn from the same distribution. In this context, it compares the distribution of each numerical column in the original dataset (`merged_tracks_df`) with its counterpart in the downsampled dataset (`balanced_merged_tracks_df`).
- **P-Value Interpretation:** The p-value indicates the probability that the two samples come from the same distribution. A higher p-value suggests a greater likelihood that the distributions are similar.

#### Interpreting the Heatmap
- **Color Coding:** The heatmap uses a color gradient (from viridis) to represent the range of p-values. Darker colors indicate higher p-values.
- **P-Value Thresholds:**
  - **High P-Values (Lighter Areas):** Indicate that the downsampling process likely did not significantly alter the distribution of that numerical column for the specific condition-repeat group.
  - **Low P-Values (Darker Areas):** Suggest that the downsampling process may have affected the distribution significantly.
- **Varying P-Values:** Variations in color across different columns and rows help identify which specific numerical columns and condition-repeat groups are most affected by the downsampling.




In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp

# @title ##Check if your downsampling has affected your data distribution

def calculate_ks_p_value(df1, df2, column):
    """
    Calculate the KS p-value for a given column between two dataframes.

    Parameters:
    df1 (pandas.DataFrame): Original DataFrame.
    df2 (pandas.DataFrame): DataFrame after downsampling.
    column (str): Column name to compare.

    Returns:
    float: KS p-value.
    """
    return ks_2samp(df1[column].dropna(), df2[column].dropna())[1]

# Identify numerical columns
numerical_columns = merged_tracks_df.select_dtypes(include=['int64', 'float64']).columns

# Initialize a DataFrame to store KS p-values
ks_p_values = pd.DataFrame(columns=numerical_columns)

# Iterate over each group and numerical column
for group, group_df in merged_tracks_df.groupby([f'Track_Classification_Condition_{ROI_name}', 'Repeat']):
    group_p_values = []
    balanced_group_df = balanced_merged_tracks_df[(balanced_merged_tracks_df[f'Track_Classification_Condition_{ROI_name}'] == group[0]) & (balanced_merged_tracks_df['Repeat'] == group[1])]
    for column in numerical_columns:
        p_value = calculate_ks_p_value(group_df, balanced_group_df, column)
        group_p_values.append(p_value)
    ks_p_values.loc[f'{group[0]}, R: {group[1]}'] = group_p_values

# Maximum number of columns per heatmap
max_columns_per_heatmap = 20

# Total number of columns
total_columns = len(ks_p_values.columns)

# Calculate the number of heatmaps needed
num_heatmaps = -(-total_columns // max_columns_per_heatmap)  # Ceiling division

# File path for the PDF
pdf_filepath = f"{Results_Folder}/Distance_to_ROI/{ROI_name}/Classified_Balanced_dataset/KS_p_values_heatmaps.pdf"

# Create a PDF file
with PdfPages(pdf_filepath) as pdf:
    # Loop through each subset of columns and create a heatmap
    for i in range(num_heatmaps):
        start_col = i * max_columns_per_heatmap
        end_col = min(start_col + max_columns_per_heatmap, total_columns)

        # Subset of columns for this heatmap
        subset_columns = ks_p_values.columns[start_col:end_col]

        # Create the heatmap for the subset of columns
        plt.figure(figsize=(12, 8))
        sns.heatmap(ks_p_values[subset_columns], cmap='viridis', vmax=0.5, vmin=0)
        plt.title(f'Kolmogorov-Smirnov P-Value Heatmap (Columns {start_col+1} to {end_col})')
        plt.xlabel('Numerical Columns')
        plt.ylabel('Condition-Repeat Groups')
        plt.tight_layout()

        # Save the current figure to the PDF
        pdf.savefig()
        plt.show()
        plt.close()

print(f"Saved all heatmaps to {pdf_filepath}")

# Save the p-values to a CSV file
ks_p_values.to_csv(f"{Results_Folder}/Distance_to_ROI/{ROI_name}/Classified_Balanced_dataset/ks_p_values.csv")
print("Saved KS p-values to ks_p_values.csv")


## **5.3.3. Plot your balanced dataset**
--------

In [None]:
# @title ##Plot track parameters (balanced dataset)

import ipywidgets as widgets
from ipywidgets import Layout, VBox, Button, Accordion, SelectMultiple, IntText
import pandas as pd
import os
from matplotlib.ticker import FixedLocator


# Parameters to adapt in function of the notebook section
base_folder = f"{Results_Folder}/Distance_to_ROI/{ROI_name}/Classified_Balanced_dataset/Classified_track_parameters_plots"
Conditions = f'Track_Classification_Condition_{ROI_name}'
df_to_plot = balanced_merged_tracks_df

# Check and create necessary directories
folders = ["pdf", "csv"]
for folder in folders:
    dir_path = os.path.join(base_folder, folder)
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def get_selectable_columns(df):
    # Exclude certain columns from being plotted
    exclude_cols = ['Condition', 'experiment_nb', 'File_name', 'Repeat', 'Unique_ID', 'LABEL', 'TRACK_INDEX', 'TRACK_ID', 'TRACK_X_LOCATION', 'TRACK_Y_LOCATION', 'TRACK_Z_LOCATION', 'Exemplar','TRACK_STOP', 'TRACK_START', 'Cluster_UMAP', 'Cluster_tsne']
    # Select only numerical columns
    return [col for col in df.columns if (df[col].dtype.kind in 'biufc') and (col not in exclude_cols)]

def display_variable_checkboxes(selectable_columns):
    # Create checkboxes for selectable columns
    variable_checkboxes = [widgets.Checkbox(value=False, description=col) for col in selectable_columns]

    # Display checkboxes in the notebook
    display(widgets.VBox([
        widgets.Label('Variables to Plot:'),
        widgets.GridBox(variable_checkboxes, layout=widgets.Layout(grid_template_columns="repeat(%d, 300px)" % 3)),
    ]))
    return variable_checkboxes

def create_condition_selector(df, column_name):
    conditions = df[column_name].unique()
    condition_selector = SelectMultiple(
        options=conditions,
        description='Conditions:',
        disabled=False,
        layout=Layout(width='100%')  # Adjusting the layout width
    )
    return condition_selector

def display_condition_selection(df, column_name):
    condition_selector = create_condition_selector(df, column_name)

    condition_accordion = Accordion(children=[VBox([condition_selector])])
    condition_accordion.set_title(0, 'Select Conditions')
    display(condition_accordion)
    return condition_selector


def plot_selected_vars(button, variable_checkboxes, df, Conditions, Results_Folder, condition_selector):

    plt.clf()  # Clear the current figure before creating a new plot
    print("Plotting in progress...")

  # Get selected variables
    variables_to_plot = [box.description for box in variable_checkboxes if box.value]
    n_plots = len(variables_to_plot)

    if n_plots == 0:
        print("No variables selected for plotting")
        return

  # Get selected conditions
    selected_conditions = condition_selector.value
    n_selected_conditions = len(selected_conditions)

    if n_selected_conditions == 0:
        print("No conditions selected for plotting")
        return

# Use only selected and ordered conditions
    filtered_df = df[df[Conditions].isin(selected_conditions)].copy()

# Initialize matrices to store effect sizes and p-values for each variable
    effect_size_matrices = {}
    p_value_matrices = {}
    bonferroni_matrices = {}

    unique_conditions = filtered_df[Conditions].unique().tolist()
    num_comparisons = len(unique_conditions) * (len(unique_conditions) - 1) // 2
    alpha = 0.05
    corrected_alpha = alpha / num_comparisons
    n_iterations = 1000

# Loop through each variable to plot
    for var in variables_to_plot:

      pdf_pages = PdfPages(f"{Results_Folder}/pdf/{var}_Boxplots_and_Statistics.pdf")
      effect_size_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)
      p_value_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)
      bonferroni_matrix = pd.DataFrame(index=unique_conditions, columns=unique_conditions)

      for cond1, cond2 in itertools.combinations(unique_conditions, 2):
        group1 = df[df[Conditions] == cond1][var]
        group2 = df[df[Conditions] == cond2][var]

        original_d = abs(cohen_d(group1, group2))
        effect_size_matrix.loc[cond1, cond2] = original_d
        effect_size_matrix.loc[cond2, cond1] = original_d  # Mirroring

        count_extreme = 0
        for i in range(n_iterations):
            combined = pd.concat([group1, group2])
            shuffled = combined.sample(frac=1, replace=False).reset_index(drop=True)
            new_group1 = shuffled[:len(group1)]
            new_group2 = shuffled[len(group1):]

            new_d = abs(cohen_d(new_group1, new_group2))
            if np.abs(new_d) >= np.abs(original_d):
                count_extreme += 1

        p_value = count_extreme / n_iterations
        p_value_matrix.loc[cond1, cond2] = p_value
        p_value_matrix.loc[cond2, cond1] = p_value  # Mirroring

        # Apply Bonferroni correction
        bonferroni_corrected_p_value = min(p_value * num_comparisons, 1.0)
        bonferroni_matrix.loc[cond1, cond2] = bonferroni_corrected_p_value
        bonferroni_matrix.loc[cond2, cond1] = bonferroni_corrected_p_value  # Mirroring

      effect_size_matrices[var] = effect_size_matrix
      p_value_matrices[var] = p_value_matrix
      bonferroni_matrices[var] = bonferroni_matrix

    # Concatenate the three matrices side-by-side
      combined_df = pd.concat(
        [
            effect_size_matrices[var].rename(columns={col: f"{col} (Effect Size)" for col in effect_size_matrices[var].columns}),
            p_value_matrices[var].rename(columns={col: f"{col} (P-Value)" for col in p_value_matrices[var].columns}),
            bonferroni_matrices[var].rename(columns={col: f"{col} (Bonferroni-corrected P-Value)" for col in bonferroni_matrices[var].columns})
        ], axis=1
    )

    # Save the combined DataFrame to a CSV file
      combined_df.to_csv(f"{Results_Folder}/csv/{var}_statistics_combined.csv")

    # Create a new figure
      fig = plt.figure(figsize=(16, 10))

    # Create a gridspec for 2 rows and 4 columns
      gs = GridSpec(2, 3, height_ratios=[1.5, 1])

    # Create the ax for boxplot using the gridspec
      ax_box = fig.add_subplot(gs[0, :])

    # Extract the data for this variable
      data_for_var = df[[Conditions, var, 'Repeat', 'File_name' ]]

    # Save the data_for_var to a CSV for replotting
      data_for_var.to_csv(f"{Results_Folder}/csv/{var}_boxplot_data.csv", index=False)

    # Calculate the Interquartile Range (IQR) using the 25th and 75th percentiles
      Q1 = df[var].quantile(0.25)
      Q3 = df[var].quantile(0.75)
      IQR = Q3 - Q1

    # Define bounds for the outliers
      multiplier = 10
      lower_bound = Q1 - multiplier * IQR
      upper_bound = Q3 + multiplier * IQR

    # Plotting
      sns.boxplot(x=Conditions, y=var, data=filtered_df, ax=ax_box, color='lightgray')  # Boxplot
      sns.stripplot(x=Conditions, y=var, data=filtered_df, ax=ax_box, hue='Repeat', dodge=True, jitter=True, alpha=0.2)  # Individual data points
      ax_box.set_ylim([max(min(filtered_df[var]), lower_bound), min(max(filtered_df[var]), upper_bound)])
      ax_box.set_title(f"{var}")
      ax_box.set_xlabel('Condition')
      ax_box.set_ylabel(var)
      tick_labels = ax_box.get_xticklabels()
      tick_locations = ax_box.get_xticks()
      ax_box.xaxis.set_major_locator(FixedLocator(tick_locations))
      ax_box.set_xticklabels(tick_labels, rotation=90)
      ax_box.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='Repeat')

    # Statistical Analyses and Heatmaps

    # Effect Size heatmap ax
      ax_d = fig.add_subplot(gs[1, 0])
      sns.heatmap(effect_size_matrices[var].fillna(0), annot=True, cmap="viridis", cbar=True, square=True, ax=ax_d, vmax=1)
      ax_d.set_title(f"Effect Size (Cohen's d) for {var}")

    # p-value heatmap ax
      ax_p = fig.add_subplot(gs[1, 1])
      sns.heatmap(p_value_matrices[var].fillna(1), annot=True, cmap="viridis_r", cbar=True, square=True, ax=ax_p, vmax=0.1)
      ax_p.set_title(f"Randomization Test p-value for {var}")

    # Bonferroni corrected p-value heatmap ax
      ax_bonf = fig.add_subplot(gs[1, 2])
      sns.heatmap(bonferroni_matrices[var].fillna(1), annot=True, cmap="viridis_r", cbar=True, square=True, ax=ax_bonf, vmax=0.1)
      ax_bonf.set_title(f"Bonferroni-corrected p-value for {var}")

      plt.tight_layout()
      pdf_pages.savefig(fig)

    # Close the PDF
      pdf_pages.close()

condition_selector = display_condition_selection(df_to_plot, Conditions)
selectable_columns = get_selectable_columns(df_to_plot)
variable_checkboxes = display_variable_checkboxes(selectable_columns)

button = Button(description="Plot Selected Variables", layout=Layout(width='400px'), button_style='info')
button.on_click(lambda b: plot_selected_vars(b, variable_checkboxes, df_to_plot, Conditions, base_folder, condition_selector))
display(button)

# **Part 6: Version log**
---
<font size = 4>While I strive to provide accurate and helpful information, please be aware that:
  - This notebook may contain bugs.
  - Features are currently limited and will be expanded in future releases.

<font size = 4>We encourage users to report any issues or suggestions for improvement. Please check the [repository](https://github.com/guijacquemet/CellTracksColab) regularly for updates and the latest version of this notebook.


<font size = 4>**Version 0.9.1**
  - Added the PIP freeze option to save a requirement text
  - Added the heatmap visualisation of track parameters
  - Heatmaps can now be displayed on multiple pages
  - Fix userwarning message during plotting (all box plots)

<font size = 4>**Version 0.9**
  - Improved plotting strategy. Specific conditions can be chosen
  - absolute cohen d values are now shown
  - In the QC the heatmap is automatically divided in subplot when too many columns are in the df

<font size = 4>**Version 0.8**
  - Settings are now saved
  - Order of the section has been modified to help streamline biological discoveries
  - New section added to quality Control to check if the dataset is balanced

<font size = 4>**Version 0.7**
  - First release of this notebook

