## Graph Clustering

In [None]:
import hdbscan
import numpy as np
import os
import torch
import umap
from matplotlib import pyplot as plt
from tqdm import tqdm

# %matplotlib notebook
%matplotlib inline

### Configurations

In [None]:
model_name_or_path = "DaizeDong/GraphsGPT-8W"
smiles_file = "../../data/examples/zinc_example.txt"

batch_size = 1024
vis_sample_num = 1024 * 32
vis_save_dir = "./clustering_results"

device = "cuda" if torch.cuda.is_available() else "cpu"

### Load SMILES

In [None]:
with open(smiles_file, "r", encoding="utf-8") as f:
    smiles_list = f.readlines()
smiles_list = [smiles.removesuffix("\n") for smiles in smiles_list]

print(f"Total SMILES loaded: {len(smiles_list)}")
for i in range(10):
    print(f"Example SMILES {i}: {smiles_list[i]}")

### Load Model Checkpoint

In [None]:
from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM
from data.tokenizer import GraphsGPTTokenizer

model = GraphsGPTForCausalLM.from_pretrained(model_name_or_path)
tokenizer = GraphsGPTTokenizer.from_pretrained(model_name_or_path)

print(model.state_dict().keys())
print(f"Total paramerters: {sum(x.numel() for x in model.parameters())}")

### Generate

In [None]:
from utils.operations.operation_list import split_list_with_yield
from utils.operations.operation_tensor import move_tensors_to_device

# generate fingerprint tokens
now_sample_num = 0
all_fingerprint_tokens = []

model.to(device)
model.eval()
with torch.no_grad():
    for batched_smiles in split_list_with_yield(smiles_list, batch_size):
        inputs = tokenizer.batch_encode(batched_smiles, return_tensors="pt")
        move_tensors_to_device(inputs, device)

        fingerprint_tokens = model.encode_to_fingerprints(**inputs)  # (batch_size, num_fingerprints, hidden_dim)

        # limit the number of samples
        this_sample_num = fingerprint_tokens.shape[0]
        append_sample_num = min(this_sample_num, vis_sample_num - now_sample_num)
        if append_sample_num > 0:
            now_sample_num += append_sample_num
            all_fingerprint_tokens.append(fingerprint_tokens)
        if append_sample_num < this_sample_num:
            print("Max sample num reached, stopping forwarding.")
            break

# fingerprint tokens to numpy
all_fingerprint_tokens = torch.cat(all_fingerprint_tokens, dim=0)
all_fingerprint_tokens = all_fingerprint_tokens.cpu().numpy()  # (vis_sample_num, num_fingerprints, hidden_dim)
num_fingerprint_tokens = fingerprint_tokens.shape[1]
print(f"Number of samples is {all_fingerprint_tokens.shape[0]}")
print(f"Number of fingerprints for each sample is {num_fingerprint_tokens}")

### UMAP Dimensionality Reduction

For reference, [here](README-Clustering.md) are some hyperparameters for UMAP and HDBSCAN.

In [None]:
# prepare features for UMAP
features = {}

if num_fingerprint_tokens > 1:  # per-fingerprint
    for i in range(num_fingerprint_tokens):
        features[f"fp_{i}"] = all_fingerprint_tokens[:, i, :]
features["fp_all"] = all_fingerprint_tokens.reshape(vis_sample_num, -1)  # aggregated fingerprints

# start UMAP
umap_features_for_clustering = {}
umap_features_for_visualization = {}

for key, value in tqdm(features.items(), desc="Computing UMAP features"):
    this_umap_features_for_clustering = umap.UMAP(
        n_neighbors=100,  # bigger value --> more compact
        min_dist=0.05,  # smaller value --> more compact
        n_components=2
    ).fit_transform(value)

    this_umap_features_for_visualization = umap.UMAP(
        n_neighbors=40, # same as above
        min_dist=0.7,
        n_components=2
    ).fit_transform(value)

    umap_features_for_clustering[key] = this_umap_features_for_clustering
    umap_features_for_visualization[key] = this_umap_features_for_visualization

### HDBSCAN Clustering

In [None]:
# HDBSCAN clustering
cluster_labels = {}

for key in tqdm(features.keys(), desc="Performing HDBSCAN clustering"):
    this_cluster_labels = hdbscan.HDBSCAN(
        min_cluster_size=32,  # bigger value --> less clusters
        min_samples=48,  # bigger value --> less clusters & more noise
        cluster_selection_epsilon=0.2,  # bigger value --> less noise
        alpha=1.0,
        gen_min_span_tree=True,
    ).fit_predict(umap_features_for_clustering[key])

    cluster_labels[key] = this_cluster_labels

### Visualization for Clusters

In [None]:
for key in tqdm(features.keys(), desc="Performing visualization"):
    # features
    noise_point_mask = (cluster_labels[key] == -1)
    noise_sample_num = np.sum(noise_point_mask)
    cluster_num = np.bincount(cluster_labels[key][~noise_point_mask]).shape[0]
    cluster_centroids = []
    for i in range(cluster_num):  # calculate for the cluster center
        this_cluster_point_mask = (cluster_labels[key] == i)
        this_cluster_point_feature = umap_features_for_visualization[key][this_cluster_point_mask]
        this_cluster_centroid = np.mean(this_cluster_point_feature, axis=0)
        cluster_centroids.append(this_cluster_centroid)

    # get image
    if not os.path.exists(vis_save_dir):
        os.makedirs(vis_save_dir)

    save_img_file = os.path.join(vis_save_dir, key + ".png")

    fig = plt.figure(figsize=(16, 12))
    ax = fig.add_subplot(111)

    ax.scatter(
        umap_features_for_visualization[key][noise_point_mask][:, 0],
        umap_features_for_visualization[key][noise_point_mask][:, 1],
        c="#CCCCCC",
        label="noise",
        alpha=0.9,
        s=16,
        linewidths=0
    )
    ax.scatter(
        umap_features_for_visualization[key][~noise_point_mask][:, 0],
        umap_features_for_visualization[key][~noise_point_mask][:, 1],
        c=cluster_labels[key][~noise_point_mask],
        label="clusters",
        alpha=0.9,
        s=16,
        linewidths=0,
        cmap="rainbow"
    )
    for i, centroid in enumerate(cluster_centroids):  # Add text label at the cluster centroid
        ax.text(
            centroid[0],
            centroid[1],
            str(i),
            color="black",
            fontsize=14,
            weight='bold',
            ha='center',
            va='center'
        )

    ax.set_title(f"{key} (total cluster {cluster_num}) (total noise sample {noise_sample_num})")
    ax.legend(loc="best")
    fig.tight_layout()
    fig.show()
    fig.savefig(save_img_file, dpi=480, bbox_inches="tight")
    plt.close(fig)

### Analysis of Different Clusters

We visualize the molecules in each clusters.

In [None]:
from utils.io import delete_file_or_dir, create_dir, save_mol_png
from utils.operations.operation_dict import reverse_dict

# read bond dict
bond_dict = tokenizer.bond_dict
inverse_bond_dict = reverse_dict(bond_dict, aggregate_same_results=False)

# visualize
for key in tqdm(features.keys(), desc="Iterating over fingerprint features"):
    # features    
    this_noise_point_mask = (cluster_labels[key] == -1)
    this_cluster_num = np.bincount(cluster_labels[key][~this_noise_point_mask]).shape[0]

    # get molecule images in different clusters
    for i in tqdm(range(this_cluster_num), desc="Iterating over clusters"):
        this_cluster_indices_list = []
        this_cluster_smiles_list = []

        # molecule info
        this_cluster_point_mask = (cluster_labels[key] == i)
        this_cluster_point_id = np.arange(vis_sample_num)[this_cluster_point_mask].tolist()

        # save mole image
        mole_vis_save_dir = os.path.join(vis_save_dir, f"moles_{key}", f"cluster_{i}")
        delete_file_or_dir(mole_vis_save_dir)
        create_dir(mole_vis_save_dir)

        save_img_cnt = 0
        for index in this_cluster_point_id:
            mol = tokenizer._convert_smiles_to_standard_molecule(smiles_list[index])

            if mol is not None:
                if save_img_cnt < 10:  # we visualize 10 samples at most for each cluster
                    save_img_file = os.path.join(mole_vis_save_dir, f"{index}.png")
                    save_mol_png(mol, save_img_file)
                    save_img_cnt += 1
                smiles = tokenizer._convert_molecule_to_standard_smiles(mol)
                this_cluster_indices_list.append(index)
                this_cluster_smiles_list.append(smiles)

        # save mole SMILES
        save_summary_file = os.path.join(mole_vis_save_dir, f"summary.csv")
        with open(save_summary_file, "w") as f:
            f.write(f"index,smiles\n")
            for j in range(len(this_cluster_indices_list)):
                f.write(f"{this_cluster_indices_list[j]},{smiles_list[j]}\n")

print(f"Visualization molecules saved to {vis_save_dir}.")
print(f"Molecule SMILES in of each cluster saved to {vis_save_dir}.")

All done.
You can check the saved files for further analysis.