# Molecular Disconnection Analysis Tool

This notebook provides a comprehensive workflow for analyzing and visualizing potential disconnection sites in a target molecule.

## Getting Started

1. **Configure your molecule**:
   - Set `working_folder` to the 'paper_code' path
   - Specify `molecule_name` for file naming 
   - Provide `selected_smiles` as the target molecule with atom mapping
   - You can also adjust the prompt dataset (`selected_dataset`) from paroutes to uspto50k. However, we only provide precalculated examples for paroutes, so you would need to do inference yourself.

2. **Run the prompt generation**:
   - The tool populates a template with your molecule's SMILES
   - Save the populated prompt at `<molecule_name>/position_model/populated_prompt.md`
   - Submit this prompt to your LLM of choice

3. **Process the results**:
   - Save the LLM's JSON response to `<molecule_name>/position_model/response.json`
   - The tool will generate:
     - Visual report with highlighted disconnection sites
     - CSV summary of all disconnections
     - LaTeX table for publication
     - Full markdown report convertible to PDF

4. **Generate PDF report**:
   - Use the provided pandoc command to convert the markdown report to PDF
   - Visualizations color-code disconnections by priority (red = highest priority)

This workflow enables rapid, visual evaluation of synthetically-relevant disconnection points in complex molecules.

In [None]:
# imports

from rdkit import Chem
from rdkit.Chem import Draw, AllChem
import seaborn as sns
from typing import List, Dict
import os
import json
from typing import List, Dict, Optional, Tuple
import pandas as pd
import os
import re

In [None]:
# General Settings
molecule_name = "LEI_515"
selected_smiles = "C[CH2:1][C:2]([C:3](=[O:4])[CH2:5][S:6](=[O:7])[c:8]1[cH:9][cH:10][c:11]([C:12](=[O:13])[N:14]2[CH2:15][CH2:16][N:17]([c:18]3[cH:19][c:20]([Cl:21])[cH:22][cH:23][cH:24]3)[C@@H:25]([CH3:26])[C@@H:27]2[CH3:28])[cH:29][c:30]1[Cl:31])([F:32])[F:33]"

# other settings
FIGURE_SIZE = (600, 400)
selected_dataset = "paroutes" # pick either "paroutes" or "uspto50k"

In [None]:
current_directory = os.getcwd()
working_folder = os.path.abspath(os.path.join(current_directory, '..'))
print(f"Working Directory: {working_folder}")

# set paths
example_folder = working_folder + "/examples/"

if selected_dataset == "paroutes":
    position_prompt_template = working_folder + "/prompts/position_model_paroutes.md"
elif selected_dataset == "uspto50k":
    position_prompt_template = working_folder + "/prompts/position_model_uspto50k.md"
else:
    raise ValueError("selected_dataset must be either 'paroutes' or 'uspto50k'")

save_path = example_folder + molecule_name
position_model_path = save_path + "/position_model/"
transition_model_path = save_path + "/transition_model/"

# report css file path
REPORT_CSS_FILE_PATH = working_folder + "/data/report.css"

In [None]:
current_directory = os.getcwd()
working_folder = os.path.abspath(os.path.join(current_directory, '..'))
print(f"Working Directory: {working_folder}")

# set paths using os.path.join for cross-platform compatibility
# This creates the path strings you need.
example_folder = os.path.join(working_folder, "examples")
position_prompt_template = os.path.join(working_folder, "prompts", "position_model_paroutes.md")
save_path = os.path.join(example_folder, molecule_name)
position_model_path = os.path.join(save_path, "position_model")
transition_model_path = os.path.join(save_path, "transition_model")

# report css file path
REPORT_CSS_FILE_PATH = os.path.join(working_folder, "data", "report.css")

# print paths
print(f"\nSave Path: {save_path}")
print(f"Position Model Path: {position_model_path}")
print(f"CSS Path: {REPORT_CSS_FILE_PATH}")

In [None]:
# make sure that the folder exists
import os
os.makedirs(save_path, exist_ok=True)
os.makedirs(position_model_path, exist_ok=True)
os.makedirs(transition_model_path, exist_ok=True)
print(f"Saving to {save_path}")

In [None]:
# Add atom mapping by index
selected_mol = Chem.MolFromSmiles(selected_smiles)
for atom in selected_mol.GetAtoms():
    atom.SetAtomMapNum(atom.GetIdx())

selected_mol_smiles = Chem.MolToSmiles(selected_mol, canonical=True)

# save molecules smiles to save_path/selected_mol.smiles
with open(os.path.join(save_path, "selected_mol.smiles"), "w") as f:
    f.write(selected_mol_smiles)

selected_mol_smiles

In [None]:
def draw_molecule_with_highlights(
            mol: Chem.Mol, 
            highlight_atoms: Optional[List[int]] = None, 
            highlight_bonds: Optional[List[int]] = None, 
            atom_colors: Optional[Dict] = None, 
            bond_colors: Optional[Dict] = None, 
            filepath: str = 'molecule.svg', 
            size: Tuple[int, int] = (600, 400),
):
    """Draws a molecule to an SVG file with specified highlights."""
    drawer = Draw.rdMolDraw2D.MolDraw2DSVG(size[0], size[1])
    drawer.drawOptions().bondLineWidth = 2
    drawer.drawOptions().useBWAtomPalette()
    
    # Initialize lists/dicts to avoid passing None to RDKit
    highlight_atoms = highlight_atoms or []
    highlight_bonds = highlight_bonds or []
    atom_colors = atom_colors or {}
    bond_colors = bond_colors or {}


    # Set up drawing options
    drawer.DrawMolecule(
        mol,
        highlightAtoms=highlight_atoms,
        highlightBonds=highlight_bonds,
        highlightAtomColors=atom_colors,
        highlightBondColors=bond_colors
    )
    drawer.FinishDrawing()
    
    # Write to file
    with open(filepath, 'w') as f:
        f.write(drawer.GetDrawingText())
    print(f"Saved image to: {filepath}")

molecule_save_path = os.path.join(save_path, "selected_mol.svg")

draw_molecule_with_highlights(
    mol=selected_mol,
    filepath=molecule_save_path,
    size=FIGURE_SIZE
)


In [None]:
def populate_and_save_template(template_path, product_smiles, output_save_path):
    """
    Loads a template file, replaces the product SMILES placeholder, and saves the result.

    Args:
        template_path (str): The full path to the template file (e.g., prompts/position_model.md).
        product_smiles (str): The SMILES string to insert into the template.
        output_save_path (str): The directory path where the populated file will be saved.
    """
    try:
        # Read the template file content
        with open(template_path, 'r') as f:
            template_content = f.read()
    except FileNotFoundError:
        print(f"Error: Template file not found at '{template_path}'")
        return None

    # Replace the placeholder with the provided SMILES string
    populated_content = template_content.replace('<canonicalized_product>', product_smiles)

    # Ensure the output directory exists
    os.makedirs(output_save_path, exist_ok=True)

    # Define the output filename and full path
    output_filename = "populated_prompt.md"
    output_filepath = os.path.join(output_save_path, output_filename)

    # Write the populated content to the new file
    with open(output_filepath, 'w') as f:
        f.write(populated_content)

    print(f"Populated template saved to: {output_filepath}")
    return populated_content

# --- Example Usage ---
# You would call this function like this, providing all the paths:

populated_content = populate_and_save_template(
     template_path=position_prompt_template,
     product_smiles=selected_mol_smiles,
     output_save_path=position_model_path 
)

Two options:
1) Take the populated template and run it through your LLM via openrouter, vllm or aistudio.google.com (its free). Save the resulting .json under the <molecule_name>/position_model/response.json
2) Copy an example response.json from the results/<molecule>/position_model folder

In [None]:
break

In [None]:
response_path = os.path.join(position_model_path, "response.json")

with open(response_path, 'r') as f:
    important_reactions = json.load(f)
important_reactions

In [None]:
def flatten_important_reactions(important_reactions):
    flattened = []
    for disconnection in important_reactions["disconnections"]:
        for reaction in disconnection["reactions"]:
            flattened.append({
                "disconnection": disconnection["disconnection"],
                "forwardReaction": reaction["forwardReaction"],
                "isInOntology": reaction["isInOntology"],
                "forwardReactionClass": reaction["forwardReactionClass"],
                "Retrosynthesis Importance": reaction["Retrosynthesis Importance"],
                "Priority": reaction["Priority"],
                "rationale": reaction["rationale"]
            })
    return {"disconnections": flattened}
flattened_reactions = flatten_important_reactions(important_reactions)
flattened_reactions

In [None]:

def parse_atoms_from_disconnection(disconnection_str: str) -> List[int]:
    """Extracts atom mapping numbers from a disconnection string like 'C:1-C:2'."""
    atoms = []
    # Handles formats like "C:6-C:7" or "C:11 N:12"
    parts = disconnection_str.replace('-', ' ').split()
    for part in parts:
        if ':' in part:
            try:
                atoms.append(int(part.split(':')[1]))
            except (ValueError, IndexError):
                continue # Ignore malformed parts
    return atoms

def get_atom_indices_from_map_nums(mol: Chem.Mol, map_nums: List[int]) -> List[int]:
    """Finds atom indices in a molecule that correspond to a list of atom map numbers."""
    indices = []
    for atom in mol.GetAtoms():
        if atom.GetAtomMapNum() in map_nums:
            indices.append(atom.GetIdx())
    return indices

def create_position_report(
    product_mol: Chem.Mol, 
    reactions_data: dict, 
    position_model_path: str, 
    molecule_name: str
):
    """
    Generates a full visual and textual report for disconnection position analysis.
    """
    # --- Setup Paths ---
    report_path = os.path.join(position_model_path, "report")
    image_dir = os.path.join(report_path, "images")
    os.makedirs(image_dir, exist_ok=True)
    
    md_content = [f"# Disconnection Position Analysis for **{molecule_name}**\n"]
    
    # --- Prepare Data and Colors ---
    # Sort disconnections by priority in REVERSE order so highest priority (lowest number) is processed last
    disconnections = sorted(reactions_data.get('disconnections', []), key=lambda x: x.get('Priority', 99), reverse=True)
    
    # --- Assign a unique color to each UNIQUE priority to ensure consistency ---
    unique_priorities = sorted(list(set(d['Priority'] for d in disconnections)))
    # Generate palette and REVERSE it - coolwarm by default goes blue->red, we want red->blue
    palette = sns.color_palette("coolwarm", len(unique_priorities))[::-1]
    
    # Map lowest priority numbers (highest importance) to red (start of palette)
    # and highest priority numbers (lowest importance) to blue (end of palette)
    priority_to_color = {}
    for i, priority in enumerate(unique_priorities):
        # First element in palette is now red (for lowest priority numbers)
        priority_to_color[priority] = palette[i]

    all_highlight_atoms = []
    all_highlight_bonds = []
    all_atom_colors = {}
    all_bond_colors = {}

    # --- Generate Overall Highlight Image ---
    for d in disconnections:
        map_nums = parse_atoms_from_disconnection(d['disconnection'])
        atom_indices = get_atom_indices_from_map_nums(product_mol, map_nums)
        color = priority_to_color.get(d['Priority'], (0,0,0))

        all_highlight_atoms.extend(atom_indices)
        for idx in atom_indices:
            all_atom_colors[idx] = color
        
        # Find bonds between highlighted atoms
        for bond in product_mol.GetBonds():
            if bond.GetBeginAtomIdx() in atom_indices and bond.GetEndAtomIdx() in atom_indices:
                all_highlight_bonds.append(bond.GetIdx())
                all_bond_colors[bond.GetIdx()] = color

    # Draw and save the main image with all highlights
    main_img_path = os.path.join(image_dir, "product_all_positions.svg")
    draw_molecule_with_highlights(
        product_mol, all_highlight_atoms, all_highlight_bonds, 
        all_atom_colors, all_bond_colors, main_img_path, size=FIGURE_SIZE
    )
    
    md_content.append("#### Overview of All Predicted Disconnection Sites\n")
    md_content.append(f"![All Positions](./images/product_all_positions.svg)\n")
    md_content.append("---\n")

    # Add page break for PDF output - both HTML and LaTeX style for compatibility with different renderers
    md_content.append("\n<div style=\"page-break-after: always;\"></div>\n")
    md_content.append("\n\\pagebreak\n")

    # --- Generate Individual Reports for Each Position ---
    # For the report, show highest priority (lowest number) first
    report_disconnections = sorted(reactions_data.get('disconnections', []), key=lambda x: x.get('Priority', 99))
    
    for d in report_disconnections:
        priority = d.get('Priority', 'N/A')
        disconnection_str = d.get('disconnection', '')
        
        md_content.append(f"##### Priority {priority}: `{disconnection_str}`\n")
        
        # --- Generate Individual Highlight Image ---
        map_nums = parse_atoms_from_disconnection(disconnection_str)
        atom_indices = get_atom_indices_from_map_nums(product_mol, map_nums)
        # Ensure we use the same color from our priority_to_color map
        color = priority_to_color.get(priority, (0,0,0))
        
        # Create fresh color dictionaries for this specific position
        position_atom_colors = {idx: color for idx in atom_indices}
        bond_indices = []
        position_bond_colors = {}
        for bond in product_mol.GetBonds():
            if bond.GetBeginAtomIdx() in atom_indices and bond.GetEndAtomIdx() in atom_indices:
                bond_indices.append(bond.GetIdx())
                position_bond_colors[bond.GetIdx()] = color

        img_filename = f"position_priority_{priority}.svg"
        img_path = os.path.join(image_dir, img_filename)
        draw_molecule_with_highlights(
            product_mol, atom_indices, bond_indices, position_atom_colors, position_bond_colors, img_path, size=FIGURE_SIZE
        )
        
        md_content.append(f"![Position for Priority {priority}](./images/{img_filename})\n")

        # --- Add Details to Markdown ---
        md_content.append(f"- **Forward Reaction:** {d.get('forwardReaction', 'N/A')}")
        md_content.append(f"- **Importance Score:** `{d.get('Retrosynthesis Importance', 'N/A')}`")
        md_content.append(f"- **In Ontology:** {d.get('isInOntology', 'N/A')}")
        md_content.append(f"- **Rationale:** {d.get('rationale', 'No rationale provided.')}\n")
        md_content.append("---\n")

    # --- Write Final Report ---
    report_filepath = os.path.join(report_path, f"{molecule_name}_position_analysis_report.md")
    
    with open(report_filepath, 'w') as f:
        f.write("\n".join(md_content))
        
    print(f"Full report generated at: {report_filepath}")
    print("To generate a PDF, run the following command in your terminal:")
    print(f"cd {report_path}")
    print(f"pandoc {molecule_name}_position_analysis_report.md -s --css={REPORT_CSS_FILE_PATH} --pdf-engine=weasyprint -o {molecule_name}_position_analysis_report.pdf")

# --- Execute the report generation ---
create_position_report(
    product_mol=selected_mol,
    reactions_data=flattened_reactions,
    position_model_path=position_model_path,
    molecule_name=molecule_name
)

In [None]:

def escape_latex(text: str) -> str:
    """
    Escapes all special LaTeX characters in a given string.
    """
    if not isinstance(text, str):
        text = str(text)
    
    # Order matters: backslash must be first
    conv = {
        '\\': r'\textbackslash{}',
        '&': r'\&',
        '%': r'\%',
        '$': r'\$',
        '#': r'\#',
        '_': r'\_',
        '{': r'\{',
        '}': r'\}',
        '~': r'\textasciitilde{}',
        '^': r'\textasciicircum{}',
    }
    regex = re.compile('|'.join(re.escape(key) for key in sorted(conv.keys(), key=lambda item: - len(item))))
    return regex.sub(lambda match: conv[match.group()], text)

def save_reactions_to_csv_and_latex(reactions_data, save_path, molecule_name):
    """
    Save flattened reaction data to CSV and create a LaTeX table template.
    
    Args:
        reactions_data: Dictionary containing flattened reaction data
        save_path: Path to save the output files
        molecule_name: Name of the molecule for file naming
    """
    # Create the report directory if it doesn't exist
    report_dir = os.path.join(save_path, "position_model", "report")
    os.makedirs(report_dir, exist_ok=True)
    
    # Sort disconnections by priority for better readability
    disconnections = sorted(reactions_data.get('disconnections', []), 
                           key=lambda x: x.get('Priority', 99))
    
    # Extract relevant fields for CSV
    csv_data = []
    for d in disconnections:
        csv_data.append({
            'Priority': d.get('Priority', ''),
            'Disconnection': d.get('disconnection', ''),
            'Forward_Reaction': d.get('forwardReaction', ''),
            'In_Ontology': d.get('isInOntology', ''),
            'Importance': d.get('Retrosynthesis Importance', ''),
            'Rationale': d.get('rationale', '').replace('\n', ' ')  # Remove line breaks for CSV
        })
    
    # Convert to DataFrame and save as CSV
    df = pd.DataFrame(csv_data)
    csv_path = os.path.join(report_dir, f"{molecule_name}_position_analysis.csv")
    df.to_csv(csv_path, index=False)
    print(f"Saved CSV to: {csv_path}")
    
    # --- Create LaTeX template using longtable for multi-page support ---
    # This requires the \usepackage{longtable} in your LaTeX document preamble.
    latex_content = []
    
    # Define column alignments for longtable
    # c=center, p{width}=paragraph for wrapping
    column_format = "c c p{3cm} c c p{5cm}"
    latex_content.append(f"\\begin{{longtable}}{{{column_format}}}")
    
    # --- Caption and Label with Descriptions ---
    # The short caption in [] is for the List of Tables. The long caption in {} is the main one.
    short_caption = "Predicted Disconnection Sites for " + escape_latex(molecule_name)
    
    long_caption_parts = [
        short_caption + ". ",
        "Header descriptions are as follows: ",
        "\\textbf{Prio.}: Priority Ranking of the Disconnections; ",
        "\\textbf{Position}: The position where the disconnection is; ",
        "\\textbf{Reaction}: The forward reaction; ",
        "\\textbf{Ontology}: If the reaction is in the reaction ontology for which examples are available; ",
        "\\textbf{Imp.}: Retrosynthesis Importance, alignment with retrosynthesis goals; ",
        "\\textbf{Rationale}: The chemical rationale."
    ]
    
    latex_content.append("\\caption{" + "".join(long_caption_parts) + "}")
    latex_content.append("\\label{table:disconnections-" + molecule_name.lower().replace('_', '-') + "} \\\\")
    
    # --- Table Headers ---
    # This header will appear on the first page
    headers = [
        "\\multicolumn{1}{c}{\\textbf{Prio.}}",
        "\\multicolumn{1}{c}{\\textbf{Position}}",
        "\\multicolumn{1}{c}{\\textbf{Reaction}}",
        "\\multicolumn{1}{c}{\\textbf{Ontology}}",
        "\\multicolumn{1}{c}{\\textbf{Imp.}}",
        "\\multicolumn{1}{c}{\\textbf{Rationale}}"
    ]
    latex_content.append(" & ".join(headers) + " \\\\")
    latex_content.append("\\hline")
    latex_content.append("\\endfirsthead")
    
    # This header will repeat on all subsequent pages
    latex_content.append(" & ".join(headers) + " \\\\")
    latex_content.append("\\hline")
    latex_content.append("\\endhead")
    
    # --- Table Footer ---
    # This will appear at the bottom of the table on all pages
    latex_content.append("\\endfoot")
    
    # --- Table Body ---
    for d in disconnections:
        priority = escape_latex(d.get('Priority', ''))
        disconnection = escape_latex(d.get('disconnection', ''))
        reaction = escape_latex(d.get('forwardReaction', ''))
        in_ontology = "Yes" if d.get('isInOntology', False) else "No"
        importance = escape_latex(d.get('Retrosynthesis Importance', ''))
        rationale = escape_latex(d.get('rationale', ''))
        
        # Join row data with '&'
        row_data = [priority, disconnection, reaction, in_ontology, importance, rationale]
        latex_content.append(" & ".join(row_data) + " \\\\")
    
    latex_content.append("\\end{longtable}")
    
    # Save LaTeX template
    latex_path = os.path.join(report_dir, f"{molecule_name}_position_analysis_table.tex")
    with open(latex_path, 'w') as f:
        f.write('\n'.join(latex_content))
    print(f"Saved LaTeX template to: {latex_path}")

# Execute the function
save_reactions_to_csv_and_latex(flattened_reactions, save_path, molecule_name)