# Dendritic Spine Clusterization

In [None]:
from spine_metrics import SpineMetricDataset
from notebook_widgets import SpineMeshDataset, intersection_ratios_mean_distance
from spine_segmentation import apply_scale
from spine_fitter import SpineGrouping


dataset_path = "0.025 0.025 0.1 dataset"
scale = (1, 1, 1)

# load meshes and apply scale
spine_dataset = SpineMeshDataset().load(dataset_path)
spine_dataset.apply_scale(scale)

# load merged and reduced manual classification
manual_classification = SpineGrouping().load(f"{dataset_path}/manual_classification/manual_classification_merged_reduced.json")

# load metrics
spine_metrics = SpineMetricDataset().load(f"{dataset_path}/metrics.csv")
spine_metrics = spine_metrics.get_spines_subset(manual_classification.samples)

# extract metric subsets
classic = spine_metrics.get_metrics_subset(['OpenAngle', 'CVD', "JunctionArea", 'AverageDistance', 'Length', 'Area', 'Volume', 'ConvexHullVolume', 'ConvexHullRatio', "LengthVolumeRatio", "LengthAreaRatio"])
classic_prefix = f"{dataset_path}_classic"
chord = spine_metrics.get_metrics_subset(['OldChordDistribution'])
chord_prefix = f"{dataset_path}_chord"

# set score function to mean distance between class over cluster distributions
score_func = lambda clusterizer: intersection_ratios_mean_distance(manual_classification, clusterizer.grouping, False)

## k-Means Classic Metrics

In [None]:
from notebook_widgets import k_means_clustering_experiment_widget

display(k_means_clustering_experiment_widget(classic, spine_metrics, spine_dataset, score_func,
                                             max_num_of_clusters=100, classification=manual_classification,
                                             filename_prefix=classic_prefix))

## k-Means Chord Histograms

In [None]:
from notebook_widgets import k_means_clustering_experiment_widget

display(k_means_clustering_experiment_widget(chord, spine_metrics, spine_dataset, score_func,
                                             max_num_of_clusters=100, classification=manual_classification,
                                             filename_prefix=chord_prefix))

## DBSCAN Classic Metrics

In [None]:
from notebook_widgets import dbscan_clustering_experiment_widget

min_eps = 0.2
max_eps = 6
eps_step = 0.1
use_pca = True

display(dbscan_clustering_experiment_widget(classic, spine_metrics, spine_dataset, score_func,
                                            min_eps=min_eps, max_eps=max_eps, eps_step=eps_step, use_pca=use_pca,
                                            classification=manual_classification, filename_prefix=classic_prefix))

## DBSCAN Chord Histograms Euclidean Distance

In [None]:
from notebook_widgets import dbscan_clustering_experiment_widget

min_eps = 0.1
max_eps = 3
eps_step = 0.1
use_pca = True

display(dbscan_clustering_experiment_widget(chord, spine_metrics, spine_dataset, score_func,
                                            min_eps=min_eps, max_eps=max_eps, eps_step=eps_step, use_pca=use_pca,
                                            classification=manual_classification, filename_prefix=f"{chord_prefix}_euclidean"))

## DBSCAN Chord Histograms Jensen — Shannon Distance

In [None]:
from notebook_widgets import dbscan_clustering_experiment_widget
from scipy.spatial.distance import jensenshannon
import numpy as np

min_eps = 0.1
max_eps = 1
eps_step = 0.01
use_pca = False

def js_distance(x, y) -> float:
    return np.sqrt(jensenshannon(x, y))

display(dbscan_clustering_experiment_widget(chord, spine_metrics, spine_dataset, score_func, metric=js_distance,
                                            min_eps=min_eps, max_eps=max_eps, eps_step=eps_step, use_pca=use_pca,
                                            classification=manual_classification, filename_prefix=f"{chord_prefix}_jensenshannon"))

## View clusterization

In [None]:
from ipywidgets import widgets
from notebook_widgets import inspect_grouping_widget


clusterization_path = "output/clusterization/0.025 0.025 0.1 dataset_chord_kmeans_num_of_clusters=6_6_clusters.json"

clusterization = SpineGrouping().load(clusterization_path)

print("Combined Metrics / Classic Metrics / Chord Histograms")
display(widgets.HBox([clusterization.show(spine_metrics), clusterization.show(classic), clusterization.show(chord)]))

display(inspect_grouping_widget(clusterization, spine_dataset, spine_metrics, spine_metrics, manual_classification))