In [None]:
import codecs
import io
import os

# Optional: Disable warnings
import warnings

warnings.filterwarnings("ignore")

import cv2
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import phenograph
import plotly.express as px
import scipy.stats as stats
import seaborn as sns
import umap
from IPython.display import clear_output, display
from ipywidgets import SelectMultiple, TwoByTwoLayout, interact, widgets
from matplotlib import cm
from numba import jit
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler

%matplotlib widget

### Select your single cell csv file and declare your working directory

In [None]:
output = widgets.Output()


def on_value_change(change):
    global singlecell_df
    global PATH

    # Extracting the value of working_directory
    PATH = working_directory.value

    # Checking if the file has been uploaded
    if input_file.value:
        uploaded_file = input_file.value[0]
        singlecell_df = pd.read_csv(io.BytesIO(uploaded_file.content))

        # Displaying the dataframe to get a quick overview
        with output:
            display(singlecell_df)
            output.clear_output(wait=True)


input_file = widgets.FileUpload(
    accept=".csv",
    multiple=False,
    description="Upload single cell csv",
    layout={"width": "auto"},
)

working_directory = widgets.Text(
    value="",
    placeholder="Enter your working directory path",
    description="Path:",
    style={"description_width": "initial"},
)

# Observing changes in the value of widgets
input_file.observe(on_value_change, names="value")
working_directory.observe(on_value_change, names="value")

# Display the input widgets and the output widget
display(input_file, working_directory, output)

### Plot a raw histogram for the selected channel

In [None]:
# Plot a raw histogram for the selected channel

histogram_select = widgets.RadioButtons(
    options=singlecell_df.columns,
    description="Histogram channel",
    style={"description_width": "initial"},
)

# Create an Output widget
plot_output = widgets.Output()


def update_histogram(change, data, channel=None, cluster=None):
    with plot_output:
        clear_output(wait=True)
        if cluster is not None:
            data = data[data["cluster"].cat.codes == cluster]
        sns.displot(data, x=channel if channel else histogram_select.value)
        plt.show()


# Update the observe function to pass the dataframe
histogram_select.observe(
    lambda change: update_histogram(change, singlecell_df), names="value"
)

# Call the function once to display the initial plot, passing the dataframe as an argument
update_histogram(None, singlecell_df)

# Display the radio buttons and the plot in an HBox
widgets.HBox([histogram_select, plot_output])

##### Uncomment the cell below to save the histogram 

In [None]:
# plt.savefig(fname=os.path.join(PATH, str(histogram_select.value + "_histogram_raw.png")), dpi=600)

## Size normalisation (optional)
The 'Area' of each cell is multiplied by the square of the entered pixel size to reflect real metrics in µm² rather than arbitrary values in pixels². \
The 'EquivalentDiameter' of each cell is multiplied by the entered pixel size resulting in µm² rather than pixels as a distance metric. \
Any selected channels are divided by their respective 'Area' (after conversion into µm²). This facilitates reliable clustering results by removing a size based difference in cellular isotop intensities.

### Select channels to be size-normalised using the specified pixel size (optional)
**Note:** Only select isotopes from the list. Do NOT include morphological features such as 'Area' or EquivalentDiameter'

In [None]:
# Enter pixelsize in µm
pixelsize_select = widgets.Text(
    value="",
    placeholder="Pixel size in µm",
    description="Pixel size:",
    style={"description_width": "initial"},
)

normalise_channels = widgets.SelectMultiple(
    options=singlecell_df.columns,
    rows=10,
    description="Channels to normalise by size",
    style={"description_width": "initial"},
)

binary_select_norm = widgets.RadioButtons(
    options=["Yes", "No"],
    value="No",
    description="Do you wish to display the modified dataframe?",
    style={"description_width": "initial"},
)

# Create the HBox with and display it
widgets.HBox(
    [pixelsize_select, normalise_channels, binary_select_norm],
)

**Note**:  
The following cell uses column names produced by Cellprofiler for 'Area' (AreaShape_Area) and 'Diameter'(AreaShape_EquivalentDiameter). If the single-cell CSV data was produced by other software or the column names were manually changed, the following strings (emphasised) must be adapted to represent the 'Area' and (optionally, otherwise can be omitted) 'Diameter' columns of the data.  
  
**\["AreaShape_Area"\]**

**\["AreaShape_EquivalentDiameter"\]**

In [None]:
normalised_df = singlecell_df.copy()

# Normalise features by size of cell objects;
cols_to_divide = list(normalise_channels.value)

try:
    normalised_df["AreaShape_Area"] = normalised_df["AreaShape_Area"].mul(
        float(pixelsize_select.value) ** 2
    )
    normalised_df["AreaShape_EquivalentDiameter"] = normalised_df["AreaShape_EquivalentDiameter"].mul(
        float(pixelsize_select.value)
    )
    normalised_df[cols_to_divide] = normalised_df[cols_to_divide].div(
        normalised_df["AreaShape_Area"], axis=0
    )
except:
    print("No pixel size for normalisation has been entered. Using original data!")

if binary_select_norm.value == "Yes":
    display(normalised_df)

## Outlier filtering (optional)
The following section performs outlier filtering either based on set percentiles or on Z-scores (standard scores).

**Percentiles method**: Outliers are determined by the percentile of the data. Everything below the 'Lower percentile' and above the 'Upper percentile' is considered an outlier.

**Z-score method**: A z-score represents how many standard deviations a data point is from the mean. This function uses z-scores indirectly by calculating the threshold as mean + n_std * standard deviation. Data points above this threshold are considered outliers. By setting the 'Number of std' you can specify how strict or lenient you want the outlier determination to be. A smaller value of n_std would be stricter, catching more outliers, whereas a larger value would be more lenient. A value of 3 would retain to 99.7% of the data.

**Note**: Z-scores are especially useful if your data is normally distributed.
Percentiles are more data-driven and can be applied without assumptions about the data distribution.

### Select and apply an outlier filtering method

In [None]:
# Function to remove outliers based on percentiles;
# Removes cells with values higher than the 'q_hi' percentile and or lower than the 'q_low' percentile of cells.
def remove_outliers_percentiles(df, columns, upper_quant, lower_quant):
    mask = pd.Series([True] * len(df))
    for col in columns:
        q_low = df[col].quantile(lower_quant)
        q_hi = df[col].quantile(upper_quant)
        col_mask = (df[col] >= q_low) & (df[col] <= q_hi)
        mask = mask & col_mask
    non_outliers = df[mask]
    outliers = df[~mask]
    return non_outliers, outliers


# Function to remove outliers based on zscores;
# Removes cells with zscores higher than 'n_std' (read: n*σ) from the mean on a per-column basis
def remove_outliers_zscore(df, columns, n_std):
    mask = pd.Series([True] * len(df))
    for col in columns:
        mean = df[col].mean()
        sd = df[col].std()
        col_mask = df[col] <= mean + (n_std * sd)
        mask = mask & col_mask
    non_outliers = df[mask]
    outliers = df[~mask]
    return non_outliers, outliers


if "normalised_df" in locals():
    data_to_filter = normalised_df.copy()
    print(
        "Data was previously size-normalised. Filtering outliers from normalised data!"
    )
else:
    data_to_filter = singlecell_df.copy()
    print("Data was not size-normalised. Filtering outliers from original data!")
# Select channel names you wish to have outlier filtered
filter_channels = widgets.SelectMultiple(
    options=singlecell_df.columns,
    rows=10,
    description="Channels",
    style={"description_width": "initial"},
)
binary_select_filter = widgets.RadioButtons(
    options=["Yes", "No"],
    value="No",
    description="Do you wish to display the modified dataframe?",
    style={"description_width": "initial"},
)
outlier_method = widgets.RadioButtons(
    options=["Percentiles", "Z-Score"], description="Outlier Method:", disabled=False
)
n_std_select = widgets.Text(
    value="",
    placeholder="3",
    description="Number of standard deviations:",
    style={"description_width": "initial"},
)
upper_quant_select = widgets.Text(
    value="",
    placeholder="0.997",
    description="Upper percentile limit:",
    style={"description_width": "initial"},
)
lower_quant_select = widgets.Text(
    value="",
    placeholder="0.000",
    description="Lower percentile limit:",
    style={"description_width": "initial"},
)

row1 = widgets.HBox([filter_channels, outlier_method, binary_select_filter])
row2 = widgets.HBox([n_std_select, upper_quant_select, lower_quant_select])

layout = widgets.VBox([row1, row2])
display(layout)

In [None]:
columns_to_filter = list(filter_channels.value)

if outlier_method.value == "Percentiles":
    filtered_df, outliers_df = remove_outliers_percentiles(
        data_to_filter,
        columns_to_filter,
        float(upper_quant_select.value),
        float(lower_quant_select.value),
    )
    print(
        str(len(outliers_df))
        + " cell events were filtered out using percentiles and stored in `outliers_df`!"
    )
elif outlier_method.value == "Z-Score":
    filtered_df, outliers_df = remove_outliers_zscore(
        data_to_filter, columns_to_filter, int(n_std_select.value)
    )
    print(
        str(len(outliers_df))
        + " cell events were filtered out using z-score and stored in `outliers_df`!"
    )

if binary_select_filter.value == "Yes":
    display(filtered_df)

##### Uncomment cell below to save filtered data as CSV (e.g. for quantification) 

In [None]:
# filtered_df.to_csv(os.path.join(PATH, 'filtered.csv'), index=False)

### Plot normalised and or filtered histogram for the selected channel

In [None]:
df_dict = {}

if "normalised_df" in locals() and "filtered_df" in locals():
    df_dict["Normalised data"] = normalised_df
    df_dict["Filtered data"] = filtered_df
elif "normalised_df" in locals():
    df_dict["Normalised data"] = normalised_df
elif "filtered_df" in locals():
    df_dict["Filtered data"] = filtered_df

if not df_dict:
    print(
        "You did not perform size normalisation nor outlier filtering. No data is available to plot a histogram!"
    )

# Plot a normalised and or filtered histogram for the selected channel
histogram_select = widgets.RadioButtons(
    options=singlecell_df.columns,
    description="Histogram channel",
    style={"description_width": "initial"},
)

hist_version_select = widgets.RadioButtons(
    options=list(df_dict.keys()),
    description="Histogram version",
    style={"description_width": "initial"},
)

# Create an Output widget
plot_output = widgets.Output()


def update_hist_version(change):
    data = df_dict.get(hist_version_select.value, singlecell_df)

    with plot_output:
        clear_output(wait=True)
        sns.displot(data, x=histogram_select.value)
        plt.show()


# Update the observe function to reference update_hist_version
histogram_select.observe(update_hist_version, names="value")
hist_version_select.observe(update_hist_version, names="value")

# Call the function once to display the initial plot, passing the dataframe as an argument
update_hist_version(None)

# Display the radio buttons and the plot in an HBox
widgets.HBox([histogram_select, hist_version_select, plot_output])

##### Uncomment the cell below to save the histogram 

In [None]:
# plt.savefig(
#     fname=os.path.join(
#         PATH,
#         str(
#             histogram_select.value
#             + "_histogram_"
#             + hist_version_select.value[0:-5].lower()
#             + ".png"
#         ),
#     ),
#     dpi=600,
# )

## IMPORTANT: Checking which version of the data is present
*(original/size-normalised/outlier filtered/size-normalised + outlier filtered)* \
This cell needs to be run before continuing!

In [None]:
if "filtered_df" in locals():
    print(
        "Data was outlier filtered. Filtered data will be used for downstream analysis!"
    )
elif "normalised_df" in locals():
    print(
        "Data was not outlier filtered but normalised by size. Using size-normalised data for downstream analysis!"
    )
    filtered_df = normalised_df.copy()
else:
    print(
        "Data was neither size-normalised nor outlier filtered. Using original data for downstream analysis!"
    )
    filtered_df = singlecell_df.copy()

## Plot a heatmap for the selected channel overlayed on a raw image
When executed this cell will produce a dynamic representation of channel intensity layered on-top of a raw image. Colours are mapped to indicate channel intensity. \
Upload a .png segmentation mask of your cells and your raw or processed (increasing contrast can greatly improve visability) sample image. Select the channel you wish to visualise and click the 'Load data and visualise' button. \
Move the 'Limits'-slider to adjust upper and lower thresholds for the channel intensity. \
Move the 'Transparency'-slider to adjust the transparency of the overlayed cell intensities.

In [None]:
def read_uploaded_file(upload_widget):
    """Reads an uploaded file widget's content into a numpy array."""
    if upload_widget.value:
        file_info = upload_widget.value[0]
        # Extract the uploaded content and convert to OpenCV-readable format
        arr = np.frombuffer(file_info.content, np.uint8)
        # Decode using OpenCV
        return cv2.imdecode(arr, cv2.IMREAD_UNCHANGED)


@jit(nopython=True)
def update_color_mask_numba(img, color_mask, cell_ids, cell_colors):
    for cell_id, color in zip(cell_ids, cell_colors):
        y_indices, x_indices = np.where(img == cell_id)
        for y, x in zip(y_indices, x_indices):
            color_mask[y, x, :] = color


def plot_heatmap(
    data,
    img,
    color_mask,
    im,
    cell_positions,
    vmin,
    vmax,
    overlay_image_resized,
    transparency,
):
    # Apply the turbo colormap to the selected column
    cmap = plt.cm.turbo
    selected_column = column_select.value

    # Convert the selected column data to RGB color values without normalization
    norm = plt.Normalize(vmin, vmax)
    data_colors = cmap(norm(data[selected_column].values))
    data["color_values"] = [
        tuple(int(v * 255) for v in color[:3]) for color in data_colors
    ]

    # Update the segmentation mask based on the precomputed cell_positions
    cell_ids = np.array(list(cell_positions.keys()))
    cell_colors = np.array(data["color_values"].tolist(), dtype=np.uint8)
    update_color_mask_numba(img, color_mask, cell_ids, cell_colors)

    # Blend the overlay image with the color_mask
    blended_image = cv2.addWeighted(
        color_mask, transparency, overlay_image_resized, 1 - transparency, 0
    )

    # Update the image with the new colormap
    im.set_data(blended_image)
    im.set_clim(vmin, vmax)
    plt.tight_layout()
    plt.draw()


# Define a function to load data and prepare the heatmap visualization
def load_data(btn):
    # Read the segmentation mask
    img = read_uploaded_file(mask_upload)
    if img is None:
        with out:
            clear_output()
            print(
                "Error reading the segmentation mask. Please upload a valid PNG file."
            )
            return

    # Count the number of unique cell IDs in the segmentation mask
    num_cells = len(np.unique(img)) - 1  # subtract 1 to exclude the background
    with out:
        clear_output()
        print(f"Number of segmented objects: {num_cells}")

    # Use the already loaded DataFrame 'filtered_df'
    data = filtered_df

    # Create an empty color image with the same dimensions as the mask
    color_mask = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)

    # Apply the turbo colormap to the selected column
    cmap = cm.get_cmap("turbo")
    selected_column = column_select.value

    # Get the minimum and maximum values of the selected column
    vmin, vmax = data[selected_column].min(), data[selected_column].max()

    # Create a range slider widget to modify the heatmap limits
    limits_slider = widgets.FloatRangeSlider(
        value=[vmin, vmax],
        min=vmin,
        max=vmax,
        step=0.1,
        description="Limits:",
        continuous_update=False,  # Only update when user stops dragging the slider
    )

    # Display the slider widget
    with out:
        display(limits_slider)

    # Create a new FloatSlider widget for transparency control
    transparency_slider = widgets.FloatSlider(
        value=0.6,
        min=0,
        max=1.0,
        step=0.1,
        description="Transparency:",
        continuous_update=False,  # Only update when user stops dragging the slider
    )

    # Display the slider widget
    with out:
        display(transparency_slider)

    # Read the overlay image
    overlay_image = read_uploaded_file(overlay_image_upload)
    if overlay_image is None:
        with out:
            clear_output()
            print("Error reading the overlay image. Please upload a valid TIFF file.")
            return

    # Resize the overlay image to match the segmentation mask dimensions
    overlay_image_resized = cv2.resize(overlay_image, (img.shape[1], img.shape[0]))
    overlay_image_resized = cv2.cvtColor(overlay_image_resized, cv2.COLOR_BGRA2BGR)
    overlay_image_resized = overlay_image_resized.astype(np.uint8)

    # Display the colored segmentation mask
    with out:
        fig, ax = plt.subplots()
        blended_image = cv2.addWeighted(color_mask, 0.6, overlay_image_resized, 0.4, 0)
        im = ax.imshow(blended_image, cmap=cmap, vmin=vmin, vmax=vmax)
        plt.title(f"Heatmap of {selected_column} Marker Signal")
        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label(selected_column)
        plt.tight_layout()
        plt.show()

    # Precompute the cell positions in the DataFrame
    cell_positions = {}
    for index, row in data.iterrows():
        Center_X, Center_Y = int(row["Location_Center_X"]), int(
            row["Location_Center_Y"]
        )
        cell_id = img[Center_Y, Center_X]
        cell_positions[cell_id] = index

    # Setting up a listener to update the plot when the slider values change
    def on_slider_change(change):
        transparency = transparency_slider.value
        vmin, vmax = limits_slider.value
        plot_heatmap(
            data,
            img,
            color_mask,
            im,
            cell_positions,
            vmin,
            vmax,
            overlay_image_resized,
            transparency,
        )

    transparency_slider.observe(on_slider_change, names="value")
    limits_slider.observe(on_slider_change, names="value")

    # Initial call
    on_slider_change(None)


# Create an Output widget for displaying output
out = widgets.Output()

# Upload widgets for the segmentation mask and overlay image
mask_upload = widgets.FileUpload(
    accept=".png",
    multiple=False,
    description="Upload Mask as PNG",
    layout={"width": "auto"},
)

overlay_image_upload = widgets.FileUpload(
    accept=".tiff",
    multiple=False,
    description="Upload raw image as TIFF",
    layout={"width": "auto"},
)

column_select = widgets.Dropdown(
    options=filtered_df.columns,
    description="Select overlay channel:",
    style={"description_width": "initial"},
)

# Load data and visualize button
load_data_button = widgets.Button(
    description="Load data and visualise",
    layout={"width": "auto"},
)

load_data_button.on_click(load_data)

# Combine the widgets into a single layout using HBox
widgets_box = widgets.HBox([mask_upload, column_select, overlay_image_upload])
display(widgets_box, load_data_button, out)

### Select the channels to be included in your clustering
**Note:** If you previously size-normalised your data (see above) it is recommended to leave morphological features unselected

In [None]:
# Select the channels to be included in your clustering here:
cluster_channels = widgets.SelectMultiple(
    options=filtered_df.columns,
    rows=10,
    description="Channels",
    style={"description_width": "initial"},
)

display(cluster_channels)

### Select a scaling method (comment/uncomment)
Standardize data values to bring them onto a comparable scale. This facilitates accurate clustering results. \
Two methods are available: Robust scaling and MinMax scaling \
For more information on data scaling see [here](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html) and [here](https://scikit-learn.org/stable/modules/preprocessing.html)

In [None]:
# Robust scaling (respects outliers/rare cells) to scale data for better clustering results
r_scaler = RobustScaler()
r_scaled = r_scaler.fit_transform(filtered_df.loc[:, list(cluster_channels.value)])

scaled_df = pd.DataFrame(r_scaled, columns=list(cluster_channels.value))

# MinMax scaling (respects outliers/rare cells) to scale data for improved clustering results
# mm_scaler = preprocessing.MinMaxScaler()
# mm_scaled = mm_scaler.fit_transform(filtered_df.loc[:, list(cluster_channels.value)])

# scaled_df = pd.DataFrame(mm_scaled, columns=list(cluster_channels.value))

## Clustering
Run PhenoGraph clustering;<br>
For information on the usage and parameters of PhenoGraph please refer to https://github.com/dpeerlab/phenograph<br>
and or read the corresponding publications: Jacob H. Levine et al., __[Phenograph](https://doi.org/10.1016/j.cell.2015.05.047)__ and V. A. Traag et al., __[Leiden algorithm](https://doi.org/10.1038/s41598-019-41695-z)__;<br>
Before clustering, make sure you are only including the desired channels (see cell above);<br>
<br>
**Note:** Leaving the 'Initial seed' parameter unset will result in a random starting seed and therefore will lead to slightly different results each run

In [None]:
resolution_parameter_text = widgets.Text(
    value="",
    placeholder="1.0",
    description="resolution parameter:",
    style={"description_width": "initial"},
)

k_text = widgets.Text(
    value="",
    placeholder="30",
    description="k-value:",
    style={"description_width": "initial"},
)

seed_text = widgets.Text(
    value="",
    placeholder="42",
    description="initial seed:",
    style={"description_width": "initial"},
)

update_button = widgets.Button(
    description="Run Clustering",
    layout={"width": "max-content"},
)

output_widget = widgets.Output()


# Callback function to run the clustering algorithm
def run_clustering(button):
    global communities, graph, Q

    try:
        k = int(k_text.value)
    except ValueError:
        print("Invalid k-value. Please enter an integer.")
        return

    try:
        resolution_parameter = float(resolution_parameter_text.value)
    except ValueError:
        print("Invalid resolution parameter. Please enter a float.")
        return

    try:
        # If seed_text is empty, set seed to None
        seed = None if seed_text.value == "" else int(seed_text.value)
    except ValueError:
        print("Invalid seed. Please enter an integer or leave it empty.")
        return

    with output_widget:
        output_widget.clear_output()
        communities, graph, Q = phenograph.cluster(
            scaled_df,
            clustering_algo="leiden",
            k=k,
            resolution_parameter=resolution_parameter,
            seed=seed,
        )

        # Update filtered_df after running the clustering algorithm
        filtered_df["cluster"] = pd.Categorical(communities)


update_button.on_click(run_clustering)

# Display the widgets
widgets.VBox(
    [
        k_text,
        resolution_parameter_text,
        seed_text,
        update_button,
        output_widget,
    ]
)

## Dimensionality reduction using UMAP
Run and visualise UMAP embedding;<br>
Default settings for UMAPs hyperparemeters are used. If you wish to finetune UMAPs parameters, uncomment and execute the collapsed cell below. <br>
For information on the usage and parameters of UMAP please refer to the documentation https://umap-learn.readthedocs.io/en/latest/index.html<br>
and or read the corresponding publication: Jacob H. Levine et al., https://doi.org/10.48550/arXiv.1802.03426;<br>
<br>
**Note:** Leaving the 'random initiation state' parameter unset will result in a random starting state and therefore will lead to slightly different results each run

In [None]:
# Perform UMAP embedding
reducer = umap.UMAP(random_state=42)
embedding = reducer.fit_transform(scaled_df)

# Convert embedding to a DataFrame
embedding_df = pd.DataFrame(embedding, columns=["UMAP1", "UMAP2"])

# Add cluster column to the DataFrame
embedding_df["cluster"] = pd.Categorical(
    communities, ordered=True, categories=sorted(set(communities))
)

# Plot the UMAP embedding with cluster colors
fig = px.scatter(
    embedding_df,
    x="UMAP1",
    y="UMAP2",
    color="cluster",
    hover_data=[embedding_df.index],
    color_continuous_scale="jet",
    title="UMAP plot colored by cluster",
    width=1000,
    height=800,
)

fig.show()

## Dimensionality reduction using UMAP (with hyperparameters)
Uncomment the cell below if you wish to change UMAP paramters 'on-the-fly'

In [None]:
# # Uncomment and execute this cell if you wish to change UMAP paramters 'on-the-fly'

# n_neighbors_text = widgets.Text(
#     value="",
#     placeholder="15",
#     description="n_neighbors:",
#     style={"description_width": "initial"},
# )

# min_dist_text = widgets.Text(
#     value="",
#     placeholder="0.1",
#     description="Minimum distance between points:",
#     style={"description_width": "initial"},
# )

# random_state_text = widgets.Text(
#     value="",
#     placeholder="42",
#     description="random initial state:",
#     style={"description_width": "initial"},
# )

# update_button = widgets.Button(
#     description="Run UMAP embedding",
#     layout={"width": "max-content"},
# )

# output_widget = widgets.Output()


# # Callback function to run the clustering algorithm
# def run_umap(button):
#     global reducer, embedding, embedding_df

#     try:
#         n_neighbors = int(n_neighbors_text.value)
#     except ValueError:
#         print("Invalid n_neighbors-value. Please enter an integer.")
#         return

#     try:
#         min_dist = float(min_dist_text.value)
#     except ValueError:
#         print("Invalid minimum distance parameter. Please enter a float.")
#         return

#     try:
#         random_state = (
#             None if random_state_text.value == "" else int(random_state_text.value)
#         )
#     except ValueError:
#         print("Invalid random state. Please enter an integer.")
#         return

#     with output_widget:
#         output_widget.clear_output()

#         # Perform UMAP embedding
#         reducer = umap.UMAP(random_state=random_state)
#         embedding = reducer.fit_transform(scaled_df)

#         # Convert embedding to a DataFrame
#         embedding_df = pd.DataFrame(embedding, columns=["UMAP1", "UMAP2"])

#         # Add cluster column to the DataFrame
#         embedding_df["cluster"] = pd.Categorical(
#             communities, ordered=True, categories=sorted(set(communities))
#         )

#         # Plot the UMAP embedding with cluster colors
#         fig = px.scatter(
#             embedding_df,
#             x="UMAP1",
#             y="UMAP2",
#             color="cluster",
#             hover_data=[embedding_df.index],
#             color_continuous_scale="jet",
#             title="UMAP plot colored by cluster",
#             width=1000,
#             height=800,
#         )

#         fig.show()


# update_button.on_click(run_umap)

# # Display the widgets
# widgets.VBox(
#     [
#         n_neighbors_text,
#         min_dist_text,
#         random_state_text,
#         update_button,
#         output_widget,
#     ]
# )

## Cluster analysis using a  a cluster heatmap
The sections below are used to create a cluster heatmap of the previously generated clusters. This can aide in phenotyping cells or distinguishing whole tissue areas.
The heatmap is generated in the following way:

1. The mean or median of each channel for every cluster is calculated
2. Z-scores (the number of standard deviations 'away' from the mean) are calculated for each previously calculated mean/median of each channel for each cluster
3. The calculated values are visualised as coloured squares ranging from red to blue - red: z-scores higher than 0 (the mean of all clusters); blue z-scores lower than 0

### Select whether median or mean values should be used to form the cluster heatmap (comment/uncomment)
By default the median is chosen because it is less affected by outliers.

In [None]:
# Creates a list containing an individual dataframe for each unique cluster of cells
cluster_list = []
for c in range(len(list(filtered_df["cluster"].unique()))):
    cluster_list.append(filtered_df[filtered_df["cluster"] == c])

# Calculates the median value per channel for each cluster;
# Uncomment the bottom section to use mean instead of median
median_clusters = pd.DataFrame()
for i in cluster_list:
    median_clusters = pd.concat(
        [
            median_clusters,
            pd.DataFrame(i.median(axis=0, numeric_only=True)).transpose(),
        ],
        ignore_index=True,
    )

# Applies zscore scaling to the median of all clusters for better heatmap representation;
z_scaler = StandardScaler()
z_scaled = z_scaler.fit_transform(median_clusters)
heatmap_clusters = pd.DataFrame(z_scaled, columns=median_clusters.columns)

# Calculates the mean value per channel for each cluster;
# mean_clusters = pd.DataFrame()
# for i in cluster_list:
#     mean_clusters = pd.concat(
#         [
#             mean_clusters,
#             pd.DataFrame(i.mean(axis=0, numeric_only=True)).transpose()
#         ],
#         ignore_index=True,
#     )

# Applies zscore scaling to the mean of all clusters for better heatmap representation;
# z_scaler = StandardScaler()
# z_scaled = z_scaler.fit_transform(mean_clusters)
# heatmap_clusters = pd.DataFrame(z_scaled, columns=mean_clusters.columns)

### Select channels to be included and plot a cluster heatmap
Any of the selected channels will be included in the final cluster heatmap as described above

In [None]:
# Select channels and plot a clusterheatmap
heatmap_channels = widgets.SelectMultiple(
    options=singlecell_df.columns,
    rows=10,
    description="Channels",
    style={"description_width": "initial"},
)

heatmap_title = widgets.Text(
    value="",
    placeholder="Heatmap title",
    description="Title:",
    style={"description_width": "initial"},
)

update_button = widgets.Button(
    description="Update Heatmap",
    disabled=False,
    button_style="",
    tooltip="Click to update the heatmap",
    icon="refresh",
)

out_heatmap = widgets.Output()


def generate_heatmap_figure(channel_order, title):
    final_df = heatmap_clusters.loc[:, channel_order]
    cluster_sizes = []
    for n in cluster_list:
        cluster_sizes.append(len(n))

    cmap = sns.diverging_palette(255, 12, as_cmap=True)

    fig, ax = plt.subplots(figsize=(13, 10))

    ax = sns.heatmap(
        final_df,
        cmap=cmap,
        annot=False,
        linewidths=0.5,
        linecolor="black",
        cbar_kws={"pad": 0.15},
        center=0.0,
    )

    ax2 = ax.twinx()

    sns.heatmap(
        final_df,
        cmap=cmap,
        annot=False,
        linewidths=0.5,
        linecolor="black",
        ax=ax2,
        yticklabels=cluster_sizes,
        cbar=False,
        center=0.0,
    )

    xtick_labels = [label for label in channel_order]
    ax.set_xticks([*np.arange(0.5, len(final_df.columns), 1)])
    ax.set_xticklabels(xtick_labels, rotation=45, ha="right", rotation_mode="anchor")

    ax.tick_params(axis="y", rotation=0)
    ax2.tick_params(axis="y", rotation=0)

    ax2.set(ylabel="cell counts per cluster", title=title)

    fig.tight_layout()

    return fig


def update_heatmap(*args):
    with out_heatmap:
        clear_output(wait=True)

        channel_order = heatmap_channels.value

        if not channel_order:
            print("Please select at least one channel.")
            return

        generate_heatmap_figure(channel_order, heatmap_title.value)
        plt.show()


update_button.on_click(update_heatmap)

# Display
widgets.VBox(
    [widgets.HBox([heatmap_channels, heatmap_title, update_button]), out_heatmap],
)

##### Uncomment cell below to save the clustermap plot as a PNG file

In [None]:
# plt.savefig(fname=os.path.join(PATH, "clustermap.png"), dpi=600)

### Select Cluster and channel to be displayed in a histogram below
**Note**:  
The following cell uses column names produced by Cellprofiler for 'Area' (AreaShape_Area) and 'Diameter'(AreaShape_EquivalentDiameter). If the single-cell CSV data was produced by other software or the column names were manually changed, the following strings (emphasised) must be adapted to represent the 'Area' and (optionally, otherwise can be omitted) 'Diameter' columns of the data.  
  
**\["AreaShape_Area"\]**

**\["AreaShape_EquivalentDiameter"\]**

In [None]:
export_df = filtered_df.copy()

# Comment this code to plot and later save size-normalised data
if pixelsize_select.value != "":
    export_df["AreaShape_Area"] = export_df["AreaShape_Area"].div(float(pixelsize_select.value) ** 2)
    export_df["AreaShape_EquivalentDiameter"] = export_df["AreaShape_EquivalentDiameter"].div(
        float(pixelsize_select.value)
    )
    export_df[cols_to_divide] = export_df[cols_to_divide].mul(export_df["AreaShape_Area"], axis=0)

# Select Cluster and channel to be displayed in a histogram below
cluster_select = widgets.RadioButtons(
    options=np.sort(filtered_df["cluster"].cat.codes.unique()),
    description="Histogram cluster",
    style={"description_width": "initial"},
)

channel_select = widgets.RadioButtons(
    options=singlecell_df.columns,
    description="Histogram channel",
    style={"description_width": "initial"},
)


# Update the observe function for the new widgets
cluster_select.observe(
    lambda change: update_histogram(
        change, export_df, channel_select.value, cluster_select.value
    ),
    names="value",
)
channel_select.observe(
    lambda change: update_histogram(
        change, export_df, channel_select.value, cluster_select.value
    ),
    names="value",
)

# Display the new widgets
display(widgets.HBox([cluster_select, channel_select, plot_output]))

# Call the function once to display the initial plot, passing the dataframe as an argument
update_histogram(None, export_df, channel_select.value, cluster_select.value)

##### Uncomment cell below to save the histogram from above

In [None]:
# plt.savefig(
#     fname=os.path.join(
#         PATH,
#         "cluster"
#         + str(cluster_select.value)
#         + "_"
#         + channel_select.value
#         + "_hist.png",
#     ),
#     dpi=600,
# )

### Select clusters you wish to save as a .csv file

In [None]:
# Select clusters you wish to save as a .csv file
# e.g. for quantification with BIO-logi-CAL standards
quant_clusters = widgets.SelectMultiple(
    options=np.sort(export_df["cluster"].cat.codes.unique()),
    rows=10,
    description="Clusters",
    style={"description_width": "initial"},
)

display(quant_clusters)

##### Uncomment cell below to save selected clusters as .csv files

In [None]:
for c in quant_clusters.value:
    cluster_df = export_df[export_df["cluster"] == c]
    cluster_df.to_csv(os.path.join(PATH, "cluster_" + str(c) + ".csv"), index=False)

# Uncomment code below to save a filtered .csv file of all clusters combined
# export_df.to_csv(os.path.join(PATH, "filtered_SC_data.csv"), index=False)

## Plot a heatmap of a saved cluser for the selected channel overlayed on a raw image
This is a copy of the previous 'heatmap overlay cell'. However, this time one can select a cluster as input data rather than the whole dataset.

In [None]:
def read_uploaded_file(upload_widget):
    """Reads an uploaded file widget's content into a numpy array."""
    if upload_widget.value:
        file_info = upload_widget.value[0]
        # Extract the uploaded content and convert to OpenCV-readable format
        arr = np.frombuffer(file_info.content, np.uint8)
        # Decode using OpenCV
        return cv2.imdecode(arr, cv2.IMREAD_UNCHANGED)
    return None


def read_uploaded_csv(upload_widget):
    """Reads an uploaded file widget's content into a dataframe."""
    if upload_widget.value:
        file_info = upload_widget.value[0]
        content = io.BytesIO(file_info.content)
        return pd.read_csv(content)
    return None


@jit(nopython=True)
def update_color_mask_numba(img, color_mask, cell_ids, cell_colors):
    for cell_id, color in zip(cell_ids, cell_colors):
        y_indices, x_indices = np.where(img == cell_id)
        for y, x in zip(y_indices, x_indices):
            color_mask[y, x, :] = color


def plot_heatmap(
    data,
    img,
    color_mask,
    im,
    cell_positions,
    vmin,
    vmax,
    overlay_image_resized,
    transparency,
):
    # Apply the turbo colormap to the selected column
    cmap = plt.cm.turbo
    selected_column = column_select.value

    # Convert the selected column data to RGB color values without normalization
    norm = plt.Normalize(vmin, vmax)
    data_colors = cmap(norm(data[selected_column].values))
    data["color_values"] = [
        tuple(int(v * 255) for v in color[:3]) for color in data_colors
    ]

    # Update the segmentation mask based on the precomputed cell_positions
    cell_ids = np.array(list(cell_positions.keys()))
    cell_colors = np.array(data["color_values"].tolist(), dtype=np.uint8)
    update_color_mask_numba(img, color_mask, cell_ids, cell_colors)

    # Blend the overlay image with the color_mask
    blended_image = cv2.addWeighted(
        color_mask, transparency, overlay_image_resized, 1 - transparency, 0
    )

    # Update the image with the new colormap
    im.set_data(blended_image)
    im.set_clim(vmin, vmax)
    plt.tight_layout()
    plt.draw()


# Define a function to load data and prepare the heatmap visualization
def load_data(btn):
    # Read the segmentation mask
    img = read_uploaded_file(mask_upload)
    if img is None:
        with out:
            clear_output()
            print(
                "Error reading the segmentation mask. Please upload a valid PNG file."
            )
            return

    # Count the number of unique cell IDs in the segmentation mask
    num_cells = len(np.unique(img)) - 1  # subtract 1 to exclude the background
    with out:
        clear_output()
        print(f"Number of segmented objects: {num_cells}")

    data = read_uploaded_csv(data_upload)
    if data is None:
        with out:
            clear_output()
            print("Error reading the data. Please upload a valid CSV file.")
            return

    # Create an empty color image with the same dimensions as the mask
    color_mask = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)

    # Apply the turbo colormap to the selected column
    cmap = plt.cm.turbo
    selected_column = column_select.value

    # Get the minimum and maximum values of the selected column
    vmin, vmax = data[selected_column].min(), data[selected_column].max()

    # Create a range slider widget to modify the heatmap limits
    limits_slider = widgets.FloatRangeSlider(
        value=[vmin, vmax],
        min=vmin,
        max=vmax,
        step=0.1,
        description="Limits:",
        continuous_update=False,  # Only update when user stops dragging the slider
    )

    # Display the slider widget
    with out:
        display(limits_slider)

    # Create a new FloatSlider widget for transparency control
    transparency_slider = widgets.FloatSlider(
        value=0.6,
        min=0,
        max=1.0,
        step=0.1,
        description="Transparency:",
        continuous_update=False,  # Only update when user stops dragging the slider
    )

    # Display the slider widget
    with out:
        display(transparency_slider)

    # Read the overlay image
    overlay_image = read_uploaded_file(overlay_image_upload)
    if overlay_image is None:
        with out:
            clear_output()
            print("Error reading the overlay image. Please upload a valid PNG file.")
            return

    # Resize the overlay image to match the segmentation mask dimensions
    overlay_image_resized = cv2.resize(overlay_image, (img.shape[1], img.shape[0]))
    overlay_image_resized = cv2.cvtColor(overlay_image_resized, cv2.COLOR_BGRA2BGR)
    overlay_image_resized = overlay_image_resized.astype(np.uint8)

    # Display the colored segmentation mask
    with out:
        fig, ax = plt.subplots()
        blended_image = cv2.addWeighted(color_mask, 0.6, overlay_image_resized, 0.4, 0)
        im = ax.imshow(blended_image, cmap=cmap, vmin=vmin, vmax=vmax)
        plt.title(f"Heatmap of {selected_column} Marker Signal")
        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label(selected_column)
        plt.tight_layout()
        plt.show()

    # Precompute the cell positions in the DataFrame
    cell_positions = {}
    for index, row in data.iterrows():
        Center_X, Center_Y = int(row["Location_Center_X"]), int(
            row["Location_Center_Y"]
        )
        cell_id = img[Center_Y, Center_X]
        cell_positions[cell_id] = index

    # Setting up a listener to update the plot when the slider values change
    def on_slider_change(change):
        transparency = transparency_slider.value
        vmin, vmax = limits_slider.value
        plot_heatmap(
            data,
            img,
            color_mask,
            im,
            cell_positions,
            vmin,
            vmax,
            overlay_image_resized,
            transparency,
        )

    transparency_slider.observe(on_slider_change, names="value")
    limits_slider.observe(on_slider_change, names="value")

    # Initial call
    on_slider_change(None)


# Create an Output widget for displaying output
out = widgets.Output()

# Upload widgets for the segmentation mask and overlay image
mask_upload = widgets.FileUpload(
    accept=".png",
    multiple=False,
    description="Upload Mask as PNG",
    layout={"width": "auto"},
)

overlay_image_upload = widgets.FileUpload(
    accept=".tiff",
    multiple=False,
    description="Upload raw image as TIFF",
    layout={"width": "auto"},
)

data_upload = widgets.FileUpload(
    accept=".csv",
    multiple=False,
    description="Upload Data as CSV",
    layout={"width": "auto"},
)

column_select = widgets.Dropdown(
    options=filtered_df.columns,
    description="Select overlay channel:",
    style={"description_width": "initial"},
)

# Load data and visualize button
load_data_button = widgets.Button(
    description="Load data and visualize",
    layout={"width": "auto"},
)

load_data_button.on_click(load_data)

# # Call update_files function when the path is changed
# singlecell_path.observe(update_files, names="value")

# Combine the widgets into a single layout using HBox
widgets_box = widgets.HBox(
    [data_upload, mask_upload, column_select, overlay_image_upload]
)
display(widgets_box, load_data_button, out)