In [1]:
import os
import json
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from PIL import Image, ImageDraw, ImageFont
import io
import numpy as np
from rdkit import RDLogger
import matplotlib.cm as cm
import matplotlib.pyplot as plt

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.*')

def load_cluster_info(json_path):
    """Load cluster information from JSON file"""
    with open(json_path, 'r') as f:
        cluster_info = json.load(f)
    
    return cluster_info

def create_molecule_from_smiles(smiles):
    """Create RDKit molecule from SMILES string"""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
            
        # Calculate 2D coordinates
        AllChem.Compute2DCoords(mol)
        return mol
    except Exception as e:
        print(f"Error creating molecule from SMILES: {e}")
        return None

def generate_molecule_image(mol, size=(300, 250), highlight_atoms=None):
    """Generate an image of the molecule"""
    if mol is None:
        return None
        
    try:
        # Use RDKit's MolToImage function
        img = Draw.MolToImage(
            mol, 
            size=size, 
            kekulize=True, 
            fitImage=True,
            highlightAtoms=highlight_atoms,
            highlightColor=(0.7, 0.9, 0.7) if highlight_atoms else None
        )
        
        # Convert to RGBA if not already
        if img.mode != 'RGBA':
            img = img.convert('RGBA')
        
        # Make white background transparent
        datas = img.getdata()
        newData = []
        for item in datas:
            if item[0] > 240 and item[1] > 240 and item[2] > 240:
                newData.append((255, 255, 255, 0))
            else:
                newData.append(item)
        
        img.putdata(newData)
        return img
    except Exception as e:
        print(f"Error generating molecule image: {e}")
        return None

def create_cluster_collage(cluster_data, output_path, size=(800, 400)):
    """Create a collage of molecules for a single cluster"""
    # Create cluster ID
    cluster_id = cluster_data['cluster_id']
    
    # Get representative molecules
    representatives = cluster_data['representatives']
    
    if not representatives:
        print(f"No representatives for cluster {cluster_id}")
        return None
    
    # Create collage background
    collage = Image.new('RGBA', size, (255, 255, 255, 0))
    draw = ImageDraw.Draw(collage)
    
    # Try to load a font, use default if not available
    try:
        title_font = ImageFont.truetype("arial.ttf", 36)
        info_font = ImageFont.truetype("arial.ttf", 20)
    except:
        title_font = ImageFont.load_default()
        info_font = ImageFont.load_default()
    
    # Add cluster title
    draw.text((size[0]//2, 30), f"Cluster {cluster_id}", 
              fill=(0, 0, 0, 255), font=title_font, anchor="mm")
              
    # Add cluster size info
    cluster_size = cluster_data['size']
    draw.text((size[0]//2, 70), f"{cluster_size} molecules", 
              fill=(0, 0, 0, 200), font=info_font, anchor="mm")
    
    # Generate molecule images
    mol_images = []
    for rep in representatives:
        smiles = rep.get('smiles')
        if not smiles:
            continue
            
        mol = create_molecule_from_smiles(smiles)
        if mol is None:
            continue
            
        img = generate_molecule_image(mol, size=(250, 200))
        if img is not None:
            mol_images.append(img)
    
    # Place molecules in a row
    if not mol_images:
        print(f"No valid molecule images for cluster {cluster_id}")
        return None
        
    n_mols = len(mol_images)
    mol_width = min(size[0] // n_mols, 250)
    mol_height = min(size[1] - 100, 200)
    
    # Calculate spacing
    total_mol_width = n_mols * mol_width
    margin = (size[0] - total_mol_width) // (n_mols + 1)
    
    # Place molecules
    for i, img in enumerate(mol_images):
        # Resize if needed
        if img.width != mol_width or img.height != mol_height:
            img = img.resize((mol_width, mol_height), Image.LANCZOS)
        
        # Calculate position
        x_pos = margin + i * (mol_width + margin)
        y_pos = 100  # Below the title
        
        # Paste into collage
        collage.paste(img, (x_pos, y_pos), img)
    
    # Save collage
    collage.save(output_path, "PNG")
    print(f"Saved cluster {cluster_id} collage to {output_path}")
    
    return collage

def create_all_cluster_collages(cluster_info, output_dir):
    """Create collages for all clusters"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Get cluster colors for consistency
    n_clusters = len(cluster_info)
    cluster_colors = cm.tab10(np.linspace(0, 1, max(10, n_clusters)))
    
    collages = []
    
    # Create individual collages
    for i, cluster_data in enumerate(cluster_info):
        cluster_id = cluster_data['cluster_id']
        
        # Skip clusters with no representatives
        if not cluster_data['representatives']:
            continue
        
        # Create output path
        output_path = os.path.join(output_dir, f"cluster_{cluster_id}.png")
        
        # Create collage
        collage = create_cluster_collage(cluster_data, output_path)
        if collage is not None:
            collages.append((cluster_id, collage))
    
    # Create combined image of all clusters
    if collages:
        # Calculate grid size
        n_cols = min(3, len(collages))
        n_rows = (len(collages) + n_cols - 1) // n_cols  # Ceiling division
        
        # Create combined image
        grid_width = 800 * n_cols
        grid_height = 400 * n_rows
        grid_img = Image.new('RGBA', (grid_width, grid_height), (255, 255, 255, 0))
        
        # Place collages in grid
        for i, (cluster_id, collage) in enumerate(collages):
            row = i // n_cols
            col = i % n_cols
            
            x_pos = col * 800
            y_pos = row * 400
            
            grid_img.paste(collage, (x_pos, y_pos), collage)
        
        # Save grid
        grid_path = os.path.join(output_dir, "all_clusters.png")
        grid_img.save(grid_path, "PNG")
        print(f"Saved combined cluster image to {grid_path}")

def main():
    # Load cluster information
    info_path = './visualization_files/cluster_info.json'
    output_dir = './visualization_files/cluster_images'
    
    print(f"Loading cluster information from {info_path}...")
    try:
        cluster_info = load_cluster_info(info_path)
    except Exception as e:
        print(f"Error loading cluster information: {e}")
        return
    
    print(f"Found {len(cluster_info)} clusters")
    
    # Create cluster collages
    print("Generating cluster collages...")
    create_all_cluster_collages(cluster_info, output_dir)
    
    print("Done! You can now combine the clustering visualization with the molecule collages.")

if __name__ == "__main__":
    main()

Loading cluster information from ./visualization_files/cluster_info.json...
Found 8 clusters
Generating cluster collages...
Saved cluster 1 collage to ./visualization_files/cluster_images\cluster_1.png
Saved cluster 2 collage to ./visualization_files/cluster_images\cluster_2.png
Saved cluster 3 collage to ./visualization_files/cluster_images\cluster_3.png
No valid molecule images for cluster 4
Saved cluster 5 collage to ./visualization_files/cluster_images\cluster_5.png
Saved cluster 6 collage to ./visualization_files/cluster_images\cluster_6.png
Saved cluster 7 collage to ./visualization_files/cluster_images\cluster_7.png
Saved cluster 8 collage to ./visualization_files/cluster_images\cluster_8.png
Saved combined cluster image to ./visualization_files/cluster_images\all_clusters.png
Done! You can now combine the clustering visualization with the molecule collages.
