In [8]:
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import squareform
import matplotlib.pyplot as plt
import os

def parse_pdbqt_to_mols(filename):
    """Parse PDBQT file and extract all docking poses as RDKit molecules."""
    poses = []
    current_pdb = []
    
    with open(filename, 'r') as f:
        lines = f.readlines()
    
    in_model = False
    for line in lines:
        if line.startswith('MODEL'):
            current_pdb = []
            in_model = True
        elif line.startswith('ENDMDL'):
            if current_pdb:
                # Convert accumulated PDB lines to molecule
                pdb_block = ''.join(current_pdb)
                mol = Chem.MolFromPDBBlock(pdb_block, removeHs=False)
                if mol is not None:
                    poses.append(mol)
            in_model = False
        elif in_model and (line.startswith('ATOM') or line.startswith('HETATM')):
            # Convert PDBQT line to PDB format (remove extra columns)
            pdb_line = line[:66] + '\n'
            current_pdb.append(pdb_line)
    
    return poses

def calculate_rmsd_matrix(mols, use_symmetry=True):
    """Calculate pairwise RMSD matrix using RDKit with proper alignment."""
    n = len(mols)
    rmsd_matrix = np.zeros((n, n))
    
    print(f"   Calculating RMSD for {n} poses...")
    for i in range(n):
        for j in range(i+1, n):
            # Calculate RMSD with alignment
            if use_symmetry:
                # This considers molecular symmetry
                rmsd = AllChem.GetBestRMS(mols[i], mols[j])
            else:
                # Simple alignment without symmetry
                rmsd = AllChem.AlignMol(mols[j], mols[i])
            
            rmsd_matrix[i, j] = rmsd
            rmsd_matrix[j, i] = rmsd
        
        if (i + 1) % 10 == 0:
            print(f"   Progress: {i+1}/{n} poses processed")
    
    return rmsd_matrix

def cluster_poses(rmsd_matrix, threshold=2.0, method='average'):
    """Perform hierarchical clustering on RMSD matrix."""
    # Convert matrix to condensed form for linkage
    condensed = squareform(rmsd_matrix)
    
    # Perform hierarchical clustering
    linkage_matrix = linkage(condensed, method=method)
    
    # Get cluster assignments
    clusters = fcluster(linkage_matrix, threshold, criterion='distance')
    
    return linkage_matrix, clusters

def plot_dendrogram(linkage_matrix, output_dir, filename='dendrogram.png'):
    """Plot and save dendrogram."""
    plt.figure(figsize=(12, 6))
    dendrogram(linkage_matrix, leaf_font_size=10)
    plt.xlabel('Pose Index', fontsize=12)
    plt.ylabel('RMSD (√Ö)', fontsize=12)
    plt.title('Hierarchical Clustering Dendrogram of Docking Poses', fontsize=14)
    plt.axhline(y=2.0, color='r', linestyle='--', label='Default threshold (2.0 √Ö)')
    plt.legend()
    plt.tight_layout()
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"   Dendrogram saved to {filepath}")

def plot_rmsd_heatmap(rmsd_matrix, clusters, output_dir, filename='rmsd_heatmap.png'):
    """Plot RMSD matrix as heatmap with cluster annotations."""
    # Sort by cluster
    sorted_indices = np.argsort(clusters)
    sorted_matrix = rmsd_matrix[np.ix_(sorted_indices, sorted_indices)]
    sorted_clusters = clusters[sorted_indices]
    
    fig, ax = plt.subplots(figsize=(10, 9))
    im = ax.imshow(sorted_matrix, cmap='viridis', aspect='auto')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('RMSD (√Ö)', rotation=270, labelpad=20, fontsize=12)
    
    # Add cluster boundary lines
    cluster_changes = np.where(np.diff(sorted_clusters))[0] + 0.5
    for change in cluster_changes:
        ax.axhline(y=change, color='red', linestyle='-', linewidth=2)
        ax.axvline(x=change, color='red', linestyle='-', linewidth=2)
    
    ax.set_xlabel('Pose Index (sorted by cluster)', fontsize=12)
    ax.set_ylabel('Pose Index (sorted by cluster)', fontsize=12)
    ax.set_title('RMSD Matrix Heatmap', fontsize=14)
    plt.tight_layout()
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"   RMSD heatmap saved to {filepath}")

def save_cluster_representatives(mols, clusters, rmsd_matrix, output_dir, output_prefix='cluster'):
    """Save representative poses from each cluster as PDB files."""
    unique_clusters = sorted(np.unique(clusters))
    representatives = []
    
    for cluster_id in unique_clusters:
        members = np.where(clusters == cluster_id)[0]
        
        # Find representative (pose closest to cluster center)
        if len(members) > 1:
            cluster_rmsd = rmsd_matrix[np.ix_(members, members)]
            avg_rmsd = cluster_rmsd.mean(axis=1)
            representative_idx = members[np.argmin(avg_rmsd)]
        else:
            representative_idx = members[0]
        
        representatives.append(representative_idx)
        
        # Save representative structure
        output_file = os.path.join(output_dir, f"{output_prefix}_{cluster_id}_rep_pose{representative_idx}.pdb")
        Chem.MolToPDBFile(mols[representative_idx], output_file)
    
    return representatives

def save_cluster_info(clusters, rmsd_matrix, representatives, output_dir, output_file='cluster_info.txt'):
    """Save detailed clustering results to file."""
    n_clusters = len(np.unique(clusters))
    filepath = os.path.join(output_dir, output_file)
    
    with open(filepath, 'w') as f:
        f.write("=" * 70 + "\n")
        f.write("RMSD-Based Clustering Results\n")
        f.write("=" * 70 + "\n\n")
        f.write(f"Total poses: {len(clusters)}\n")
        f.write(f"Number of clusters: {n_clusters}\n")
        f.write(f"\nCluster Details:\n")
        f.write("-" * 70 + "\n")
        
        for i, cluster_id in enumerate(sorted(np.unique(clusters))):
            members = np.where(clusters == cluster_id)[0]
            representative = representatives[i]
            
            f.write(f"\n{'='*70}\n")
            f.write(f"Cluster {cluster_id}:\n")
            f.write(f"{'='*70}\n")
            f.write(f"  Number of poses: {len(members)}\n")
            f.write(f"  Member poses: {list(members)}\n")
            f.write(f"  Representative pose: {representative}\n")
            
            # Calculate intra-cluster statistics
            if len(members) > 1:
                cluster_rmsd = rmsd_matrix[np.ix_(members, members)]
                avg_rmsd_to_rep = rmsd_matrix[representative, members].mean()
                max_rmsd = cluster_rmsd.max()
                avg_rmsd = cluster_rmsd[np.triu_indices_from(cluster_rmsd, k=1)].mean()
                
                f.write(f"  Average intra-cluster RMSD: {avg_rmsd:.3f} √Ö\n")
                f.write(f"  Maximum intra-cluster RMSD: {max_rmsd:.3f} √Ö\n")
                f.write(f"  Average RMSD to representative: {avg_rmsd_to_rep:.3f} √Ö\n")
            else:
                f.write(f"  Single pose cluster\n")
            
            f.write(f"  Output file: cluster_{cluster_id}_rep_pose{representative}.pdb\n")
    
    print(f"   Cluster information saved to {filepath}")

def plot_cluster_size_distribution(clusters, output_dir, filename='cluster_distribution.png'):
    """Plot distribution of cluster sizes."""
    unique, counts = np.unique(clusters, return_counts=True)
    
    plt.figure(figsize=(10, 6))
    plt.bar(unique, counts, color='steelblue', edgecolor='black')
    plt.xlabel('Cluster ID', fontsize=12)
    plt.ylabel('Number of Poses', fontsize=12)
    plt.title('Cluster Size Distribution', fontsize=14)
    plt.xticks(unique)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"   Cluster distribution saved to {filepath}")

# Main execution
if __name__ == "__main__":
    # Configuration
    input_file = "docking_results/pocket3  _docked_poses.pdbqt"  # Change to your file name
    output_dir = "rmsd_results"  # Output directory for all results
    rmsd_threshold = 2.0  # RMSD threshold in Angstroms for clustering
    clustering_method = 'average'  # 'single', 'complete', 'average', 'ward'
    use_symmetry = True  # Consider molecular symmetry in RMSD calculation
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    print(f"\nüìÅ Output directory: {output_dir}")
    
    print("\n" + "=" * 70)
    print(" RMSD-Based Clustering of Docking Poses (RDKit)")
    print("=" * 70)
    
    # Parse PDBQT file
    print(f"\n[1/5] Reading poses from {input_file}...")
    try:
        mols = parse_pdbqt_to_mols(input_file)
        print(f"   ‚úì Successfully loaded {len(mols)} poses")
    except Exception as e:
        print(f"   ‚úó Error reading file: {e}")
        exit(1)
    
    if len(mols) < 2:
        print("   ‚úó Error: Need at least 2 poses for clustering")
        exit(1)
    
    # Calculate RMSD matrix
    print(f"\n[2/5] Calculating RMSD matrix (symmetry={'ON' if use_symmetry else 'OFF'})...")
    rmsd_matrix = calculate_rmsd_matrix(mols, use_symmetry=use_symmetry)
    non_zero = rmsd_matrix[rmsd_matrix > 0]
    print(f"   ‚úì RMSD range: {non_zero.min():.3f} - {non_zero.max():.3f} √Ö")
    print(f"   ‚úì Average RMSD: {non_zero.mean():.3f} √Ö")
    
    # Perform clustering
    print(f"\n[3/5] Performing hierarchical clustering...")
    print(f"   Method: {clustering_method}")
    print(f"   Threshold: {rmsd_threshold} √Ö")
    linkage_matrix, clusters = cluster_poses(rmsd_matrix, threshold=rmsd_threshold, 
                                             method=clustering_method)
    n_clusters = len(np.unique(clusters))
    print(f"   ‚úì Identified {n_clusters} clusters")
    
    # Save representative structures
    print(f"\n[4/5] Saving representative structures...")
    representatives = save_cluster_representatives(mols, clusters, rmsd_matrix, output_dir)
    print(f"   ‚úì Saved {len(representatives)} representative PDB files")
    
    # Generate outputs
    print(f"\n[5/5] Generating visualizations and reports...")
    plot_dendrogram(linkage_matrix, output_dir)
    plot_rmsd_heatmap(rmsd_matrix, clusters, output_dir)
    plot_cluster_size_distribution(clusters, output_dir)
    save_cluster_info(clusters, rmsd_matrix, representatives, output_dir)
    
    # Print summary
    print("\n" + "=" * 70)
    print(" Clustering Summary")
    print("=" * 70)
    for i, cluster_id in enumerate(sorted(np.unique(clusters))):
        members = np.where(clusters == cluster_id)[0]
        rep = representatives[i]
        print(f"Cluster {cluster_id}: {len(members):2d} poses | "
              f"Representative: Pose {rep:2d} | Members: {list(members)}")
    
    print("\n" + "=" * 70)
    print(" Output Files Generated")
    print("=" * 70)
    print(f"  üìä {output_dir}/dendrogram.png - Hierarchical clustering tree")
    print(f"  üìä {output_dir}/rmsd_heatmap.png - RMSD distance matrix visualization")
    print(f"  üìä {output_dir}/cluster_distribution.png - Cluster size distribution")
    print(f"  üìÑ {output_dir}/cluster_info.txt - Detailed cluster information")
    print(f"  üß¨ {output_dir}/cluster_*_rep_pose*.pdb - {n_clusters} representative structures")
    print("\n‚úì Clustering complete!\n")


üìÅ Output directory: rmsd_results

 RMSD-Based Clustering of Docking Poses (RDKit)

[1/5] Reading poses from docking_results/pocket3  _docked_poses.pdbqt...
   ‚úì Successfully loaded 10 poses

[2/5] Calculating RMSD matrix (symmetry=ON)...
   Calculating RMSD for 10 poses...
   Progress: 10/10 poses processed
   ‚úì RMSD range: 0.010 - 2.171 √Ö
   ‚úì Average RMSD: 0.987 √Ö

[3/5] Performing hierarchical clustering...
   Method: average
   Threshold: 2.0 √Ö
   ‚úì Identified 2 clusters

[4/5] Saving representative structures...
   ‚úì Saved 2 representative PDB files

[5/5] Generating visualizations and reports...
   Dendrogram saved to rmsd_results/dendrogram.png
   RMSD heatmap saved to rmsd_results/rmsd_heatmap.png
   Cluster distribution saved to rmsd_results/cluster_distribution.png
   Cluster information saved to rmsd_results/cluster_info.txt

 Clustering Summary
Cluster 1:  2 poses | Representative: Pose  7 | Members: [np.int64(7), np.int64(8)]
Cluster 2:  8 poses | Represen