# Tutorial for AFMpy.REC.DSC

## Imports

In [None]:
# Standard library imports
import json
import logging

# Third party imports
import numpy as np
import matplotlib.pyplot as plt

# AFMpy imports
from AFMpy import Stack, DL, Plotting, REC

## Configure Logging

Each module in AFMpy contains logging to for debugging purposes via the default python logging library. Logging for the modules should always be configured at the application level. Included in these tutorials are example logging configuration files that can be loaded with the following functions. You may adjust these logging configuration files as you see fit.

In [None]:
# Load the preconfigured logging settings
with open('logs/DSC_Tutorial_LoggingConfig.json', 'r') as f:
    LOGGING_CONFIG = json.load(f)

# Set up the logging configuration
logging.config.dictConfig(LOGGING_CONFIG)

## Matplotlib Config

Included within the ```Plotting``` module are functions for creating the high quality figures. A default configuration that matches the figures in the publication is activated by running the following function.

In [None]:
Plotting.configure_formatting()

## Check if the GPU is being used for tensorflow

This tutorial uses Tensorflow and Keras in its deep learning algorithms. The performance, especially for large image stacks is substantially degraded when not using the GPU, so it is highly reccomended to use the GPU if available. The helper function ```DL.is_gpu_available``` will check to see if Tensorflow has GPU access. 

In [None]:
# Check to see if the GPU is available
if DL.is_gpu_available():
    print('GPU is accessible by tensorflow.')
else:
    print('GPU is NOT accessible by tensorflow. If you want to use GPU, please check your AFMpy version and tensorflow installation.')

## Load the Stacks

Here we load the compressed pickle file of our stack. A comprehensive explanation of the loading functions is available in the ```LAFM``` tutorial.

In [None]:
# Set the filepath for the public key to verify the integrity of the stacks.
PUBLIC_KEY_FILEPATH = '../common/keys/Tutorial_Public.pub'

# Load the cytoplasmic and periplasmic stacks
cytoplasmic_stack = Stack.Stack.load_compressed_pickle(pickle_filepath = '../common/stacks/Example_AC-20-4.xz',
                                                       public_key_filepath = PUBLIC_KEY_FILEPATH)

periplasmic_stack = Stack.Stack.load_compressed_pickle(pickle_filepath = '../common/stacks/Example_AP-20-4.xz',
                                                       public_key_filepath = PUBLIC_KEY_FILEPATH)

Here we display the metadat of each stack.

In [None]:
print('Cytoplasmic Stack Metadata:')
cytoplasmic_stack.display_metadata()
print('Periplasmic Stack Metadata:')
periplasmic_stack.display_metadata()

## Prepare the Convolutional Autoencoders

Here we prepare the convolutional autoencoder (CAE) for deep spectral clustering.

Distributed in ```AFMpy.Models``` is the abstract base class ```ConvolutionalAutoencoder```. This ABC is a template for custom convoultional autoencoders to inherit from. They must have the following attributes and abstract methods to be considered valid:

- Attributes:
    - ```Encoder``` (```keras.models.Model```): The encoder model that reduces an input image to a latent feature vector.
    - ```Decoder``` (```keras.models.Model```): The decoder model that reconstructs an output image from a latent feature vector.
    - ```Autoencoder``` (```keras.models.Model```): The combined model which concatenates ```Decoder``` onto ```Encoder```.
- Abstract Methods:
    - ```_build_models```: Method that sets the above model attributes.

Included in AFMpy is an example ```ConvolutionalAutoencoder``` called ```DefaultCAE``` used for the analysis in our publication. To initialize an instance of the CAE, we pass the shape of an individual image (expanded to include a channel axis) to ```DL.DefaultCAE```. Hyperparameters such as the shape of convolutional filters (```filter_shape```), number of convolutional filters (```num_filters```), and the size of the latent feature vectors (```latent_dim```) can also be adjusted.

Here we create a separate CAE for the cytoplasmic and periplasmic stack.

In [None]:
# Determine the input shape for the CAE. It should be (width,height,channels). In our case (64,64,1)
cyto_input_shape = (*cytoplasmic_stack.images.shape[1:], 1)
peri_input_shape = (*periplasmic_stack.images.shape[1:], 1)

# Create the Convolutional Autoencoder models to train with our data.
cytoplasmic_CAE = DL.DefaultCAE(input_shape = cyto_input_shape)
periplasmic_CAE = DL.DefaultCAE(input_shape = peri_input_shape)

## Apply Deep Spectral Clustering

### Cytoplasmic

#### Train the model and apply spectral clustering

With the CAE now prepared, we can apply Deep Spectral Clustering (DSC). DSC is effective when the input stack is well aligned and the number of distinct conformations is known. If these conditions are not met, alternative methods such as Hierarchical DSC, REC, or IREC may be more appropriate (see the respective tutorials for details).

The ```REC.DSC``` function applies Deep Spectral Clustering to a given ```Stack``` object using a specified ```ConvolutionalAutoencoder```. The CAE is trained end-to-end on the images in the input stack. Once trained, the latent feature vectors are compared using the locally scaled affinity with a local neighborhood size defined by ```k_neighbors```. The resulting affinity matrix is then clustered using ```sklearn.cluster.SpectralClustering``` with the number of clusters set by ```n_clusters```. A new ```Stack``` object is created for each resulting cluster and returned as a list.

In [None]:
# Use DSC to determine the cluster labels for the cytoplasmic stack  with n=4 clusters
cyto_clusters = REC.DSC(cytoplasmic_stack, cytoplasmic_CAE, n_clusters = 4)

#### Generate the LAFM Images for each cytoplamic cluster

With the clusters now computed, we can calculate the Mean and LAFM images for each clustered stack.

In [None]:
# Generate the mean and LAFM images for each cluster.
for cluster in cyto_clusters:
    cluster.calc_mean_image()
    cluster.calc_LAFM_image(target_resolution = (96,96), sigma = 2.25)

#### Display the cytoplasmic clustered LAFM images

And now we can display the mean and LAFM images from each cluster.

In [None]:
# Set the color range based upon the maximum LAFM image value across all clusters.
vmin = 0
vmax = np.max([cluster.LAFM_image for cluster in cyto_clusters])

# Create the figure.
fig, ax = plt.subplots(4,2, figsize = (6,12))

# Turn off the tick marks
for axis in ax.ravel():
    axis.set_xticks([])
    axis.set_yticks([])

# Set the axis labels 
ax[0,0].set_title('Mean Image', fontsize = 16)
ax[0,1].set_title('LAFM Image', fontsize = 16)
ax[0,0].set_ylabel('Cluster 1', fontsize = 16)
ax[1,0].set_ylabel('Cluster 2', fontsize = 16)
ax[2,0].set_ylabel('Cluster 3', fontsize = 16)
ax[3,0].set_ylabel('Cluster 4', fontsize = 16)

# Plot cluster 1
cyto_clusters[0].plot_mean_image(ax = ax[0,0], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
Plotting.add_scalebar(10/cytoplasmic_stack.resolution, label = '1nm', ax = ax[0,0])
cyto_clusters[0].plot_LAFM_image(ax = ax[0,1], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
Plotting.add_scalebar(30/cytoplasmic_stack.resolution, label = '1nm', size_vertical = 3/8, ax = ax[0,1])

# Plot cluster 2
cyto_clusters[1].plot_mean_image(ax = ax[1,0], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
Plotting.add_scalebar(10/cytoplasmic_stack.resolution, label = '1nm', ax = ax[1,0])
cyto_clusters[1].plot_LAFM_image(ax = ax[1,1], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
Plotting.add_scalebar(30/cytoplasmic_stack.resolution, label = '1nm', size_vertical = 3/8, ax = ax[1,1])

# Plot cluster 3
cyto_clusters[2].plot_mean_image(ax = ax[2,0], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
Plotting.add_scalebar(10/cytoplasmic_stack.resolution, label = '1nm', ax = ax[2,0])
cyto_clusters[2].plot_LAFM_image(ax = ax[2,1], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
Plotting.add_scalebar(30/cytoplasmic_stack.resolution, label = '1nm', size_vertical = 3/8, ax = ax[2,1])

# Plot cluster 4
cyto_clusters[3].plot_mean_image(ax = ax[3,0], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
Plotting.add_scalebar(10/cytoplasmic_stack.resolution, label = '1nm', ax = ax[3,0])
cyto_clusters[3].plot_LAFM_image(ax = ax[3,1], cmap = Plotting.LAFMcmap, vmin = vmin, vmax = vmax)
Plotting.add_scalebar(30/cytoplasmic_stack.resolution, label = '1nm', size_vertical = 3/8, ax = ax[3,1])

# Add the colorbars to the right of the images.
cbar_ax = fig.add_axes([0.91, 0.11, 0.03, 0.77])
Plotting.draw_colorbar_to_ax(vmin, vmax, Plotting.LAFMcmap,
                             label = 'Height (Å)', cbar_ax = cbar_ax)

### Periplasmic

#### Train the model and apply spectral clustering

Here we repeat the deep spectral clustering algorithm for the periplasmic stack.

In [None]:
# Use DSC to determine the cluster labels for the periplasmic stack  with n=2 clusters
peri_clusters = REC.DSC(periplasmic_stack, periplasmic_CAE, n_clusters = 2)

#### Generate the LAFM Images for each periplamic cluster

In [None]:
# Generate the mean and LAFM images for each cluster.
for cluster in peri_clusters:
    cluster.calc_mean_image()
    cluster.calc_LAFM_image(target_resolution = (96,96), sigma = 2.25)

#### Display the periplasmic clustered LAFM images

In [None]:
# Set the color range based on the maximum LAFM image value across all periplasmic clusters
vmin = 0
vmax = np.max([cluster.LAFM_image for cluster in peri_clusters])

# Create the figure
fig, ax = plt.subplots(2, 2, figsize=(6, 6))

# Turn off tick marks
for axis in ax.ravel():
    axis.set_xticks([])
    axis.set_yticks([])

# Set axis labels
ax[0, 0].set_title('Mean Image', fontsize=16)
ax[0, 1].set_title('LAFM Image', fontsize=16)
ax[0, 0].set_ylabel('Cluster 1', fontsize=16)
ax[1, 0].set_ylabel('Cluster 2', fontsize=16)

# Plot Cluster 1
peri_clusters[0].plot_mean_image(ax=ax[0, 0], cmap=Plotting.LAFMcmap, vmin=vmin, vmax=vmax)
Plotting.add_scalebar(10 / periplasmic_stack.resolution, label='1nm', ax=ax[0, 0])
peri_clusters[0].plot_LAFM_image(ax=ax[0, 1], cmap=Plotting.LAFMcmap, vmin=vmin, vmax=vmax)
Plotting.add_scalebar(30 / periplasmic_stack.resolution, label='1nm', size_vertical=3/8, ax=ax[0, 1])

# Plot Cluster 2
peri_clusters[1].plot_mean_image(ax=ax[1, 0], cmap=Plotting.LAFMcmap, vmin=vmin, vmax=vmax)
Plotting.add_scalebar(10 / periplasmic_stack.resolution, label='1nm', ax=ax[1, 0])
peri_clusters[1].plot_LAFM_image(ax=ax[1, 1], cmap=Plotting.LAFMcmap, vmin=vmin, vmax=vmax)
Plotting.add_scalebar(30 / periplasmic_stack.resolution, label='1nm', size_vertical=3/8, ax=ax[1, 1])

# Add the colorbar
cbar_ax = fig.add_axes([0.91, 0.11, 0.03, 0.77])
Plotting.draw_colorbar_to_ax(vmin, vmax, Plotting.LAFMcmap,
                             label='Height (Å)', cbar_ax=cbar_ax)