In [1]:
import os
import json
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from PIL import Image, ImageDraw, ImageFont
import io
import numpy as np
from rdkit import RDLogger
import matplotlib.cm as cm
import matplotlib.pyplot as plt

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.*')

def load_cluster_info(json_path):
    """Load cluster information from JSON file"""
    with open(json_path, 'r') as f:
        cluster_info = json.load(f)
    
    return cluster_info

def create_molecule_from_smiles(smiles):
    """Create RDKit molecule from SMILES string"""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
            
        # Calculate 2D coordinates
        AllChem.Compute2DCoords(mol)
        return mol
    except Exception as e:
        print(f"Error creating molecule from SMILES: {e}")
        return None

def generate_molecule_image(mol, size=(300, 250), highlight_atoms=None):
    """Generate an image of the molecule"""
    if mol is None:
        return None
        
    try:
        # Use RDKit's MolToImage function
        img = Draw.MolToImage(
            mol, 
            size=size, 
            kekulize=True, 
            fitImage=True,
            highlightAtoms=highlight_atoms,
            highlightColor=(0.7, 0.9, 0.7) if highlight_atoms else None
        )
        
        # Convert to RGBA if not already
        if img.mode != 'RGBA':
            img = img.convert('RGBA')
        
        # Make white background transparent
        datas = img.getdata()
        newData = []
        for item in datas:
            if item[0] > 240 and item[1] > 240 and item[2] > 240:
                newData.append((255, 255, 255, 0))
            else:
                newData.append(item)
        
        img.putdata(newData)
        return img
    except Exception as e:
        print(f"Error generating molecule image: {e}")
        return None

def create_property_colorbar(min_val, max_val, property_name, width=30, height=200):
    """Create a vertical colorbar for a property"""
    # Create gradient image
    gradient = np.linspace(0, 1, height).reshape(-1, 1)
    gradient = np.tile(gradient, (1, width))
    
    # Create colormap
    colors = [
        (1.0, 1.0, 0.0),      # light yellow
        (1.0, 0.7, 0.0),      # gold/amber
        (1.0, 0.4, 0.0),      # orange
        (1.0, 0.0, 0.0),      # red
        (0.8, 0.0, 0.2),      # dark red
        (0.5, 0.0, 0.5)       # purple
    ]
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list("property_colormap", colors)
    
    # Apply colormap
    colored_gradient = (cmap(gradient) * 255).astype(np.uint8)
    
    # Convert to PIL image
    img = Image.fromarray(colored_gradient)
    
    # Add labels
    draw = ImageDraw.Draw(img)
    
    # Try to load a font, use default if not available
    try:
        font = ImageFont.truetype("arial.ttf", 12)
    except:
        font = ImageFont.load_default()
    
    # Add min and max labels
    draw.text((width+2, height-12), f"{min_val:.1f}", fill=(0, 0, 0), font=font)
    draw.text((width+2, 0), f"{max_val:.1f}", fill=(0, 0, 0), font=font)
    
    # Add property name vertically
    property_img = Image.new('RGBA', (30, height), (255, 255, 255, 0))
    prop_draw = ImageDraw.Draw(property_img)
    prop_draw.text((0, height//2), property_name, fill=(0, 0, 0), font=font, anchor="lm")
    
    # Rotate property name
    property_img = property_img.rotate(90, expand=True)
    
    # Create final image with both parts
    final_width = img.width + property_img.height + 5
    final_height = max(img.height, property_img.width)
    
    final_img = Image.new('RGBA', (final_width, final_height), (255, 255, 255, 0))
    final_img.paste(img, (0, 0))
    final_img.paste(property_img, (img.width + 5, (final_height - property_img.width)//2), property_img)
    
    return final_img

def create_cluster_collage(cluster_data, output_path, size=(900, 500)):
    """Create a collage of molecules for a single cluster with property information"""
    # Create cluster ID
    cluster_id = cluster_data['cluster_id']
    
    # Get representative molecules
    representatives = cluster_data['representatives']
    
    if not representatives:
        print(f"No representatives for cluster {cluster_id}")
        return None
        
    # Create collage background
    collage = Image.new('RGBA', size, (255, 255, 255, 0))
    draw = ImageDraw.Draw(collage)
    
    # Try to load a font, use default if not available
    try:
        title_font = ImageFont.truetype("arial.ttf", 36)
        info_font = ImageFont.truetype("arial.ttf", 20)
        prop_font = ImageFont.truetype("arial.ttf", 14)
    except:
        title_font = ImageFont.load_default()
        info_font = ImageFont.load_default()
        prop_font = ImageFont.load_default()
    
    # Add cluster title
    draw.text((size[0]//2, 30), f"Cluster {cluster_id}", 
              fill=(0, 0, 0, 255), font=title_font, anchor="mm")
              
    # Add cluster size info
    cluster_size = cluster_data['size']
    draw.text((size[0]//2, 70), f"{cluster_size} molecules", 
              fill=(0, 0, 0, 200), font=info_font, anchor="mm")
    
    # Generate molecule images and collect property data
    mol_images = []
    properties_list = []
    
    for rep in representatives:
        smiles = rep.get('smiles')
        if not smiles:
            continue
            
        mol = create_molecule_from_smiles(smiles)
        if mol is None:
            continue
            
        img = generate_molecule_image(mol, size=(250, 200))
        if img is not None:
            mol_images.append(img)
            # Get properties if available
            if 'properties' in rep:
                properties_list.append(rep['properties'])
            else:
                # Default empty properties
                properties_list.append({})
    
    # Place molecules in a row
    if not mol_images:
        print(f"No valid molecule images for cluster {cluster_id}")
        return None
        
    n_mols = len(mol_images)
    mol_width = min(250, (size[0] - 100) // n_mols)  # Allow space for property bars
    mol_height = min(200, size[1] - 150)  # Allow space for headers and property values
    
    # Calculate spacing
    total_mol_width = n_mols * mol_width
    margin = (size[0] - 100 - total_mol_width) // (n_mols + 1)  # Reserve space for properties
    
    # Find min/max values for properties across all representatives
    property_ranges = {}
    for props in properties_list:
        for prop_name, value in props.items():
            if prop_name not in property_ranges:
                property_ranges[prop_name] = {'min': float('inf'), 'max': float('-inf')}
            property_ranges[prop_name]['min'] = min(property_ranges[prop_name]['min'], value)
            property_ranges[prop_name]['max'] = max(property_ranges[prop_name]['max'], value)
    
    # Place molecules and add property information
    for i, (img, props) in enumerate(zip(mol_images, properties_list)):
        # Resize if needed
        if img.width != mol_width or img.height != mol_height:
            img = img.resize((mol_width, mol_height), Image.LANCZOS)
        
        # Calculate position
        x_pos = margin + i * (mol_width + margin)
        y_pos = 100  # Below the title
        
        # Paste into collage
        collage.paste(img, (x_pos, y_pos), img)
        
        # Add property information below each molecule
        y_offset = y_pos + mol_height + 10
        for prop_name, value in sorted(props.items()):
            if prop_name in ['MW', 'LogP', 'TPSA', 'HBA', 'HBD']:
                if prop_name in ['HBA', 'HBD']:
                    # Integer values
                    text = f"{prop_name}: {int(value)}"
                else:
                    # Floating point values
                    text = f"{prop_name}: {value:.1f}"
                    
                draw.text((x_pos + mol_width//2, y_offset), text, 
                          fill=(0, 0, 0, 200), font=prop_font, anchor="mm")
                y_offset += 20
    
    # Add property colorbars on the right side
    x_bar = size[0] - 80
    y_bar = 120
    colorbar_height = 150
    
    # Select key properties to display
    key_properties = ['MW', 'LogP', 'TPSA']
    
    for prop in key_properties:
        if prop in property_ranges:
            min_val = property_ranges[prop]['min']
            max_val = property_ranges[prop]['max']
            
            # Create and paste colorbar
            colorbar = create_property_colorbar(min_val, max_val, prop, height=colorbar_height)
            collage.paste(colorbar, (x_bar, y_bar), colorbar)
            y_bar += colorbar_height + 30
    
    # Save collage
    collage.save(output_path, "PNG")
    print(f"Saved cluster {cluster_id} collage to {output_path}")
    
    return collage

def create_all_cluster_collages(cluster_info, output_dir):
    """Create collages for all clusters"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Get cluster colors for consistency
    n_clusters = len(cluster_info)
    cluster_colors = cm.tab10(np.linspace(0, 1, max(10, n_clusters)))
    
    collages = []
    
    # Create individual collages
    for i, cluster_data in enumerate(cluster_info):
        cluster_id = cluster_data['cluster_id']
        
        # Skip clusters with no representatives
        if not cluster_data['representatives']:
            continue
        
        # Create output path
        output_path = os.path.join(output_dir, f"cluster_{cluster_id}.png")
        
        # Create collage
        collage = create_cluster_collage(cluster_data, output_path)
        if collage is not None:
            collages.append((cluster_id, collage))
    
    # Create combined image of all clusters
    if collages:
        # Calculate grid size
        n_cols = min(3, len(collages))
        n_rows = (len(collages) + n_cols - 1) // n_cols  # Ceiling division
        
        # Create combined image
        grid_width = 900 * n_cols
        grid_height = 500 * n_rows
        grid_img = Image.new('RGBA', (grid_width, grid_height), (255, 255, 255, 0))
        
        # Place collages in grid
        for i, (cluster_id, collage) in enumerate(collages):
            row = i // n_cols
            col = i % n_cols
            
            x_pos = col * 900
            y_pos = row * 500
            
            grid_img.paste(collage, (x_pos, y_pos), collage)
        
        # Save grid
        grid_path = os.path.join(output_dir, "all_clusters.png")
        grid_img.save(grid_path, "PNG")
        print(f"Saved combined cluster image to {grid_path}")

def main():
    # Load cluster information
    info_path = './visualization_files/cluster_info.json'
    output_dir = './visualization_files/cluster_images'
    
    print(f"Loading cluster information from {info_path}...")
    try:
        cluster_info = load_cluster_info(info_path)
    except Exception as e:
        print(f"Error loading cluster information: {e}")
        return
    
    print(f"Found {len(cluster_info)} clusters")
    
    # Create cluster collages
    print("Generating cluster collages with property information...")
    create_all_cluster_collages(cluster_info, output_dir)
    
    print("Done! You can now combine the clustering visualization with the molecule collages.")

if __name__ == "__main__":
    main()

Loading cluster information from ./visualization_files/cluster_info.json...
Found 6 clusters
Generating cluster collages with property information...
Saved cluster 1 collage to ./visualization_files/cluster_images\cluster_1.png
Saved cluster 2 collage to ./visualization_files/cluster_images\cluster_2.png
Saved cluster 3 collage to ./visualization_files/cluster_images\cluster_3.png
Saved cluster 4 collage to ./visualization_files/cluster_images\cluster_4.png
Saved cluster 5 collage to ./visualization_files/cluster_images\cluster_5.png
No valid molecule images for cluster 6
Saved combined cluster image to ./visualization_files/cluster_images\all_clusters.png
Done! You can now combine the clustering visualization with the molecule collages.
