In [None]:
import os
import threading
from io import BytesIO

import ipywidgets as widgets
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from IPython.display import clear_output, display
from ipywidgets import Label
from scipy.stats import gaussian_kde

## Gating
This notebook provides manual population gating functionality comparable to flow-/mass cytometry data analysis.  
It uses single-cell data to plot two cell features against one another in an interactive scatterplot.  
Subsequently, polygon gates can be drawn onto the plot to separate and save individual cell populations.

In [None]:
singlecell_df = pd.DataFrame
dataframes_dict = {}


def handle_file_upload(change):
    with output_widget:
        clear_output()
        global singlecell_df
        uploaded_file = upload_widget.value[0]
        content = uploaded_file["content"]
        singlecell_df = pd.read_csv(BytesIO(content))
        dataframes_dict["initial population"] = singlecell_df

        display(singlecell_df.head())


output_widget = widgets.Output()
upload_widget = widgets.FileUpload(
    accept=".csv",
    multiple=False,
    description="Upload single-cell CSV",
    layout={"width": "auto"},
)


display(upload_widget, output_widget)
upload_widget.observe(handle_file_upload, names="value")

In [None]:
# Widgets section

x_dropdown = widgets.Dropdown(
    options=singlecell_df.columns,
    description="X-Axis:",
    value=singlecell_df.columns[0],
)

y_dropdown = widgets.Dropdown(
    options=singlecell_df.columns, description="Y-Axis:", value=singlecell_df.columns[1]
)

population_dropdown = widgets.Dropdown(
    options=list(dataframes_dict.keys()),
    description="Population:",
    value=list(dataframes_dict.keys())[0],
)

gate_name = widgets.Text(
    description="Gate name:",
    style={"description_width": "initial", "width": "auto"},
)

residual_name = widgets.Text(
    description="Residual name:",
    style={"description_width": "initial", "width": "auto"},
)

extract_button = widgets.Button(
    description="Save gate", style={"button_color": "#98FB98"}
)

save_residual_button = widgets.Button(
    description="Save Residual Population", style={"button_color": "#FF6347"}, layout={"width": "auto"}
)

warning_label = Label()

In [None]:
# Calculate point density
def calculate_density(x, y):
    xy = np.vstack([x, y])
    try:
        kde = gaussian_kde(xy)
        density = kde(xy)
    except np.linalg.LinAlgError:
        density = None
    return density


point_density = calculate_density(singlecell_df[x_dropdown.value], singlecell_df[y_dropdown.value])

fig = go.FigureWidget(
    data=go.Scatter(
        x=singlecell_df[x_dropdown.value],
        y=singlecell_df[y_dropdown.value],
        mode="markers",
        marker=dict(color=point_density, colorscale="Turbo", showscale=False, size=3),
    ),
    layout=go.Layout(
        xaxis=dict(title=x_dropdown.value),
        yaxis=dict(title=y_dropdown.value),
        margin=dict(l=1, r=1, t=2, b=5),
        height=700,
    ),
)

In [None]:
# Global variables
residual_timer = None
selected_data = pd.DataFrame()
residual_data = pd.DataFrame()


def update_plot(change=None, data=None):
    clear_output(wait=True)
    if population_dropdown.value in dataframes_dict and data is None:
        target_df = dataframes_dict[population_dropdown.value]
    else:
        target_df = data if data is not None else singlecell_df

    if target_df.shape[0] > 1:
        point_density = calculate_density(
            target_df[x_dropdown.value], target_df[y_dropdown.value]
        )
    else:
        point_density = "blue"

    with fig.batch_update():
        fig.data[0].x = target_df[x_dropdown.value]
        fig.data[0].y = target_df[y_dropdown.value]
        fig.data[0].marker.color = point_density
        fig.update_layout(
            xaxis_title=x_dropdown.value,
            yaxis_title=y_dropdown.value,
            margin=dict(l=1, r=1, t=2, b=5),
            height=700,
        )


# Function to detect selected and unselected data
def selection_fn(trace, points, selector):
    global selected_data
    global residual_data
    selected_points = np.array(points.point_inds)

    if len(selected_points) == 0:
        print("No points selected")
    else:
        selected_data = dataframes_dict[population_dropdown.value].iloc[selected_points]
        residual_data = dataframes_dict[population_dropdown.value].drop(
            selected_data.index
        )


def clear_warning_label(_=None):
    warning_label.value = ""
    save_residual_button.layout.display = "none"
    residual_name.layout.display = "none"


def save_gate(_=None):
    global residual_timer
    clear_output(wait=True)
    if not gate_name.value:
        warning_label.value = "Warning: No name provided for the gate."
        return
    if gate_name.value in dataframes_dict:
        warning_label.value = (
            "Warning: Name already present. Please use a different name."
        )
        return

    name = gate_name.value
    dataframes_dict[name] = selected_data
    # current_options = list(dataframes_dict.keys())
    population_dropdown.options = list(dataframes_dict.keys())
    population_dropdown.value = name
    update_plot(data=selected_data)

    warning_label.value = "Do you want to save the residual population? Enter a name and click the button below."
    residual_name.layout.display = "block"
    save_residual_button.layout.display = "block"

    if residual_timer:
        residual_timer.cancel()
    residual_timer = threading.Timer(10.0, clear_warning_label)
    residual_timer.start()


def save_residual_population(_=None):
    global residual_timer
    if not residual_name.value:
        warning_label.value = "Warning: No name provided for the residual population."
        return
    if residual_name.value in dataframes_dict:
        warning_label.value = (
            "Warning: Name already present. Please use a different name."
        )
        return

    dataframes_dict[residual_name.value] = residual_data
    population_dropdown.options = list(dataframes_dict.keys())


def on_residual_name_change(change):
    global residual_timer
    if residual_timer:
        residual_timer.cancel()
        residual_timer = None


# Event handlers
fig.data[0].on_selection(selection_fn)
extract_button.on_click(save_gate)

x_dropdown.observe(update_plot, names="value")
y_dropdown.observe(update_plot, names="value")
population_dropdown.observe(update_plot, names="value")

save_residual_button.on_click(save_residual_population)
save_residual_button.on_click(clear_warning_label)
residual_name.observe(on_residual_name_change, names="value")
save_residual_button.layout.display = "none"
residual_name.layout.display = "none"

# display
display(
    widgets.VBox(
        [
            widgets.HBox([x_dropdown, y_dropdown, population_dropdown]),
            fig,
            widgets.HBox([gate_name, extract_button]),
            widgets.HBox([residual_name, save_residual_button]),
            warning_label,
        ]
    )
)

### Export of gated populations

In [None]:
base_name = upload_widget.value[0].name[:-4]

path_widget = widgets.Text(description="Save Path", disabled=False)

select_widget = widgets.SelectMultiple(
    options=list(dataframes_dict.keys()),
    description="Select DataFrames",
    disabled=False,
)

save_button = widgets.Button(description="Save Selected DataFrames")


# Function to save selected DataFrames
def save_selected_dataframes(b):
    selected_gates = select_widget.value
    for key in selected_gates:
        filename = f"{base_name}_{key}.csv"
        filepath = os.path.join(path_widget.value, filename)
        dataframes_dict[key].to_csv(filepath, index=False)
        print(f"Saved {filename} to {path_widget.value}")


display(path_widget, select_widget, save_button)
save_button.on_click(save_selected_dataframes)