# Analysis of the Choroid Plexus

In [None]:
from tabulate import tabulate
# %matplotlib qt4
%matplotlib inline

%load_ext autoreload
%autoreload 2

from variables import LABEL_TO_NAME, LABEL_TO_MAIN_NAME, PLOT_DIR
from src._types import SavedModelTypes
from variables import MS_TYPE_COLORS, MS_TYPE_MAIN_COLORS, MS_TYPES, MS_TYPE_MAIN_EDGECOLORS, MS_TYPE_EDGECOLORS, MAIN_NAME_TO_LABEL
from statsmodels.formula.api import ols


from tqdm import tqdm
from src.utils.visualisation import add_cloud_plot_to_axs

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pingouin as pg
import torch

import os
import math

plt.rcParams['axes.facecolor'] = 'white'
%matplotlib inline


## Settings
Please provide the setup for running the analysis

In [None]:
# main directories
checkpoint_dir = "../../models/"
data_dir = "../../data/analysis_data/"
results_dir = "../../reports/results/"
figures_dir = "../../reports/figures/"
original_folders = ["original_samples_test", "original_samples_train", "original_samples_val"]

# general settings
device = "cpu"
ignore_geometric_columns = ["idx", "bbox-0", "bbox-1", "bbox-2", "bbox-3", "bbox-4", "bbox-5"]

In [None]:
def preprocess_geometric_df(df: pd.DataFrame) -> pd.DataFrame:
    drop_columns = [column for column in ignore_geometric_columns if column in df.columns]
    df = df.drop(drop_columns, axis=1)
    df.columns = df.columns.str.replace('-', '_')
    df["label_name"] = df["label"].map(LABEL_TO_MAIN_NAME)
    df["initial_label_name"] = df["initial_label"].map(LABEL_TO_NAME)
    return df.drop(["label", "initial_label"], axis=1)

def prepare_geometric_dfs(df_name: str):
    """
    Prepare the geometric dataframes for analysis
    """
    df = pd.read_csv(os.path.join(data_dir, df_name))
    return preprocess_geometric_df(df)

In [None]:
# load model independent files
evaluation_metrics_df = pd.read_csv(os.path.join(data_dir, "evaluation_metrics.csv"))
geometrics_features_df = prepare_geometric_dfs("geometric_features.csv")
geometrics_features_df.head()

In [None]:
  # Model for analysis (please create before all necessary files by running preparation, evaluation and analysis pipelines, please check that you set paths accordingly)
model_name = SavedModelTypes.bdae_20_2048_5_cond_encoder_shift_scale_alpha
# model_params = {"kld": 50, "ch": 64}
# model_checkpoint = "final_cvae_ch64_kld50_base_20250709_161140_best"
prototypes_filename = f"prototypes_{model_name.value}.pkl"
prototypes_3D_filename = f"prototypes_3D_{model_name.value}.npz"
cluster_prototypes_filename = f"cluster_prototypes.npz"
prototypes_initial_3D_filename = f"prototypes_initial_3D_{model_name.value}.npz"
latents_filename = f"features_{model_name.value}.pkl"
samples_folder = f"samples_{model_name.value}"
samples_folder_full = f"samples_{model_name.value}_full"
sampled_geometrics_features_df = prepare_geometric_dfs(f"{samples_folder}/geometric_features.csv")
sampled_geometrics_features_full_df = prepare_geometric_dfs(f"{samples_folder_full}/geometric_features.csv")

# Model Comparison
The evaluation results of the different models to find the best model

In [None]:
def make_pretty(styler):
    styler.set_caption("Weather Conditions")
    styler.background_gradient(axis=None, vmin=0, vmax=1, cmap="Purples")
    return styler

evaluation_metrics_df.drop("checkpoint_name", axis=1) #.style.pipe(make_pretty)

# Raw Data

In [None]:
means_subtypes = np.zeros((5, 96,96,96))
means_main = np.zeros((2, 96,96,96))
num_samples_subtypes = np.zeros(5)
num_samples_main = np.zeros(2)

for folder in original_folders:
    print("Retrieving data from", folder)
    raw_data = np.load(os.path.join(data_dir, folder, "samples.npz"))
    for i, image in tqdm(enumerate(raw_data["images"])):
        means_main[raw_data["labels"][i]]+= image[0]
        means_subtypes[raw_data["initial_labels"][i]] += image[0]
        num_samples_main[raw_data["labels"][i]] += 1
        num_samples_subtypes[raw_data["initial_labels"][i]] += 1

means_subtypes /= np.expand_dims(num_samples_subtypes, axis=(1, 2, 3))
means_main /= np.expand_dims(num_samples_main, axis=(1, 2, 3))

## Geomterics
Based on samples (newly generated and original ones) geometric properties have been computed that can be further analyzed

In [None]:
from scipy.stats import zscore

# For each group, compute the mean across columns
def get_mean_geometrics_df(base_df, normalize: bool = False, group_by: str = "label_name") -> (pd.DataFrame, pd.DataFrame):
    main_mean_geometrics = []
    main_std_geometrics = []
    remove_columns = ["initial_label_name", "label_name"]
    cols_to_norm = [c for c in base_df.columns if c not in remove_columns]
    df = base_df.copy()
    if normalize:
        df[cols_to_norm] = base_df.drop(remove_columns, axis=1)[cols_to_norm].apply(zscore, axis=0)
    for group, group_df in df.groupby(group_by):
        means = group_df.drop(columns=remove_columns).mean().round(5)
        means[group_by] = group
        main_mean_geometrics.append(means)

        std = group_df.drop(columns=remove_columns).std().round(5)
        std[group_by] = group
        main_std_geometrics.append(std)

    mean_df = pd.DataFrame(main_mean_geometrics)
    mean_df.set_index(group_by, inplace=True)

    std_df = pd.DataFrame(main_std_geometrics)
    std_df.set_index(group_by, inplace=True)
    return mean_df, std_df

main_mean_geometrics_df, main_std_geometrics_df = get_mean_geometrics_df(geometrics_features_df)
initial_mean_geometrics_df, initial_std_geometrics_df = get_mean_geometrics_df(geometrics_features_df, group_by="initial_label_name")
main_norm_mean_geometrics_df, _ = get_mean_geometrics_df(geometrics_features_df, normalize=True)
initial_norm_mean_geometrics_df, _ = get_mean_geometrics_df(geometrics_features_df, normalize=True, group_by="initial_label_name")

samples_mean_geometrics_df, samples_std_geometrics_df = get_mean_geometrics_df(geometrics_features_df)
samples_norm_mean_geometrics_df, _ = get_mean_geometrics_df(geometrics_features_df, normalize=True)
samples_mean_geometrics_df.head()

In [None]:
# Shorter code for simply getting/ reading std
numeric_cols = geometrics_features_df.select_dtypes(include='number').columns
geometrics_features_df.groupby('initial_label_name')[numeric_cols].agg(['mean', 'std'])

#### The differences between HC and MS

In [None]:
def create_diff_plot(df):
    diff = df.iloc[0] - df.iloc[1]
    diff_df = pd.DataFrame({
        'Feature': diff.index.astype(str),
        'Mean Difference': diff.values,
        'Direction': ['HC > MS' if v > 0 else 'MS > HC' for v in diff.values]
    })
    plt.figure(figsize=(8, 5))
    sns.barplot(data=diff_df, x='Feature', y='Mean Difference', hue='Direction', dodge=False, palette={'HC > MS': 'green', 'MS > HC': 'red'})
    plt.axhline(0, color='black', linewidth=0.8)
    plt.xticks(rotation=90)
    plt.title("Mean Difference per Feature")
    plt.tight_layout()
    plt.show()
create_diff_plot(main_norm_mean_geometrics_df)

In [None]:
def create_bar_plot(mean_df, columns, std_df=None, label_key: str = "label_name", individually: bool = False, save_to: str = None, plot_name: str = "geometrics_barplot"):
    """
    Plot all columns of the dataframe to a barplot comparing the values per group (in index)
    """
    mean_sub_df = mean_df.reset_index()
    if std_df is not None:
        std_sub_df = std_df.reset_index()
    if not individually:
        num_columns = len(columns)
        cols = 2  # Number of columns in the grid of subplots
        rows = math.ceil(num_columns / cols)
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 6, rows * 4), constrained_layout=True)
        axes = axes.flatten()

    i=0
    label_order = list(MS_TYPES.keys() if label_key == "initial_label_name" else MAIN_NAME_TO_LABEL.keys())
    # add mean barplot
    for i, column in enumerate(columns):
        if column in ["label_name", "initial_label_name"]:
            continue

        if individually:
            fig, ax = plt.subplots(figsize=(12, 8))
        else:
            ax = axes[i]

        ax.set_facecolor('white')
        sns.barplot(data=mean_sub_df,x=label_key,y=column,order=label_order,ax=ax,hue= label_key,
            palette=
            {label: MS_TYPE_COLORS[MS_TYPES[label]] for label in mean_sub_df['initial_label_name'].unique()} if label_key == "initial_label_name" else
            {label: MS_TYPE_MAIN_COLORS[MAIN_NAME_TO_LABEL[label]] for label in mean_sub_df['label_name'].unique()}
        )

        if std_df is not None:
            # add error bars
            for index, label in enumerate(label_order):
                mean = mean_sub_df[mean_sub_df[label_key] == label][column].values
                std = std_sub_df[std_sub_df[label_key] == label][column].values
                if len(mean) > 0 and len(std) > 0:
                    ax.errorbar(x=index,y=mean[0],yerr=std[0],color="black")

        ax.axhline(0, color='black', linewidth=0.8)
        ax.set_title(column)
        ax.tick_params(axis='x')

        if individually:
            plt.tight_layout()
            if save_to is not None:
                plt.savefig(os.path.join(PLOT_DIR, plot_name + column + save_to))
            plt.show()

    if not individually:
        for j in range(i, len(axes)):
            fig.delaxes(axes[j])
        if save_to is not None:
            plt.savefig(os.path.join(PLOT_DIR, plot_name + save_to))
    plt.show()


def create_box_plot(mean_df, columns, label_key: str = "label_name", individually: bool = False, save_to: str = None, plot_name: str = "geometrics_boxplot"):
    """
    Plot all columns of the dataframe to a barplot comparing the values per group (in index)
    """
    mean_sub_df = mean_df.reset_index()

    if not individually:
        num_columns = len(columns)
        cols = 2  # Number of columns in the grid of subplots
        rows = math.ceil(num_columns / cols)
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 6, rows * 4), constrained_layout=True)
        axes = axes.flatten()
    i = 0
    label_order = list(MS_TYPES.keys() if label_key == "initial_label_name" else MAIN_NAME_TO_LABEL.keys())
    # add mean barplot
    for i, column in enumerate(columns):
        if column in ["label_name", "initial_label_name"]:
            continue
        if individually:
            fig, ax = plt.subplots(figsize=(6,4))
        else:
            ax = axes[i]
        ax.set_facecolor('white')
        sns.boxplot(data=mean_sub_df, x=label_key, y=column, order=label_order, ax=ax, hue=label_key,
                    palette=
                    {label: MS_TYPE_COLORS[MS_TYPES[label]] for label in
                     mean_sub_df['initial_label_name'].unique()} if label_key == "initial_label_name" else
                    {label: MS_TYPE_MAIN_COLORS[MAIN_NAME_TO_LABEL[label]] for label in mean_sub_df['label_name'].unique()}
        )

        ax.axhline(0, color='black', linewidth=0.8)
        ax.set_title(column)
        ax.set_facecolor('white')
        ax.tick_params(axis='x')
        if individually:
            plt.tight_layout()
            if save_to is not None:
                plt.savefig(os.path.join(PLOT_DIR, plot_name + "_" + column + save_to))
            plt.show()

    if not individually:
        for j in range(i, len(axes)):
            fig.delaxes(axes[j])
        if save_to is not None:
            plt.savefig(os.path.join(PLOT_DIR, plot_name + save_to))


    plt.show()

# create_box_plot(geometrics_features_df, geometrics_features_df.columns)
create_box_plot(sampled_geometrics_features_df, sampled_geometrics_features_df.columns, label_key="initial_label_name", individually=True)
# create_bar_plot(main_mean_geometrics_df, main_mean_geometrics_df.columns, main_std_geometrics_df)

### Geometric Distribution for Sampled and Original Data

In [None]:
def create_geometric_distributions(df_original, df_sampled, columns, label_key="label_name", save_to: str = None):
    unique_labels = df_original[label_key].unique()
    n_rows = 2
    for column in columns:
        fig, axes = plt.subplots(nrows=n_rows, ncols=1)
        for i, label in enumerate(unique_labels):
            # Filter the data for the current label
            df_original_label = df_original[df_original[label_key] == label]
            df_sampled_label = df_sampled[df_sampled[label_key] == label]
            sns.histplot(df_original_label,alpha=0.2, x=column, element="step", stat ="probability", common_norm=False, color="grey", ax=axes[i],  label="original")
            sns.histplot(df_sampled_label, alpha=0.2, x=column, color=MS_TYPE_MAIN_COLORS[i], element="step", stat ="probability", common_norm=False, ax=axes[i], label="sampled")
            axes[i].legend()
            axes[i].set_xlabel("")

        plt.suptitle(column)
        if save_to is not None:
            plt.savefig(os.path.join(PLOT_DIR, "distributions", column + save_to))
        plt.show()
create_geometric_distributions(geometrics_features_df, sampled_geometrics_features_full_df, geometrics_features_df.columns,save_to = ".png")

### Statistical Analysis of geometrics
To assess statistical differences using an ANOVA we first have to check its assumptions. Normal distributed values as well as homogeneity of variances

In [None]:
def check_anova_assumptions(df, dv: str, idv: str):
    print(f"Checking assumptions for {dv}")
    # 1. homoscedasticity
    print(pg.homoscedasticity(data= df, dv=dv, group=idv).round(2))
    print(pg.normality(df[dv], method='normaltest'))
    # 2. normal distributed
    formula = f'{dv} ~ {idv}'
    model = ols(formula, data=df).fit()
    res = model.resid
    pg.qqplot(res, dist='norm')
    sns.despine()
    plt.show()
    sns.histplot(res)
    sns.despine()
    plt.show()
#
for column in [item for item in sampled_geometrics_features_df.columns.tolist() if item not in ignore_geometric_columns]:
    check_anova_assumptions(geometrics_features_df, column, "initial_label_name")

In [None]:
# from tabulate import tabulate
def perform_anova(df, idv, columns, welch: bool = False):
    for column in columns:
        if not pd.api.types.is_numeric_dtype(df[column]):
            print(f"\nskipping column {column}, since it is not numeric")
            continue
        print("\nANOVA for column {}".format(column))
        if welch:
            anova_results = pg.welch_anova(dv=column, between=idv, data=df)
        else:
            anova_results = pg.anova(dv=column,
                 between=idv,
                 data=df)
        print(anova_results)
        print("Means", df.groupby(idv)[column].mean())
        print("std", df.groupby(idv)[column].std())
        # if anova_results['p-unc'][0] < 0.05:
        #     print(f"anova was significant, performing pairwise tukey test")
        #     tukey_result = pg.pairwise_tukey(data=df, dv=column, between=idv)
        #     significant = tukey_result[tukey_result['p-tukey'] < 0.05]
        #     print(significant)

perform_anova(geometrics_features_df.drop(geometrics_features_df[geometrics_features_df.label_name == "HC"].index), "initial_label_name", [item for item in geometrics_features_df.columns.tolist() if item not in ignore_geometric_columns], welch=False)

In [None]:
perform_anova(geometrics_features_df, "label_name", [item for item in geometrics_features_df.columns.tolist() if item not in ignore_geometric_columns], welch = False)


In [None]:
pg.pairwise_corr(geometrics_features_df, method='pearson', alternative='two-sided')

In [None]:
from scipy.stats import ks_2samp

def perform_ks_2sample(df, df2, columns, label_key="label_name"):
    grouped_df = df.groupby(label_key)
    grouped_df2 = df2.groupby(label_key)

    for label, group in grouped_df:
        group2 = grouped_df2.get_group(label)
        print(f"\nPerforming KS test for label: {label}")

        for column in columns:
            if not pd.api.types.is_numeric_dtype(group[column]):
                print(f"Skipping column {column}, as it is not numeric")
                continue
            ks_result = ks_2samp(group[column].tolist(), group2[column].tolist())
            print(f"KS test result for {column}:")
            print(ks_result)

perform_ks_2sample(geometrics_features_df, sampled_geometrics_features_full_df, sampled_geometrics_features_full_df.columns)


## Prototypes
Based on the samples prototypes have been generated for the different classes, those prototypes can be analyzed independently

In [None]:
prototypes = np.load(os.path.join(data_dir, samples_folder, prototypes_3D_filename))
prototypes_initial = np.load(os.path.join(data_dir, samples_folder, prototypes_initial_3D_filename))

In [None]:
%matplotlib inline
def plot_prototypes(prototype_array, title, colors=MS_TYPE_MAIN_COLORS, edgecolors=MS_TYPE_MAIN_EDGECOLORS, names=LABEL_TO_MAIN_NAME, save_to: str = None, individually: bool = False, overlay_array: np.ndarray = None):
    cols = prototype_array.shape[0] if not individually else 1
    rows = 1
    fig = plt.figure(figsize=(cols * 6, rows * 6))

    for i in range(prototype_array.shape[0]):
        if individually:
            fig = plt.figure(figsize=(cols * 6, rows * 6))
            ax = fig.add_subplot(1, 1, 1, projection="3d")
        else:
            ax = fig.add_subplot(rows, cols, i + 1, projection="3d")

        if overlay_array is not None:
            add_cloud_plot_to_axs(
                overlay_array[i],
                threshold=0.0,
                ax=ax,
                ax_limits=[(0, 100), (0, 100), (0, 100)],
                color=colors[i],
                alpha_mult=0.03,
                edgecolors=edgecolors[i],
            )

        add_cloud_plot_to_axs(
            prototype_array[i][0],
            threshold=0.5,
            ax=ax,
            ax_limits=[(0, 100), (0, 100), (0, 100)],
            color=colors[i],
            edgecolors=edgecolors[i],
        )

        # ax.set_title(f"{names[i]}")
        ax.set_facecolor('white')
        if individually:
            if save_to is not None:
                plt.savefig(os.path.join(PLOT_DIR, model_name + "_" + names[i] + "_" + title + save_to))
            plt.show()

    # fig.suptitle(title)

    # plt.tight_layout()
    if not individually:
        if save_to is not None:
            plt.savefig(os.path.join(PLOT_DIR, model_name + "_" + title + save_to))
        plt.show()

plot_prototypes(prototypes["images"], "Prototypes Main Classes", save_to=".png", individually=True, overlay_array=means_main)
plot_prototypes(prototypes_initial["images"], "Prototypes Subtypes", names=LABEL_TO_NAME, colors=MS_TYPE_COLORS, edgecolors=MS_TYPE_EDGECOLORS, save_to=".png", individually=True, overlay_array=means_subtypes)
plt.show()
# If this cell freezes comment the scatter and color bar in add_cloud_plot_to_axs out run it and then comment in again, I do not know yet why this works

### Difference plot between prototypes

In [None]:
from src.analysis.plots import shape_difference_plot

shape_difference_plot(prototypes_initial["images"][1][0],prototypes_initial["images"][0][0])
plt.show()

### Geometrics of Prototypes

In [None]:
from variables import MS_TYPES_TO_MAIN_TYPE
from src.analysis.shape_metrics import compute_metrics_for_sample

def geometrics_for_prototype(proto, label_key="label_name", labels=None):
    rows = []
    for i, prototype in tqdm(enumerate(proto["images"])):
        metric_table = compute_metrics_for_sample(prototype[0], raw_features=True)
        metric_table["initial_label"] = i
        metric_table["label"] = labels[i] if labels is not None else MS_TYPES_TO_MAIN_TYPE[i]
        rows.append(metric_table)
    prototype_geometrics_df = preprocess_geometric_df(pd.DataFrame(rows))
    # create_bar_plot(prototype_geometrics_df, prototype_geometrics_df.columns, label_key=label_key, save_to=".svg", individually=True, plot_name="geometrics_prototypes")
    return prototype_geometrics_df
prototype_geometrics_df = geometrics_for_prototype(prototypes_initial)

In [None]:
prototype_geometrics_df

In [None]:
MS_proto_geometrics = prototype_geometrics_df.iloc[1:]
MS_proto_geometrics.drop(["label_name", "initial_label_name"], axis=1).mean()

In [None]:
MS_proto_geometrics.drop(["label_name", "initial_label_name"], axis=1).std()

## Clustering
Visualization and playground for clustering of precomputed data samples

In [None]:
from variables import MS_TYPE_DETAILS, MS_MAIN_TYPE_DETAILS
from src.analysis.latents import reduced_clustering
from src.analysis.utils import load_np_latents

def plot_clustering_results(result, label_array, classes, colors, markers=None, save_to: str = None, plot_name: str = "clustering"):
    plt.figure(figsize=(6,4))

    for i, class_name in enumerate(classes):
        idx = (label_array == i)
        zorder = 1 if i == 3 else 2
        plt.scatter(
            result[idx, 0], result[idx, 1],
            color=colors[i],
            marker=MS_TYPE_DETAILS["markers"][i] if markers is not None else 'o',
            label=class_name,
            alpha=0.5,
            s=40,
            zorder=zorder,
        )
    plt.legend(title="Classes")
    plt.xlabel("UMAP Component 1")
    plt.ylabel("UMAP Component 2")
    plt.title('UMAP Projection of Latent Embeddings')
    if save_to is not None:
        plt.savefig(os.path.join(PLOT_DIR, plot_name + save_to))
    plt.show()

In [None]:
kwargs = {
   "map_location": torch.device('cpu')
}
embeddings, labels, initial_labels = load_np_latents(
    "",
    file_name=latents_filename,
    alternative_paths=[str(os.path.join(data_dir, folder)) for folder in original_folders],
    kwargs_load=kwargs
)
kwargs = {}
args = {
    "umap": [10] # [8,10,15],
    # "tsne": [10, 20, 30]
}
for method in ["umap"]: #, "tsne"]:
    for i in args[method]:
        if method == "tsne":
            kwargs["perplexity"] = i
            if kwargs.get("n_neighbors") is not None:
                del kwargs["n_neighbors"]

        if method == "umap":
            kwargs["n_neighbors"] = i
            kwargs["random_state"] = 0

        cluster_result = reduced_clustering(embeddings, kwargs=kwargs, mode=method)

        print("Method: ", method)

        plot_clustering_results(cluster_result, labels, MS_MAIN_TYPE_DETAILS["names"], MS_MAIN_TYPE_DETAILS["colors"], MS_MAIN_TYPE_DETAILS["markers"], plot_name=method + str(i), save_to=".png")
        plot_clustering_results(cluster_result, initial_labels, MS_TYPE_DETAILS["names"], MS_TYPE_DETAILS["colors"], MS_TYPE_DETAILS["markers"], plot_name=method + str(i), save_to=".png")

In [None]:
from matplotlib.lines import Line2D
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=8, random_state=0)
cluster_labels = kmeans.fit_predict(cluster_result)
#
plt.figure(figsize=(6,4), facecolor="white")
scatter = plt.scatter(cluster_result[:, 0], cluster_result[:, 1], c=cluster_labels, cmap='viridis', s=40, alpha=0.5)
plt.title('UMAP Projection with KMeans Clusters')
plt.xlabel('UMAP Component 1')
plt.ylabel('UMAP Component 2')
colors = [scatter.cmap(scatter.norm(i)) for i in range(kmeans.n_clusters)]
legend_handles = [Line2D([0], [0], marker='o', color="whitesmoke",markerfacecolor=color, markersize=10, label=f"{i}")
                  for i, color in enumerate(colors)]
plt.legend(handles=legend_handles, title='Clusters')
plt.savefig(os.path.join(PLOT_DIR, "kmeans_of_umap.png"))
plt.show()

cluster_means = []
new_embeddings = embeddings.copy()
for cluster_id in range(kmeans.n_clusters):
    cluster_data_points = embeddings[cluster_labels == cluster_id]
    cluster_mean = cluster_data_points.mean(axis=0)
    cluster_means.append(cluster_mean)
cluster_means = np.array(cluster_means)
embeddings_with_means = np.vstack([embeddings, cluster_means])
kwargs_mean = {
    "n_neighbors": 10,
    "random_state": 0
}
cluster_result_means = reduced_clustering(embeddings_with_means, kwargs=kwargs_mean, mode="umap")
#
# kmeans_means = KMeans(n_clusters=5, random_state=0)
# cluster_labels_means = kmeans_means.fit_predict(cluster_result_means)
# cluster_labels_means[-5:] = 10
# plt.figure(figsize=(6,4))
# scatter = plt.scatter(cluster_result_means[:, 0], cluster_result_means[:, 1], c=cluster_labels_means, label=cluster_labels_means, cmap='rainbow', s=40, alpha=0.5)
# plt.colorbar(scatter, label='Cluster ID')
# plt.title('UMAP Projection with KMeans Clusters')
# plt.xlabel('UMAP Component 1')
# plt.ylabel('UMAP Component 2')
# plt.savefig(os.path.join(PLOT_DIR, "kmeans_with_means.png"))
# # Show the plot
# plt.show()

In [None]:
from matplotlib import cm

cluster_prototypes = np.load(os.path.join(data_dir, samples_folder, cluster_prototypes_filename))
cmap = cm.viridis
num_clusters = 8
colors = [cmap(i / (num_clusters - 1)) for i in range(num_clusters)]  # Normalized values from 0 to 1
print(cluster_prototypes["images"].shape[0])
# Create a darker color for each cluster (by scaling the RGB values)
dark_colors = [tuple([max(0, x * 0.5) for x in color[:3]]) + (color[3],)  # Darken by reducing brightness
               for color in colors]
plot_prototypes(cluster_prototypes["images"], "Cluster Prototypes", names=list(map(str, range(5))), colors=colors, edgecolors=dark_colors, save_to=".png", individually=True)

In [None]:
print(cluster_prototypes["images"].shape)
geometrics_for_prototype(cluster_prototypes, label_key="initial_label_name", labels = range(8)).drop([1,3,4])

In [None]:
c0 = cluster_prototypes["images"][0][0]

In [None]:
def bbox2_3D(img):

    segmentation = np.where(img > 0.5)

    # Bounding Box
    bbox = 0, 0, 0, 0
    if len(segmentation) != 0 and len(segmentation[1]) != 0 and len(segmentation[0]) != 0:
        y_min = int(np.min(segmentation[0]))
        y_max = int(np.max(segmentation[0]))
        x_min = int(np.min(segmentation[1]))
        x_max = int(np.max(segmentation[1]))
        z_min = int(np.min(segmentation[2]))
        z_max = int(np.max(segmentation[2]))

        bbox = y_max-y_min, x_max-x_min, z_max-z_min
        print(bbox)

In [None]:
bbox2_3D(c0)