In [1]:
import os
import json
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from PIL import Image, ImageDraw, ImageFont
import io
import numpy as np
from rdkit import RDLogger

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.*')

def load_molecule_info(json_path):
    """Load molecule information from JSON file"""
    with open(json_path, 'r') as f:
        molecule_info = json.load(f)
    
    return molecule_info

def create_molecule_from_smiles(smiles):
    """Create RDKit molecule from SMILES string"""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
            
        # Calculate 2D coordinates
        AllChem.Compute2DCoords(mol)
        return mol
    except Exception as e:
        print(f"Error creating molecule from SMILES: {e}")
        return None

def generate_molecule_image(mol, index, size=(600, 500), highlight_atoms=None):
    """Generate a high-quality image of the molecule with index number"""
    if mol is None:
        return None
        
    try:
        # Draw molecule with transparent background
        drawer = Draw.rdMolDraw2DCairo(size[0], size[1])
        drawer.drawOptions().clearBackground = False
        drawer.drawOptions().useBWAtomPalette = False  # Use color atoms
        drawer.drawOptions().additionalAtomLabelPadding = 0.15  # Add padding to labels
        drawer.drawOptions().bondLineWidth = 2.0  # Thicker bonds
        
        # Draw molecule
        if highlight_atoms:
            drawer.DrawMolecule(
                mol, 
                highlightAtoms=highlight_atoms,
                highlightAtomColors={i: (0.7, 0.9, 0.7) for i in highlight_atoms}
            )
        else:
            drawer.DrawMolecule(mol)
            
        drawer.FinishDrawing()
        png_data = drawer.GetDrawingText()
        
        # Convert to PIL Image
        img = Image.open(io.BytesIO(png_data))
        
        # Convert to RGBA if not already
        if img.mode != 'RGBA':
            img = img.convert('RGBA')
        
        # Make white background transparent
        datas = img.getdata()
        newData = []
        for item in datas:
            if item[0] > 240 and item[1] > 240 and item[2] > 240:
                newData.append((255, 255, 255, 0))
            else:
                newData.append(item)
        
        img.putdata(newData)
        
        # Add index number in a circle
        draw = ImageDraw.Draw(img)
        
        # Try to load a font, use default if not available
        try:
            font = ImageFont.truetype("arial.ttf", 40)
        except:
            font = ImageFont.load_default()
        
        # Draw circle with number
        circle_radius = 30
        circle_position = (50, 50)  # Top-left corner
        
        # Draw white filled circle
        draw.ellipse(
            [
                circle_position[0] - circle_radius, 
                circle_position[1] - circle_radius,
                circle_position[0] + circle_radius, 
                circle_position[1] + circle_radius
            ], 
            fill=(255, 255, 255, 255), 
            outline=(0, 0, 0, 255),
            width=2
        )
        
        # Draw index number
        text_width = draw.textlength(str(index), font=font)
        text_position = (
            circle_position[0] - text_width/2, 
            circle_position[1] - 20
        )
        draw.text(text_position, str(index), fill=(0, 0, 0, 255), font=font)
        
        return img
        
    except Exception as e:
        print(f"Error generating molecule image: {e}")
        return None

def generate_all_molecule_images(molecule_info, output_dir):
    """Generate images for all molecules in the info list"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Create a composite image with all molecules
    composite_width = 1500
    composite_height = 1200
    composite = Image.new('RGBA', (composite_width, composite_height), (255, 255, 255, 0))
    
    # Size for individual molecules in the composite
    mol_width = composite_width // 4
    mol_height = composite_height // 3
    
    # Track successful molecules
    successful_molecules = []
    
    for mol_data in molecule_info:
        index = mol_data["index"]
        smiles = mol_data["smiles"]
        
        # Create RDKit molecule
        mol = create_molecule_from_smiles(smiles)
        if mol is None:
            print(f"Skipping molecule {index}: Invalid SMILES")
            continue
        
        # Generate individual image
        img = generate_molecule_image(mol, index, size=(800, 600))
        if img is None:
            print(f"Skipping molecule {index}: Image generation failed")
            continue
        
        # Save individual image
        img_path = os.path.join(output_dir, f"molecule_{index}.png")
        img.save(img_path, "PNG")
        print(f"Saved molecule {index} to {img_path}")
        
        # Add to composite
        # Calculate position in grid (3 rows x 4 columns)
        row = (index - 1) // 4
        col = (index - 1) % 4
        
        # Resize for composite
        small_img = img.resize((mol_width, mol_height), Image.LANCZOS)
        
        # Calculate position in composite
        x_pos = col * mol_width
        y_pos = row * mol_height
        
        # Paste into composite
        composite.paste(small_img, (x_pos, y_pos), small_img)
        
        # Remember this molecule was successful
        successful_molecules.append(mol_data)
    
    # Save composite image
    composite_path = os.path.join(output_dir, "all_molecules.png")
    composite.save(composite_path, "PNG")
    print(f"Saved composite image to {composite_path}")
    
    # Save updated molecule info (only successful ones)
    updated_info_path = os.path.join(output_dir, "successful_molecules.json")
    with open(updated_info_path, 'w') as f:
        json.dump(successful_molecules, f, indent=2)
    
    return successful_molecules

def main():
    # Load molecule information from Step 1
    info_path = './visualization_files/selected_molecules.json'
    output_dir = './visualization_files/molecule_images'
    
    print(f"Loading molecule information from {info_path}...")
    try:
        molecule_info = load_molecule_info(info_path)
    except Exception as e:
        print(f"Error loading molecule information: {e}")
        return
    
    print(f"Found {len(molecule_info)} molecules to visualize")
    
    # Generate images for all molecules
    print("Generating molecule images...")
    successful_molecules = generate_all_molecule_images(molecule_info, output_dir)
    
    print(f"Successfully generated images for {len(successful_molecules)} molecules")
    print(f"Images saved to {output_dir}")
    print(f"You can now combine the embedding plot with these molecule images")

if __name__ == "__main__":
    main()

Loading molecule information from ./visualization_files/selected_molecules.json...
Found 12 molecules to visualize
Generating molecule images...
Skipping molecule 1: Invalid SMILES
Skipping molecule 2: Invalid SMILES
Error generating molecule image: module 'rdkit.Chem.Draw' has no attribute 'rdMolDraw2DCairo'
Skipping molecule 3: Image generation failed
Error generating molecule image: module 'rdkit.Chem.Draw' has no attribute 'rdMolDraw2DCairo'
Skipping molecule 4: Image generation failed
Skipping molecule 5: Invalid SMILES
Skipping molecule 6: Invalid SMILES
Error generating molecule image: module 'rdkit.Chem.Draw' has no attribute 'rdMolDraw2DCairo'
Skipping molecule 7: Image generation failed
Error generating molecule image: module 'rdkit.Chem.Draw' has no attribute 'rdMolDraw2DCairo'
Skipping molecule 8: Image generation failed
Error generating molecule image: module 'rdkit.Chem.Draw' has no attribute 'rdMolDraw2DCairo'
Skipping molecule 9: Image generation failed
Skipping molecul