# CryoDRGN - Diego visualization and figures

This jupyter notebook provides a template for regenerating and customizing cryoDRGN visualizations and figures

In [None]:
from cryodrgn import analysis
from cryodrgn import utils
import cryodrgn.config

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import DBSCAN

In [None]:
# Enable interactive widgets
!jupyter nbextension enable --py widgetsnbextension

In [None]:
def plot_by_cluster_2(
    x,
    y,
    K,
    labels,
    centers=None,
    centers_ind=None,
    annotate=False,
    s=2,
    alpha=0.1,
    colors=None,
    cmap=None,
    figsize=None,
    ax=None,
):
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.get_figure()

    if type(K) is int:
        K = list(range(K))

    if colors is None:
        colors = analysis._get_colors(len(K), cmap)

    # scatter by cluster
    for i in K:
        ii = labels == i
        x_sub = x[ii]
        y_sub = y[ii]
        ax.scatter(
            x_sub,
            y_sub,
            s=s,
            alpha=alpha,
            label="cluster {}".format(i),
            color=colors[i],
            rasterized=True,
        )

    # plot cluster centers
    if centers_ind is not None:
        assert centers is None
        centers = np.array([[x[i], y[i]] for i in centers_ind])
    if centers is not None:
        ax.scatter(centers[:, 0], centers[:, 1], c="k")
    if annotate:
        assert centers is not None
        for i in K:
            ax.annotate(str(i), centers[i, 0:2])
    return fig, ax


### Load dataset

In [None]:
# Specify the workdir, the epoch number (0-based index) and the number of clusters for K-Means 
# and Gaussian Mixture Models

WORKDIR = '..' # CHANGE ME
EPOCH = 49 # CHANGE ME
K = 20 # CHANGE ME

In [None]:
# Load configuration file
config = cryodrgn.config.load(f'{WORKDIR}/config.yaml')
print(config)

### Load results

In [None]:
# Load z
z = utils.load_pkl(f'{WORKDIR}/z.{EPOCH}.pkl')
umap = utils.load_pkl(f'{WORKDIR}/analysis_diego.{EPOCH}/umap.pkl')

### Learning curve

In [None]:
loss = analysis.parse_loss(f'{WORKDIR}/run.log')
plt.plot(loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.axvline(x=EPOCH, linestyle="--", color="black", label=f"Epoch {EPOCH}")
plt.legend()

# Plot PCA

Visualize the latent space by principal component analysis (PCA).

In [None]:
pc, pca = analysis.run_pca(z)

In [None]:
# Style 1 -- Scatter

plt.figure(figsize=(4,4))
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1,rasterized=True)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
#plt.savefig('pca_style1.pdf')

In [None]:
# Style 2 -- Scatter with marginals

g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1,rasterized=True, height=4)
g.ax_joint.set_xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
g.ax_joint.set_ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
#plt.savefig('pca_style2.pdf')

In [None]:
# Style 3 -- Hexbin/heatmap

g = sns.jointplot(x=pc[:,0], y=pc[:,1], height=4, kind='hex')
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
#plt.savefig('pca_style3.pdf')

In [None]:
# Explained variance

plt.bar(np.arange(z.shape[1])+1,pca.explained_variance_ratio_)
plt.xticks(np.arange(z.shape[1])+1)
plt.xlabel('PC')
plt.ylabel('explained variance')

# Plot UMAP

Visualize the latent space by Uniform Manifold Approximation and Projection (UMAP). 

In [None]:
# Style 1 -- Scatter

plt.figure(figsize=(4,4))
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1,rasterized=True)
plt.xticks([])
plt.yticks([])
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
#plt.savefig('umap_style1.pdf')

In [None]:
# Style 2 -- Scatter with marginal distributions

g = sns.jointplot(x=umap[:,0], y=umap[:,1], alpha=.1, s=1,rasterized=True, height=4)
g.ax_joint.set_xlabel('UMAP1')
g.ax_joint.set_ylabel('UMAP2')
#plt.savefig('umap_style2.pdf')

In [None]:
# Style 3 -- Hexbin / heatmap

g = sns.jointplot(x=umap[:,0], y=umap[:,1], kind='hex',height=4)
g.ax_joint.set_xlabel('UMAP1')
g.ax_joint.set_ylabel('UMAP2')
#plt.savefig('umap_style3.pdf')

# Plot K-Means samples by clustering on z, PCA and UMAPs

In [None]:
# Load centers
kmeans_ind_z = np.loadtxt(f'{WORKDIR}/analysis_diego.{EPOCH}/kmeans{K}_z/centers_ind.txt', dtype=int)
kmeans_ind_pca = np.loadtxt(f'{WORKDIR}/analysis_diego.{EPOCH}/kmeans{K}_pca/centers_ind.txt', dtype=int)
kmeans_ind_umap = np.loadtxt(f'{WORKDIR}/analysis_diego.{EPOCH}/kmeans{K}_umap/centers_ind.txt', dtype=int)

# Default chimerax color map
colors = analysis._get_chimerax_colors(K)

In [None]:
# Load kmeans on z
kmeans_labels_z = utils.load_pkl(f'{WORKDIR}/analysis_diego.{EPOCH}/kmeans{K}_z/labels.pkl')
kmeans_centers_z = np.loadtxt(f'{WORKDIR}/analysis_diego.{EPOCH}/kmeans{K}_z/centers.txt')
# Or re-run kmeans with the desired number of classes
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)

# Get index for on-data cluster center
kmeans_centers_z, _ = analysis.get_nearest_point(z, kmeans_centers_z)

In [None]:
# Load kmeans on PCA
kmeans_labels_pca = utils.load_pkl(f'{WORKDIR}/analysis_diego.{EPOCH}/kmeans{K}_pca/labels.pkl')
kmeans_centers_pca = np.loadtxt(f'{WORKDIR}/analysis_diego.{EPOCH}/kmeans{K}_pca/centers.txt')
# Or re-run kmeans with the desired number of classes
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)

# Get index for on-data cluster center
kmeans_centers_pca, _ = analysis.get_nearest_point(pc, kmeans_centers_pca)

In [None]:
# Plot PCA-kmeans on PCA

plot_by_cluster_2(pc[:,0], pc[:,1], K, 
                         kmeans_labels_pca, 
                         centers_ind=kmeans_ind_pca,
                         annotate=True)
plt.title('PCA-KMeans, Centers and Clusters')
plt.xlabel('PCA1')
plt.ylabel('PCA2')

In [None]:
#Plot de cada cluster

fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], K, 
                            kmeans_labels_pca)

In [None]:
# Load kmeans on UMAP

# Load kmeans
kmeans_labels_umap = utils.load_pkl(f'{WORKDIR}/analysis_diego.{EPOCH}/kmeans{K}_umap/labels.pkl')
kmeans_centers_umap = np.loadtxt(f'{WORKDIR}/analysis_diego.{EPOCH}/kmeans{K}_umap/centers.txt')
# Or re-run kmeans with the desired number of classes
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)

# Get index for on-data cluster center
kmeans_centers_umap, _ = analysis.get_nearest_point(umap, kmeans_centers_umap)

In [None]:
# Plot UMAP-kmeans on UMAP

plot_by_cluster_2(umap[:,0], umap[:,1], K, 
                         kmeans_labels_umap, 
                         centers_ind=kmeans_ind_umap,
                         annotate=True)
plt.title('UMAP-KMeans, Centers and Clusters')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], K, 
                            kmeans_labels_umap)

In [None]:
#Plot K-Means Clustering applied on z, PCA and UMAP. Shown on PCA

fig, axs = plt.subplots(1, 3, figsize=(10, 4))

# K-means sobre z proyectado en PCA
plot_by_cluster_2(
    pc[:,0], pc[:,1], K, 
    kmeans_labels_z, 
    centers_ind=kmeans_ind_z, 
    annotate=True, ax=axs[0]
)
axs[0].set_xlabel('PCA1')
axs[0].set_ylabel('PCA2')
axs[0].set_title('z-KMeans sobre PCA')

# K-means sobre PCA proyectado en PCA
plot_by_cluster_2(
    pc[:,0], pc[:,1], K, 
    kmeans_labels_pca, 
    centers_ind=kmeans_ind_pca, 
    annotate=True, ax=axs[1]
)
axs[1].set_xlabel('PC1')
axs[1].set_ylabel('PC2')
axs[1].set_title('PCA-KMeans sobre PCA')

# K-means sobre UMAP proyectado en PCA
plot_by_cluster_2(
    pc[:,0], pc[:,1], K, 
    kmeans_labels_umap, 
    centers_ind=kmeans_ind_umap, 
    annotate=True, ax=axs[2]
)
axs[2].set_xlabel('PC1')
axs[2].set_ylabel('PC2')
axs[2].set_title('UMAP-KMeans sobre PCA')

fig.tight_layout()
plt.show()


In [None]:
fig, axs = plt.subplots(1, 3, figsize=(10, 4))

# z-KMeans sobre UMAP
plot_by_cluster_2(
    umap[:,0], umap[:,1], K, 
    kmeans_labels_z, 
    centers_ind=kmeans_ind_z, 
    annotate=True, ax=axs[0]
)
axs[0].set_xlabel('UMAP1')
axs[0].set_ylabel('UMAP2')
axs[0].set_title('z-KMeans sobre UMAP')

# PCA-KMeans sobre UMAP
plot_by_cluster_2(
    umap[:,0], umap[:,1], K, 
    kmeans_labels_pca, 
    centers_ind=kmeans_ind_pca, 
    annotate=True, ax=axs[1]
)
axs[1].set_xlabel('UMAP1')
axs[1].set_ylabel('UMAP2')
axs[1].set_title('PCA-KMeans sobre UMAP')

# UMAP-KMeans sobre UMAP
plot_by_cluster_2(
    umap[:,0], umap[:,1], K, 
    kmeans_labels_umap, 
    centers_ind=kmeans_ind_umap, 
    annotate=True, ax=axs[2]
)
axs[2].set_xlabel('PC1')
axs[2].set_ylabel('PC2')
axs[2].set_title('UMAP-KMeans sobre UMAP')

fig.tight_layout()
plt.show()

### Plot GMM centers on UMAPs

In [None]:
# Plot UMAP-gmm on UMAP
gmm_ind_umap = np.loadtxt(f'{WORKDIR}/analysis_diego.{EPOCH}/gmm{K}_umap/centers_ind.txt', dtype=int)

f, ax = plt.subplots(figsize=(4,4))
plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)
plt.scatter(umap[gmm_ind_umap,0], umap[gmm_ind_umap,1], c=colors,edgecolor='black')
labels = np.arange(len(kmeans_ind_umap))
centers = umap[gmm_ind_umap]
for i in labels:
    ax.annotate(str(i), centers[i, 0:2] + np.array([0.1, 0.1]))
plt.xticks([])
plt.yticks([])
plt.xlabel('GMM1')
plt.ylabel('GMM2')

### DBSCAN clustering

In [None]:
# DBSCAN en z

dbscan = DBSCAN(eps=0.1, min_samples=250)  # Ajusta eps y min_samples según sea necesario
labels_z = dbscan.fit_predict(z)

# Filtrar puntos que no son ruido
mask = labels_z != -1  # Máscara para eliminar el ruido
filtered_z = z[mask]  # Filtramos los puntos
filtered_labels_z = labels_z[mask]  # Filtramos las etiquetas

# Obtener etiquetas únicas
unique_labels_all_z = set(labels_z)  # Incluye ruido (-1)
unique_labels_filtered_z = set(filtered_labels_z)  # Excluye ruido

# Crear figura con dos subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Subplot 1: Con ruido
ax = axes[0]
for label in unique_labels_all_z:
    cluster_points = z[labels_z == label]
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], 
               label=f"Cluster {label}" if label != -1 else "Ruido",
               alpha=0.8, edgecolors="k")

ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.legend()
ax.set_title("DBSCAN - Clusters con ruido")

# Subplot 2: Sin ruido
ax = axes[1]
for label in unique_labels_filtered_z:
    cluster_points = filtered_z[filtered_labels_z == label]
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], 
               label=f"Cluster {label}", alpha=0.8, edgecolors="k")

ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.legend()
ax.set_title("DBSCAN - Clusters sin ruido")

plt.tight_layout()
plt.show()

In [None]:
# DBSCAN en PCA

dbscan = DBSCAN(eps=0.1, min_samples=250)  # Ajusta eps y min_samples según sea necesario
labels_pca = dbscan.fit_predict(pc)

# Filtrar puntos que no son ruido
mask = labels_pca != -1  # Máscara para eliminar el ruido
filtered_pca = pc[mask]  # Filtramos los puntos
filtered_labels_pca = labels_pca[mask]  # Filtramos las etiquetas

# Obtener etiquetas únicas
unique_labels_all_pca = set(labels_pca)  # Incluye ruido (-1)
unique_labels_filtered_pca = set(filtered_labels_pca)  # Excluye ruido

# Crear figura con dos subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Subplot 1: Con ruido
ax = axes[0]
for label in unique_labels_all_pca:
    cluster_points = pc[labels_pca == label]
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], 
               label=f"Cluster {label}" if label != -1 else "Ruido",
               alpha=0.8, edgecolors="k")

ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.legend()
ax.set_title("DBSCAN - Clusters con ruido")

# Subplot 2: Sin ruido
ax = axes[1]
for label in unique_labels_filtered_pca:
    cluster_points = filtered_pca[filtered_labels_pca == label]
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], 
               label=f"Cluster {label}", alpha=0.8, edgecolors="k")

ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.legend()
ax.set_title("DBSCAN - Clusters sin ruido")

plt.tight_layout()
plt.show()

In [None]:
# DBSCAN en UMAP

dbscan = DBSCAN(eps=0.1, min_samples=50)  # Ajusta eps y min_samples según sea necesario
labels = dbscan.fit_predict(umap)

# Filtrar puntos que no son ruido
mask = labels != -1  # Máscara para eliminar el ruido
filtered_umap = umap[mask]  # Filtramos los puntos
filtered_labels = labels[mask]  # Filtramos las etiquetas

# Obtener etiquetas únicas
unique_labels_all = set(labels)  # Incluye ruido (-1)
unique_labels_filtered = set(filtered_labels)  # Excluye ruido

# Crear figura con dos subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Subplot 1: Con ruido
ax = axes[0]
for label in unique_labels_all:
    cluster_points = umap[labels == label]
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], 
               label=f"Cluster {label}" if label != -1 else "Ruido",
               alpha=0.8, edgecolors="k")

ax.set_xlabel("UMAP1")
ax.set_ylabel("UMAP2")
ax.legend()
ax.set_title("DBSCAN - Clusters con ruido")

# Subplot 2: Sin ruido
ax = axes[1]
for label in unique_labels_filtered:
    cluster_points = filtered_umap[filtered_labels == label]
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], 
               label=f"Cluster {label}", alpha=0.8, edgecolors="k")

ax.set_xlabel("UMAP1")
ax.set_ylabel("UMAP2")
ax.legend()
ax.set_title("DBSCAN - Clusters sin ruido")

plt.tight_layout()
plt.show()