# Tutorial for AFMpy.REC.HierarchicalDSC

## Imports

In [None]:
# Standard library imports
import json
import logging

# Third party imports
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt


# AFMpy imports
from AFMpy import Stack, DL, Plotting, REC, SSIM

## Configure Logging

Each module in AFMpy contains logging to for debugging purposes via the default python logging library. Logging for the modules should always be configured at the application level. Included in these tutorials are example logging configuration files that can be loaded with the following functions. You may adjust these logging configuration files as you see fit.

In [None]:
# Load the preconfigured logging settings
with open('logs/HierarchicalDSC_Tutorial_LoggingConfig.json', 'r') as f:
    LOGGING_CONFIG = json.load(f)

# Set up the logging configuration
logging.config.dictConfig(LOGGING_CONFIG)

## Matplotlib Config

Included within the ```Plotting``` module are functions for creating the high quality figures. A default configuration that matches the figures in the publication is activated by running the following function.

In [None]:
Plotting.configure_formatting()

## Check if GPU is accessible by Tensorflow

This tutorial uses Tensorflow and Keras in its deep learning algorithms. The performance, especially for large image stacks is substantially degraded when not using the GPU, so it is highly reccomended to use the GPU if available. The helper function ```DL.is_gpu_available``` will check to see if Tensorflow has GPU access. 

In [None]:
# Check to see if the GPU is available
if DL.is_gpu_available():
    print('GPU is accessible by tensorflow.')
else:
    print('GPU is NOT accessible by tensorflow. If you want to use GPU, please check your AFMpy version and tensorflow installation.')

## Load the Stacks

Here we load the compressed pickle file of our stack.

A comprehensive explanation of the loading functions is available in the ```LAFM``` tutorial.

In [None]:
# Set the filepath for the public key to verify the integrity of the stacks.
PUBLIC_KEY_FILEPATH = '../common/keys/Tutorial_Public.pub'

# Load the cytoplasmic and periplasmic stacks
cytoplasmic_stack = Stack.Stack.load_compressed_pickle(pickle_filepath = '../common/stacks/Example_AC-20-4.xz',
                                                       public_key_filepath = PUBLIC_KEY_FILEPATH)

periplasmic_stack = Stack.Stack.load_compressed_pickle(pickle_filepath = '../common/stacks/Example_AP-20-4.xz',
                                                       public_key_filepath = PUBLIC_KEY_FILEPATH)

## Prepare the Convolutional Autoencoders

Prepare the convolutional autoencoder (CAE) for deep spectral clustering.

A comprehensive description of the ```ConvolutionalAutoencoder``` object is available in the ```DSC``` tutorial.

In [None]:
# Determine the input shape for the CAE. It should be (width,height,channels). In our case (64,64,1)
cyto_input_shape = (*cytoplasmic_stack.images.shape[1:], 1)
peri_input_shape = (*periplasmic_stack.images.shape[1:], 1)

# Set the configurations for the Convolutional Autoencoder. Adjust the parameters as needed.
compile_config = DL.CompileConfig(optimizer = 'adam', loss = DL.Losses.combined_ssim_loss)
fit_config = DL.FitConfig(epochs = 25, batch_size = 32, verbose = 1)
predict_config = DL.PredictConfig(batch_size = 32, verbose = 1)

# Create the Convolutional Autoencoder models to train with our data.
cytoplasmic_CAE = DL.DefaultCAE(input_shape = cyto_input_shape, compile_config = compile_config, fit_config = fit_config, predict_config = predict_config)
periplasmic_CAE = DL.DefaultCAE(input_shape = peri_input_shape, compile_config = compile_config, fit_config = fit_config, predict_config = predict_config)

The same pretrained weights can be used for Hierarchical DSC, so load them if model training is too computationally expensive.

In [None]:
# Whether to use pretrained weights or not.
use_pretrained_weights = False

# Load the pretrained weights if the user has set the flag to True.
if use_pretrained_weights:
    cytoplasmic_CAE.load_weights('../common/weights/Cytoplasmic_CAE.weights.h5')
    periplasmic_CAE.load_weights('../common/weights/Periplasmic_CAE.weights.h5')

## Apply Hierarchical Deep Spectral Clustering

### Cytoplasmic Hierarchical DSC

#### Train the Model and Apply Hierarchical Deep Spectral Clustering

Before applying Hierarchical Deep Spectral Clustering (Hierarchical DSC), we first shuffle the stack of images using the ```Stack.shuffle()``` method. This shuffling helps reduce potential biases arising from sequentially scanned images. The ```shuffle()``` method is called directly on the ```Stack``` object and does not require any arguments.

The ```REC.hierarchical_DSC``` function performs Hierarchical Deep Spectral Clustering on a given ```Stack``` object using a specified ```ConvolutionalAutoencoder```. This method iteratively applies spectral clustering, refining clusters based on their stability as measured by Structural Similarity Index Measure (SSIM) comparisons between LAFM images of clusters.

Parameters:

- ```k_neighbors```: Defines the local neighborhood size used to calculate the locally scaled affinity matrix from latent feature vectors.
- ```max_iterations```: Limits the maximum number of iterations for hierarchical clustering.
- ```lafm_target_resolution```: Sets the resolution for LAFM image generation, defaulting to three times the input resolution if unspecified.
- ```lafm_sigma```: Specifies the Gaussian smoothing parameter for LAFM images.
- ```distinct_cluster_threshold```: SSIM threshold above which clusters from a 2-way split are considered indistinct, halting further splitting.
- ```stability_threshold```: SSIM threshold to determine the stability of clusters across iterations.
- ```min_cluster_size```: Minimum number of images required to form a valid cluster.

The function returns a list of new ```Stack``` objects, each representing a distinct conformational cluster.

After we apply the Hierarchical DSC algorithm, we generate the mean and LAFM images for each cluster.

In [None]:
# Shuffle the stack so that the order of the images doesn't bias the training.
cytoplasmic_stack.shuffle()

# Apply Hierarchical DSC to the cytoplasmic stack.
cyto_clusters = REC.hierarchical_DSC(cytoplasmic_stack,
                                     cytoplasmic_CAE,
                                     k_neighbors = 7,
                                     max_iterations = 5,
                                     lafm_target_resolution = (96,96),
                                     lafm_sigma = 2.25,
                                     stability_threshold = 0.85,
                                     distinct_cluster_threshold = 0.65,
                                     min_cluster_size = 750)

# Calculate the mean images for each cluster
for cluster in cyto_clusters:
    cluster.calc_mean_image() 
    cluster.calc_LAFM_image(target_resolution = (96, 96), sigma = 2.25)

#### Load And Process the Cytoplasmic Benchmark Stack

Load the benchmark stacks, and cluster them according to the cluster labels found by hierarchically deep spectral clustering AC-20-4 creating equivalent benchmark clusters. Each benchmark cluster has its LAFM image generated.

In [None]:
benchmark_cytoplasmic_stack = Stack.Stack.load_compressed_pickle(pickle_filepath = '../common/stacks/Example_AC-2-2.xz',
                                                                 public_key_filepath = PUBLIC_KEY_FILEPATH)

benchmark_cyto_clusters = []
for cluster in cyto_clusters:
    benchmark_images = benchmark_cytoplasmic_stack.images[cluster.indexes]
    benchmark_cluster = Stack.Stack(images = benchmark_images, resolution = benchmark_cytoplasmic_stack.resolution, indexes = cluster.indexes)
    benchmark_cluster.calc_LAFM_image(target_resolution = (192,192), sigma = 2.25)
    benchmark_cyto_clusters.append(benchmark_cluster)

#### Display the Cytoplasmic Clustered LAFM Images

Plot the mean, LAFM, and benchmark LAFM images from each cluster.

In [None]:
# Set the color range based upon the maximum LAFM image value across all clusters.
vmin = 0
vmax = np.max([cluster.LAFM_image for cluster in cyto_clusters])

# Create the figure.
fig, ax = plt.subplots(len(cyto_clusters), 3, figsize = (9, 3 * len(cyto_clusters)))

# Turn off the tick marks
for axis in ax.ravel():
    axis.set_xticks([])
    axis.set_yticks([])

# Set the axis labels 
ax[0,0].set_title('Mean Image', fontsize = 16)
ax[0,1].set_title('LAFM Image', fontsize = 16)
ax[0,2].set_title('Benchmark Image', fontsize = 16)

# Plot the mean images, LAFM images, and benchmark images for each cluster.
for i in range(len(cyto_clusters)):
    # Set the title for the cluster
    ax[i,0].set_ylabel(f'Cluster {i}', fontsize = 16)

    # Plot the mean image and its scalebar
    cyto_clusters[i].plot_mean_image(ax = ax[i,0], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
    Plotting.add_scalebar(10/cytoplasmic_stack.resolution, label = '1nm', ax = ax[i,0])

    # Plot the LAFM image and its scalebar
    cyto_clusters[i].plot_LAFM_image(ax = ax[i,1], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
    Plotting.add_scalebar(30/cytoplasmic_stack.resolution, label = '1nm', size_vertical = 3/8, ax = ax[i,1])

    # Plot the benchmark image and its scalebar
    benchmark_cyto_clusters[i].plot_LAFM_image(ax = ax[i,2], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
    Plotting.add_scalebar(60/cytoplasmic_stack.resolution, label = '1nm', size_vertical = 6/8, ax = ax[i,2])

# Add the colorbars to the right of the images.
cbar_ax = fig.add_axes([0.91, 0.11, 0.03, 0.77])
Plotting.draw_colorbar_to_ax(vmin, vmax, Plotting.LAFMcmap,
                             label = 'Height (Å)', cbar_ax = cbar_ax)

plt.show()

#### Calculate the SSIM Between the LAFM Images

Finally, we evaluate the increase in image quality via LAFM by calculating the Structural Similarity Index Measure (SSIM) between the mean/LAFM image and the benchmark LAFM. 

In [None]:
ssim_df = {
    'Mean Image':[SSIM.masked_SSIM(cv2.resize(cyto_clusters[index].mean_image, (192,192)), benchmark_cyto_clusters[index].LAFM_image) for index in range(len(cyto_clusters))],
    'LAFM Image':[SSIM.masked_SSIM(cv2.resize(cyto_clusters[index].LAFM_image, (192,192)), benchmark_cyto_clusters[index].LAFM_image) for index in range(len(cyto_clusters))]
}
ssim_df = pd.DataFrame(ssim_df)
ssim_df.index = [f'Cluster {index}' for index in range(len(cyto_clusters))]
display(ssim_df.round(2))

### Periplasmic Hierarchical DSC

Repeat the hierachical deep spectral clustering algorithm for the periplasmic stack.

In [None]:
# Shuffle the stack so that the order of the images doesn't bias the training.
periplasmic_stack.shuffle()

# Apply Hierarchical DSC to the cytoplasmic stack.
peri_clusters = REC.hierarchical_DSC(periplasmic_stack,
                                     periplasmic_CAE,
                                     k_neighbors = 7,
                                     max_iterations = 5,
                                     lafm_target_resolution = (96,96),
                                     lafm_sigma = 2.25,
                                     stability_threshold = 0.85,
                                     distinct_cluster_threshold = 0.65,
                                     min_cluster_size = 750)

# Calculate the mean images for each cluster
for cluster in peri_clusters:
    cluster.calc_mean_image()
    cluster.calc_LAFM_image(target_resolution = (96, 96), sigma = 2.25)

#### Load and Process the Periplasmic Benchmark Stack

In [None]:
benchmark_periplasmic_stack = Stack.Stack.load_compressed_pickle(pickle_filepath = '../common/stacks/Example_AP-2-2.xz',
                                                                 public_key_filepath = PUBLIC_KEY_FILEPATH)

benchmark_peri_clusters = []
for cluster in peri_clusters:
    benchmark_images = benchmark_periplasmic_stack.images[cluster.indexes]
    benchmark_cluster = Stack.Stack(images = benchmark_images, resolution = benchmark_periplasmic_stack.resolution, indexes = cluster.indexes)
    benchmark_cluster.calc_LAFM_image(target_resolution = (192,192), sigma = 2.25)
    benchmark_peri_clusters.append(benchmark_cluster)

#### Display the Periplasmic Clustered LAFM Images

In [None]:
# Set the color range based on the maximum LAFM image value across all periplasmic clusters
vmin = 0
vmax = np.max([cluster.LAFM_image for cluster in peri_clusters])

# Create the figure
fig, ax = plt.subplots(len(peri_clusters), 3, figsize=(9, 3 *len(peri_clusters)))

# Turn off tick marks
for axis in ax.ravel():
    axis.set_xticks([])
    axis.set_yticks([])

# Set axis labels
ax[0, 0].set_title('Mean Image',      fontsize=16)
ax[0, 1].set_title('LAFM Image',      fontsize=16)
ax[0, 2].set_title('Benchmark Image', fontsize=16)

for i in range(len(peri_clusters)):
    # Set the title for the cluster
    ax[i,0].set_ylabel(f'Cluster {i}', fontsize = 16)

    # Plot the mean image and its scalebar
    peri_clusters[i].plot_mean_image(ax = ax[i,0], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
    Plotting.add_scalebar(10/periplasmic_stack.resolution, label = '1nm', ax = ax[i,0])

    # Plot the LAFM image and its scalebar
    peri_clusters[i].plot_LAFM_image(ax = ax[i,1], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
    Plotting.add_scalebar(30/periplasmic_stack.resolution, label = '1nm', size_vertical = 3/8, ax = ax[i,1])

    # Plot the benchmark image and its scalebar
    benchmark_peri_clusters[i].plot_LAFM_image(ax = ax[i,2], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
    Plotting.add_scalebar(60/periplasmic_stack.resolution, label = '1nm', size_vertical = 6/8, ax = ax[i,2])

# Add the colorbar
cbar_ax = fig.add_axes([0.91, 0.11, 0.03, 0.77])
Plotting.draw_colorbar_to_ax(vmin, vmax, Plotting.LAFMcmap,
                             label='Height (Å)', cbar_ax=cbar_ax)

plt.show()

#### Calculate the SSIM between the LAFM Images

In [None]:
ssim_df = {
    'Mean Image':[SSIM.masked_SSIM(cv2.resize(peri_clusters[index].mean_image, (192,192), cv2.INTER_CUBIC), benchmark_peri_clusters[index].LAFM_image) for index in range(len(peri_clusters))],
    'LAFM Image':[SSIM.masked_SSIM(cv2.resize(peri_clusters[index].LAFM_image, (192,192), cv2.INTER_CUBIC), benchmark_peri_clusters[index].LAFM_image) for index in range(len(peri_clusters))]
}
ssim_df = pd.DataFrame(ssim_df)
ssim_df.index = [f'Cluster {index}' for index in range(len(peri_clusters))]

display(ssim_df.round(2))