In [None]:
import codecs
import os

# Optional: Disable warnings
import warnings
from io import StringIO

warnings.filterwarnings("ignore")

import ipywidgets as widgets
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display
from ipywidgets import FileUpload, interact
from sklearn.cluster import KMeans
from sklearn.preprocessing import MinMaxScaler, RobustScaler

%matplotlib widget

## Load your pixel-based data

In [None]:
# Define a custom CSS style for the HBox
style = {
    "description_width": "initial",
    "flex_flow": "row wrap",
    "justify_content": "space-between",
}

# Enter the path of your present working directory
image_path = widgets.Text(
    value="",
    placeholder="Enter your working directory path",
    description="Path:",
    style={"description_width": "initial"},
)


def update_csv_files(change):
    if image_path.value:
        csv_files = [f for f in os.listdir(image_path.value) if f.endswith(".csv")]
        csv_select.options = csv_files
    else:
        csv_select.options = []


image_path.observe(update_csv_files, names="value")

# Select which .csv file contains your single-cell data
csv_select = widgets.RadioButtons(options=[], description="Image data", disabled=False)

widgets_box = widgets.HBox(
    [image_path, csv_select],
)

display(widgets_box)

In [None]:
# Reads your single_cell.csv data as a pandas dataframe
PATH = image_path.value
PX_DATA = csv_select.value

file_path = os.path.join(PATH, PX_DATA)
image_df = pd.read_csv(file_path, sep=",")
image_df = image_df.dropna(axis=1)

## Kmeans clustering of selected channel(s)

### Select a scaling method
Standardize data values to bring them onto a comparable scale. This facilitates accurate clustering results. \
Two methods are available: Robust scaling and MinMax scaling \
For more information on data scaling see [here](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html) and [here](https://scikit-learn.org/stable/modules/preprocessing.html)

In [None]:
number_clusters = widgets.Text(
    value="",
    placeholder="Desired number of clusters",
    description="Number of clusters:",
    style={"description_width": "initial"},
)

channels_cluster = widgets.SelectMultiple(
    options=image_df.columns,
    description="Channels for clustering",
    disabled=False,
    rows=10,
    style={"description_width": "initial"},
    # layout=widgets.Layout(width="auto", height="auto"),
)

scaler_select = widgets.RadioButtons(
    options=["RobustScaler", "MinMaxScaler"],
    description="Scaling method",
    style={"description_width": "initial"},
)

widgets.HBox([number_clusters, channels_cluster, scaler_select])

In [None]:
cluster_df = image_df[list(channels_cluster.value)]

if len(list(channels_cluster.value)) == 1:
    cluster_df = cluster_df.values.reshape(-1, 1)
else:
    None

# Scaling the data
if scaler_select.value == "RobustScaler":
    scaler = RobustScaler()
    cluster_std = scaler.fit_transform(cluster_df)
else:
    scaler = MinMaxScaler()
    cluster_std = scaler.fit_transform(cluster_df)

# Run KMeans
kmeans = KMeans(n_clusters=int(number_clusters.value), random_state=0).fit(cluster_std)

# Add the labels to the dataframe
image_df["label"] = kmeans.labels_

print("Clustering has been completed!")

In [None]:
img_width = widgets.Text(
    value="",
    placeholder="Enter your image width in pixels",
    description="Width:",
    style={"description_width": "initial"},
)

img_height = widgets.Text(
    value="",
    placeholder="Enter your image height in pixels",
    description="Height:",
    style={"description_width": "initial"},
)

widgets.HBox([img_width, img_height])

#### Plot and save clusters

In [None]:
# Get the labels from the dataframe
labels = image_df["label"].values.astype(int)

# Reshape the labels into the original image shape
image = labels.reshape((int(img_height.value), int(img_width.value)), order="C")

# Display the image
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis("off")
plt.tight_layout()
mpl.rcParams["savefig.pad_inches"] = 0
plt.show()

#### Uncomment the cell below to save the figure

In [None]:
# plt.savefig(fname=os.path.join(PATH, "kmeans_clusters.png"), dpi=600)

#### Print and save simple statistics

In [None]:
def generate_info(cluster_label, channel, df):
    mean_val = df.loc[df["label"] == cluster_label, channel].mean()
    min_val = df.loc[df["label"] == cluster_label, channel].min()
    max_val = df.loc[df["label"] == cluster_label, channel].max()

    if cluster_label == 0:
        cluster_name = "background (cluster 0)"
    else:
        cluster_name = f"cluster_{cluster_label}"

    info = (
        f"The average single pixel value in counts for {channel} in the {cluster_name} is: {mean_val}\n"
        f"The minimum single pixel value in counts for {channel} in the {cluster_name} is: {min_val}\n"
        f"The maximum single pixel value in counts for {channel} in the {cluster_name} is: {max_val}"
    )
    return info


info_strings = []

for i in range(0, labels.max() + 1):
    for j in list(channels_cluster.value):
        info = generate_info(i, j, image_df)
        print(info)
        info_strings.append(info)
    print()
    info_strings.append("")

#### Uncomment the section below to save the information to a .txt file

In [None]:
# filename = os.path.join(PATH, "kmeans_cluster_statistics.txt")
# with open(filename, "w") as file:
#     for line in info_strings:
#         file.write(line + "\n")

## Correlation heatmap
Select channels for which you'd like to calculate spearman correlation coefficients and draw a correlation map

In [None]:
# Select channel names you wish to calculate a correlation coefficient for
correlation_channels = widgets.SelectMultiple(
    options=image_df.columns,
    rows=10,
    description="Channels",
    style={"description_width": "initial"},
)

correlation_select = widgets.RadioButtons(
    options=["spearman", "pearson"],
    description="Correlation method",
    style={"description_width": "initial"},
)

widgets.HBox([correlation_channels, correlation_select])

In [None]:
# Apply the square root transformation
transformed_data = image_df[list(correlation_channels.value)].applymap(np.sqrt)

# Scale the transformed data
scaled_data = pd.DataFrame(
    scaler.fit_transform(transformed_data), columns=list(correlation_channels.value)
)


def plot_correlation_heatmap(df, channels, method):
    corr_matrix = df[channels].corr(method=method)
    sns.set(font_scale=0.5)
    fig, ax = plt.subplots()

    sns.heatmap(
        corr_matrix,
        vmin=-1,
        vmax=1,
        center=0,
        cmap=sns.diverging_palette(20, 220, n=200),
        square=True,
        annot=True,
        ax=ax,
    )

    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment="right")

    plt.tight_layout()
    plt.show()


plot_correlation_heatmap(
    scaled_data, list(correlation_channels.value), correlation_select.value
)

# Uncomment the cell below to save the figure

In [None]:
# plt.savefig(fname=os.path.join(PATH, correlation_select.value + "_correlation_heatmap.png"), dpi=600)