# **Using aeon Distances with scikit-learn Clusterers**

This notebook demonstrates how to integrate aeon’s distance metrics with hierarchical, density-based, and spectral clustering methods from scikit-learn. While aeon primarily supports partition-based clustering algorithms, such as $k$-means and $k$-medoids, its robust distance measures can be leveraged to enable other clustering techniques using scikit-learn.

To measure similarity between time series and enable clustering, we use aeon’s precomputed distance matrices. For details about distance metrics, see the [distance examples](../distances/distances.ipynb).

## **Contents**
1. **Example Dataset**: Using the `load_unit_test` dataset from aeon.
2. **Computing Distance Matrices with aeon**: Precomputing distance matrices with aeon’s distance metrics.
3. **Hierarchical Clustering**
4. **Density-Based Clustering**
5. **Spectral Clustering**

## **Example Dataset**

We'll begin by loading a sample dataset. For this demonstration, we'll use the `load_unit_test` dataset from aeon.


In [None]:
# Import & load data
from aeon.datasets import load_unit_test
X, y = load_unit_test(split="train")

print(f"Data shape: {X.shape}")
print(f"Labels shape: {y.shape}")


## **Computing Distance Matrices with aeon**

Aeon provides a variety of distance measures suitable for time series data. We'll compute the distance matrix using the Dynamic Time Warping (DTW) distance as an example.

For a comprehensive overview of all available distance metrics in aeon, see the [aeon distances API reference](https://www.aeon-toolkit.org/en/stable/api_reference/distances.html).


In [None]:
from aeon.distances import pairwise_distance

# Compute the pairwise distance matrix using DTW
distance_matrix = pairwise_distance(X, metric="dtw")

print(f"Distance matrix shape: {distance_matrix.shape}")


## **Hierarchical Clustering**

## **Hierarchical Clustering**

[AgglomerativeClustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) is, as the name suggests, an agglomerative approach that works by merging clusters bottom-up. 
 

Hierarchical clustering builds a hierarchy of clusters either by progressively merging or splitting existing clusters. We'll use scikit-learn's AgglomerativeClustering with the precomputed distance matrix.

Not all linkage methods can be used with a precomputed distance matrix. The following linkage methods work with aeon distances:
- `single`
- `complete`
- `average`
- `weighted`

In [None]:
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
import numpy as np

# Perform Agglomerative Clustering
agg_clustering = AgglomerativeClustering(
    n_clusters=2, affinity="precomputed", linkage="average"
)
labels = agg_clustering.fit_predict(distance_matrix)

# Visualize the clustering results
plt.figure(figsize=(10, 6))
for label in np.unique(labels):
    plt.plot(X[labels == label].mean(axis=0), label=f"Cluster {label}")
plt.title("Hierarchical Clustering with DTW Distance")
plt.legend()
plt.show()

## **Density-Based Clustering**
Density-based clustering identifies clusters based on the density of data points in the feature space. We'll demonstrate this using scikit-learn's `DBSCAN` and `OPTICS` algorithms.

### **DBSCAN**

[DBSCAN](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html) is a density-based clustering algorithm that groups data points based on their density connectivity. 
We use the `DBSCAN` algorithm from scikit-learn with a precomputed distance matrix.


In [None]:
from sklearn.cluster import DBSCAN

# Perform DBSCAN clustering
dbscan = DBSCAN(eps=0.5, min_samples=5, metric="precomputed")
dbscan_labels = dbscan.fit_predict(distance_matrix)

# Visualize the clustering results
plt.figure(figsize=(10, 6))
for label in np.unique(dbscan_labels):
    if label == -1:
        # Noise points
        plt.plot(X[dbscan_labels == label].mean(axis=0), label="Noise", linestyle="--")
    else:
        plt.plot(X[dbscan_labels == label].mean(axis=0), label=f"Cluster {label}")
plt.title("DBSCAN Clustering with DTW Distance")
plt.legend()
plt.show()


### **OPTICS**
[DBSCAN](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html) is a density-based clustering algorithm similar to DBSCAN but provides better handling of varying 
densities. We use the `OPTICS` algorithm from scikit-learn with a precomputed distance matrix.

In [None]:
from sklearn.cluster import OPTICS

# Perform OPTICS clustering
optics = OPTICS(min_samples=5, metric="precomputed")
optics_labels = optics.fit_predict(distance_matrix)

# Visualize the clustering results
plt.figure(figsize=(10, 6))
for label in np.unique(optics_labels):
    if label == -1:
        # Noise points
        plt.plot(X[optics_labels == label].mean(axis=0), label="Noise", linestyle="--")
    else:
        plt.plot(X[optics_labels == label].mean(axis=0), label=f"Cluster {label}")
plt.title("OPTICS Clustering with DTW Distance")
plt.legend()
plt.show()


## **Spectral Clustering**
[SpectralClustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.SpectralClustering.html) performs dimensionality reduction on the data before clustering in fewer dimensions. It requires a similarity matrix, so we'll convert our distance matrix accordingly.

In [None]:
from sklearn.cluster import SpectralClustering
import numpy as np
import matplotlib.pyplot as plt

# Ensure the distance matrix does not contain zeros on the diagonal or elsewhere
# Normalize distance values to [0, 1] and convert to similarities
inverse_distance_matrix = 1 - (distance_matrix / distance_matrix.max())

# Perform Spectral Clustering with affinity="precomputed"
spectral = SpectralClustering(
    n_clusters=2, affinity="precomputed", random_state=42
)
spectral_labels = spectral.fit_predict(inverse_distance_matrix)

# Visualising the clustering results
plt.figure(figsize=(10, 6))
for label in np.unique(spectral_labels):
    plt.plot(X[spectral_labels == label].mean(axis=0), label=f"Cluster {label}")
plt.title("Spectral Clustering with Normalized Similarity Matrix")
plt.legend()
plt.show()
