In [2]:
import numpy as np

from analysis_tools.clustering import compute_partial_scores_matrix_fast
from analysis_tools.clustering import compute_clusters
from sklearn import cluster
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, fcluster,linkage
from scipy.cluster.hierarchy import set_link_color_palette


import matplotlib.pyplot as plt
from sklearn.cluster import OPTICS
from sklearn.manifold import MDS

from pyminimax import minimax

COMPUTATION OF SCORES MATRIX

Compute scores matrix (with Spectral descriptors OR Zernike descriptors)

In [None]:
#Find example data of scores and entries in /pdb_shape_retrieval/example_data
#scores file should have three columns, first two colums #1 #2- entries IDs, third column #3 score
# entries_file - list of entry ids  

scores_file = './example_data/clustering_data/complexes/SARS_Spike_protein/zernike_scores.txt'
entries_file ='./example_data/clustering_data/complexes/SARS_Spike_protein/list_entries.txt'

#compute scores matrix
scores_matrix, labels = compute_partial_scores_matrix_fast(scores_file, entries_file)
scores_matrix_single=scores_matrix/scores_matrix.max()


#plot scores matrix
fig, (ax1) = plt.subplots(1, figsize=(4, 4))
sns.heatmap(scores_matrix_single,ax=ax1,vmin=0.0,cmap ='seismic')


Define a combined score using Zernike and Spectrals similarity scores to compute scores matrix  (optional)

In [None]:
#Find example data of scores and entries in /pdb_shape_retrieval/example_data

scores_spectral_file = './example_data/clustering_data/complexes/SARS_Spike_protein/spectral_scores.txt'
scores_zernike_file ='./example_data/clustering_data/complexes/SARS_Spike_protein/zernike_scores.txt'
entries_file = './example_data/clustering_data/complexes/SARS_Spike_protein/list_entries.txt'

#compute scores matrix
scores_spectral_matrix, labels = compute_partial_scores_matrix_fast(scores_spectral_file, entries_file)
scores_zernike_matrix, labels = compute_partial_scores_matrix_fast(scores_zernike_file, entries_file)

#normalize scores
scores_zernike_matrix=scores_zernike_matrix/scores_zernike_matrix.max()
scores_spectral_matrix=scores_spectral_matrix/scores_spectral_matrix.max()

#combine scores
scores_matrix_combined = (scores_spectral_matrix*scores_zernike_matrix)**0.5

#plot scores matrix
fig, (ax1) = plt.subplots(1, figsize=(4, 4))
#sns.heatmap(sym_matrix,ax=ax1,vmin=0.0,cmap ='seismic',xticklabels=labels,yticklabels=labels)
sns.heatmap(scores_matrix_combined,ax=ax1,vmin=0.0,cmap ='seismic',xticklabels=labels,yticklabels=labels)


CLUSTERING ALGORITHMS 

Define scores matrix to use for clustering

In [None]:
#Define scores matrix to use: 
# 1. scores_matrix = scores_matrix_single #Single descriptor scores matrix (Zernike or Spectral)
# 2. scores_matrix = scores_matrix_combined #Zernike and Spectral combined scores matrix

scores_matrix = scores_matrix_combined

Hierarchical clustering

In [None]:
#Define a threshold or number of clusters
threshold = None
n_clusters=2

#define clustering method ('ward' or 'average') and compute clusters
linkage_method="ward"

#Compute clusters
clusters,n_clusters,link_matrix,threshold_dist = compute_clusters(scores_matrix,labels,cluster,linkage_method, threshold, n_clusters)

print('optimal number of clusters',n_clusters)
print('threshold distance',threshold_dist)

#set clustering colour palette
set_link_color_palette(['red', 'blue', 'green', 'orange', 'purple','black'])

#plot dendogram
dendrogram(link_matrix,truncate_mode = "level", p=50,color_threshold=threshold_dist, labels=labels, above_threshold_color='lightgrey')
plt.axhline(threshold_dist, color='k', linestyle='--')
plt.xticks(rotation=90)  # Rotate labels vertically
plt.tight_layout() 
plt.show()

Minimax clustering

In [None]:
from scipy.spatial.distance import squareform
from scipy.spatial.distance import pdist

Z = minimax(pdist(scores_matrix))

max_clusters = 6
clusters = fcluster(Z, t=max_clusters, criterion='maxclust')

# Step 4: Print clusters with labels
for cluster_id in sorted(set(clusters)):
    members = [label for label, c in zip(labels, clusters) if c == cluster_id]
    print(f"Cluster {cluster_id}: {members}")

fig = plt.figure(figsize=(10, 4))
dendrogram(Z, labels=labels)
plt.xticks(rotation=90)  # Rotate labels vertically
plt.tight_layout() 
plt.show()

OPTICS clustering

In [None]:

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
D_scaled = scores_matrix

optics = OPTICS(min_samples=2, xi=0.001)
optics.fit(D_scaled)

# 4️⃣ Get cluster labels (automatic, no eps)
cluster_labels = optics.labels_
print("Cluster labels length:", len(cluster_labels))  # should be 29
print("Cluster labels:", cluster_labels)

# 5️⃣ Embed distances in 2D for plotting
mds = MDS(n_components=2, dissimilarity='precomputed', random_state=42)
coords = mds.fit_transform(scores_matrix)

# 6️⃣ Plot clusters
plt.figure(figsize=(5, 5), facecolor='white')

unique_clusters = np.unique(cluster_labels)

cmap = plt.cm.Set2
colors = cmap(np.linspace(0, 1, max(len(unique_clusters) - 1, 1)))

color_idx = 0
for cluster_id in unique_clusters:
    mask = cluster_labels == cluster_id

    if cluster_id == -1:
        plt.scatter(
            coords[mask, 0], coords[mask, 1],
            c='0.7',
            s=20,
            alpha=0.6,
            label='Noise'
        )
    else:
        plt.scatter(
            coords[mask, 0], coords[mask, 1],
            s=35,
            color=colors[color_idx],
            alpha=0.85,
            label=f'Cluster {cluster_id}'
        )
        color_idx += 1

# ---- Labels (subtle) ----
for i, (x, y) in enumerate(coords):
    plt.text(
        x, y,
        labels[i],
        fontsize=11,
        alpha=0.8
    )

# ---- Axis limits: COMPRESS the view ----
x_lo, x_hi = np.percentile(coords[:, 0], [5, 95])
y_lo, y_hi = np.percentile(coords[:, 1], [5, 95])

x_pad = 0.2 * (x_hi - x_lo)
y_pad = 0.2 * (y_hi - y_lo)

plt.xlim(x_lo - x_pad, x_hi + x_pad)
plt.ylim(y_lo - y_pad, y_hi + y_pad)

# ---- Axes on, clean ----
plt.xlabel('MDS dimension 1')
plt.ylabel('MDS dimension 2')
plt.title('OPTICS clustering (MDS)', fontsize=11)

plt.legend(frameon=False, fontsize=7)
plt.tight_layout()
plt.show()


