# Transition Analysis Tool

This notebook provides a comprehensive workflow for generating and visualizing proposed reactants for specific disconnection sites in a target molecule.

## Getting Started

1. **Configure your reaction**:
   - Specify `selected_molecule_name` for file naming 
   - Provide `smiles` as the target molecule with atom mapping taken from the position model
   - Set `selected_priority` and `selected_position` based on position model output
   - Specify `selected_forward_reaction` name
   - You can also adjust the reaction example dataset (`selected_dataset`) from paroutes to uspto50k. However, we only provide precalculated examples for paroutes, so you would need to do inference yourself.

2. **Generate reaction-specific prompt**:
   - The tool automatically pulls relevant training examples from the PaRoutes dataset
   - Populates a template with your molecule's SMILES, position, and training examples
   - Saves the populated prompt at `<molecule_name>/transition_model/priority_<N>/transition_model_prompt_populated.md`
   - Submit this prompt to your LLM of choice

3. **Process the results**:
   - Save the LLM's JSON response to `<molecule_name>/transition_model/priority_<N>/response.json`
   - The tool will generate:
     - Visual report with highlighted reaction site on the product
     - Individual images of each proposed reactant
     - Full markdown report describing each proposed transition

4. **Generate PDF report**:
   - Use the provided pandoc command to convert the markdown report to PDF
   - Each transition includes chemical validity assessment and reasoning

This workflow enables rapid evaluation of possible disconnections for retrosynthetic analysis, building on the positions identified in the position model.

In [None]:
import pandas
import os
import json
from dataclasses import dataclass
from typing import List, Dict, Any
import os
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from typing import List, Dict, Optional, Tuple, NamedTuple

# settings

In [None]:
selected_molecule_name = "LEI_515"
smiles = "C[CH2:1][C:2]([C:3](=[O:4])[CH2:5][S:6](=[O:7])[c:8]1[cH:9][cH:10][c:11]([C:12](=[O:13])[N:14]2[CH2:15][CH2:16][N:17]([c:18]3[cH:19][c:20]([Cl:21])[cH:22][cH:23][cH:24]3)[C@@H:25]([CH3:26])[C@@H:27]2[CH3:28])[cH:29][c:30]1[Cl:31])([F:32])[F:33]"

selected_priority = "13"
selected_position = "N:14"
selected_forward_reaction = "Boc amine deprotection"

# General Settings

selected_dataset = "paroutes" # pick either "paroutes" or "uspto50k"
REACTANT_FIGURE_SIZE = (300, 200)
PRODUCT_FIGURE_SIZE = (600, 400)

In [None]:
# SET PATHS

current_directory = os.getcwd()
working_folder = os.path.abspath(os.path.join(current_directory, '..'))
print(f"Working Directory: {working_folder}")

prompt_template = os.path.join(working_folder, "prompts/transition_model_prompt.md")
save_path = os.path.join(working_folder, f"examples/{selected_molecule_name}/transition_model/")

if selected_dataset == "paroutes":
    training_data_examples = os.path.join(working_folder, "data/paroutes_TRAIN_reactions_with_models_matter_splits_annotated_subsample_5_examples_per_reaction_name.csv")
elif selected_dataset == "uspto50k":
    training_data_examples = os.path.join(working_folder, "data/uspto50k_graphretro_canonicalized_TRAIN_atom_and_bond_changes_final_subsample_5_examples_per_reaction_name.csv")
else:
    raise ValueError("selected_dataset must be either 'paroutes' or 'uspto50k'")

# report css file path
REPORT_CSS_FILE_PATH = working_folder + "/data/report.css"

# print paths
print(f"Prompt Template Path: {prompt_template}")
print(f"Save Path: {save_path}")
print(f"Training Data Examples Path: {training_data_examples}")
print(f"Report CSS File Path: {REPORT_CSS_FILE_PATH}")

In [None]:
# populate training examples for selected reaction
training_data_examples = pandas.read_csv(training_data_examples)
matching_train_rows = training_data_examples[training_data_examples['rxn_insight_name'] == selected_forward_reaction]
matching_train_rows

In [None]:
train_examples = matching_train_rows['canonicalized_retro_reaction'].tolist()
train_examples

In [None]:
# function to create and save populated transition file

def create_transition_file(product_smiles, priority, reaction_position, reaction_name, train_examples):
    """
    Load template and populate it with provided data and training examples.
    
    Args:
        product_smiles: SMILES string for the product
        priority: Priority level for folder naming
        reaction_position: Position/bonds for the reaction
        reaction_name: Name of the reaction
        train_examples: List of training examples (no_atom_mapping)
    """
    # Read the template file
    with open(prompt_template, 'r') as f:
        template = f.read()
    
    # Create priority-based folder
    priority_folder = f"priority_{priority}"
    output_dir = os.path.join(save_path, priority_folder)
    os.makedirs(output_dir, exist_ok=True)
    
    # Format training examples as a list string
    examples_formatted = '[\n' + ',\n'.join([f'    "{example}"' for example in train_examples]) + '\n]'
    
    # Replace placeholders in template
    populated_content = template.replace('<REACTION_POSITION>', str(reaction_position))
    populated_content = populated_content.replace('<REACTION_NAME>', reaction_name)
    populated_content = populated_content.replace('<PRODUCT_SMILES>', product_smiles)
    populated_content = populated_content.replace('<TRAIN_REACTION_EXAMPLES>', examples_formatted)
    
    # Write to output file in the priority folder
    output_filename = os.path.basename(prompt_template).replace('.md', '_populated.md')
    output_path = os.path.join(output_dir, output_filename)
    with open(output_path, 'w') as f:
        f.write(populated_content)
    
    print(f"Populated template saved to: {output_path}")
    return populated_content

# sample call

populated_file_content = create_transition_file(smiles, selected_priority, selected_position, selected_forward_reaction, train_examples)
populated_file_content

Two options:
1) Take the populated template and run it through your LLM via openrouter, vllm or aistudio.google.com (its free). Save the resulting .json under the <molecule_name>/transition_model/response.json
2) Copy an example response.json from the results/<molecule_name>/transition_model/<transition_n> folder

In [None]:
break

In [None]:
file_path = os.path.join(save_path, f"priority_{selected_priority}", "response.json")
with open(file_path, "r") as f:
    response = json.load(f)

response

In [None]:
# Define the dataclass for a single reactant permutation result
@dataclass
class ReactantResult:
    """Represents the result of a single reactant permutation analysis."""
    product: str
    forward_reaction_name: str
    reactant_list: List[str]
    is_valid: bool
    is_template: bool
    reasoning: str

def process_response(response: Dict[str, Any]) -> List[ReactantResult]:
    """
    Parses the API response and creates a list of ReactantResult objects.
    """
    results = []
    product = response["product"]

    for analysis in response["reaction_analysis"]:
        forward_reaction_name = analysis["forward_reaction_name"]
        for permutation in analysis["reactant_permutations"]:
            # Create and populate the ReactantResult for each permutation
            result_instance = ReactantResult(
                product=product,
                forward_reaction_name=forward_reaction_name,
                reactant_list=permutation["reactants"],
                is_valid=permutation["is_valid"],
                is_template=permutation["is_template"],
                reasoning=permutation["reasoning"]
            )
            results.append(result_instance)
    
    return results

# Example usage with a sample response dictionary:
# response = { ... } 
reactant_results = process_response(response)
reactant_results

In [None]:


# Assuming a definition for ReactantResult for context
class ReactantResult(NamedTuple):
    product: str
    forward_reaction_name: str
    is_valid: bool
    is_template: bool
    reasoning: str
    reactant_list: List[str]

def parse_position(position_str, smiles):
    """
    Parse the position string and find atoms by their atom mapping numbers in the SMILES.
    Returns both atom indices and the bond index connecting them (if exists).
    """
    try:
        # Extract the atom mappings from the position string
        target_mappings = []
        for part in position_str.split():
            if ':' in part:
                # Format like "C:14" -> get mapping number 14
                mapping = int(part.split(':')[1])
                target_mappings.append(mapping)
        
        print(f"Looking for atoms with mappings: {target_mappings}")
        
        # Create a molecule from SMILES
        mol = Chem.MolFromSmiles(smiles)
        if not mol:
            return [], []
        
        # Find atoms in the molecule that match the target mappings
        highlight_atoms = []
        atom_indices_map = {}  # Map from atom mapping number to atom index
        
        for atom in mol.GetAtoms():
            # Get the atom mapping property
            mapping = atom.GetAtomMapNum()
            if mapping in target_mappings:
                atom_idx = atom.GetIdx()
                highlight_atoms.append(atom_idx)
                atom_indices_map[mapping] = atom_idx
                print(f"Found atom with mapping {mapping} at index {atom_idx}, type {atom.GetSymbol()}")
        
        # Find bonds connecting the highlighted atoms
        highlight_bonds = []
        if len(highlight_atoms) >= 2:
            # For each pair of highlighted atoms, check if there's a bond between them
            for i, atom1_idx in enumerate(highlight_atoms):
                for atom2_idx in highlight_atoms[i+1:]:
                    bond = mol.GetBondBetweenAtoms(atom1_idx, atom2_idx)
                    if bond:
                        bond_idx = bond.GetIdx()
                        highlight_bonds.append(bond_idx)
                        print(f"Found bond at index {bond_idx} between atoms {atom1_idx} and {atom2_idx}")
        
        return highlight_atoms, highlight_bonds
    except Exception as e:
        print(f"Error finding atoms and bonds: {e}")
        return [], []

def draw_molecule_with_highlights(
            mol: Chem.Mol, 
            highlight_atoms: Optional[List[int]] = None, 
            highlight_bonds: Optional[List[int]] = None, 
            atom_colors: Optional[Dict] = None, 
            bond_colors: Optional[Dict] = None, 
            filepath: str = 'molecule.svg', 
            size: Tuple[int, int] = (600, 400),
):
    """Draws a molecule to an SVG file with specified highlights."""
    drawer = Draw.rdMolDraw2D.MolDraw2DSVG(size[0], size[1])
    if size[0] < 600:
        drawer.drawOptions().bondLineWidth = 1
    else:
        drawer.drawOptions().bondLineWidth = 2
    #drawer.drawOptions().useBWAtomPalette()
    drawer.drawOptions().highlightRadius = 0.4  # Increase highlight radius for better visibility
    
    # Initialize lists/dicts to avoid passing None to RDKit
    highlight_atoms = highlight_atoms or []
    highlight_bonds = highlight_bonds or []
    atom_colors = atom_colors or {}
    bond_colors = bond_colors or {}
    
    # If there are highlight atoms but no colors specified, use red highlight by default
    if highlight_atoms and not atom_colors:
        atom_colors = {idx: (1, 0, 0) for idx in highlight_atoms}  # Red highlight
        
    # If there are highlight bonds but no colors specified, use red highlight by default
    if highlight_bonds and not bond_colors:
        bond_colors = {idx: (1, 0, 0) for idx in highlight_bonds}  # Red highlight

    # Set up drawing options
    drawer.DrawMolecule(
        mol,
        highlightAtoms=highlight_atoms,
        highlightBonds=highlight_bonds,
        highlightAtomColors=atom_colors,
        highlightBondColors=bond_colors
    )
    drawer.FinishDrawing()
    
    # Write to file
    with open(filepath, 'w') as f:
        f.write(drawer.GetDrawingText())
    print(f"Saved image to: {filepath}")
    return True

def create_report(reactant_results: List[ReactantResult], save_path: str, selected_molecule_name: str, selected_priority: str, selected_position: str = None, selected_forward_reaction: str = None):
    """Creates a markdown report with visualizations of the reaction analysis."""
    if not reactant_results:
        print("No reactant results to generate a report for.")
        return

    report_dir = os.path.join(save_path, f"priority_{selected_priority}/report")
    image_dir = os.path.join(report_dir, "images")
    os.makedirs(image_dir, exist_ok=True)

    # Start the report with reaction information
    md_content = [f"# Reaction Report for ***{selected_molecule_name}***, Priority {selected_priority}\n"]
    
    if selected_position:
        md_content.append(f"**Selected Position:** `{selected_position}`")
    
    if selected_forward_reaction:
        md_content.append(f"\n**Selected Forward Reaction:** {selected_forward_reaction}\n")

    # 1. Product visualization with highlighted atoms
    product_smiles = reactant_results[0].product
    product_img_filename = "product.svg"
    product_img_path = os.path.join(image_dir, product_img_filename)
    
    # Parse the position string to get atom indices and bond indices to highlight
    highlight_atoms, highlight_bonds = [], []
    if selected_position:
        highlight_atoms, highlight_bonds = parse_position(selected_position, product_smiles)
        print(f"Highlighting atoms at positions: {highlight_atoms}")
        print(f"Highlighting bonds: {highlight_bonds}")
    
    mol = Chem.MolFromSmiles(product_smiles)
    if mol:
        # Add 2D coordinates if needed
        if mol.GetNumConformers() == 0:
            AllChem.Compute2DCoords(mol)
            
        # Draw product with highlights
        draw_molecule_with_highlights(
            mol=mol,
            highlight_atoms=highlight_atoms,
            highlight_bonds=highlight_bonds,
            filepath=product_img_path,
            size=PRODUCT_FIGURE_SIZE
        )
    
    md_content.append("## Product")
    if highlight_atoms:
        md_content.append(f"*Reaction site highlighted at position `{selected_position}`*")
    md_content.append(f"![Product](./images/{product_img_filename})")
    md_content.append(f"\n**Product SMILES:** ```{product_smiles}```\n")

    # 2. Transitions visualization
    md_content.append("## Proposed Transitions")
    for i, result in enumerate(reactant_results):
        md_content.append(f"### Transition {i+1}")
        md_content.append(f"- **Forward Reaction:** {result.forward_reaction_name}")
        md_content.append(f"- **Is Chemically Valid:** {result.is_valid}")
        md_content.append(f"- **Is Template-based:** {result.is_template}")
        md_content.append(f"- **Reasoning:** {result.reasoning}")
        
        print(f"\nProcessing Transition {i+1}:")
        print(f"Number of reactants: {len(result.reactant_list)}")
        
        # Format reactants as a proper list, with SMILES in code blocks
        for j, reactant_smiles in enumerate(result.reactant_list):
            reactant_index = j + 1
            reactant_img_filename = f"transition_{i+1}_reactant_{reactant_index}.svg"
            reactant_img_path = os.path.join(image_dir, reactant_img_filename)
            
            # Parse the reactant SMILES
            mol = None
            try:
                mol = Chem.MolFromSmiles(reactant_smiles)
                if mol is None:
                    print(f"Trying to parse as SMARTS: {reactant_smiles}")
                    mol = Chem.MolFromSmarts(reactant_smiles)
            except Exception as e:
                print(f"Error parsing reactant {reactant_index}: {e}")
            
            if mol:
                # Add 2D coordinates if needed
                if mol.GetNumConformers() == 0:
                    AllChem.Compute2DCoords(mol)
                
                # Draw the reactant
                draw_molecule_with_highlights(
                    mol=mol,
                    filepath=reactant_img_path,
                    size=REACTANT_FIGURE_SIZE
                )
                
                md_content.append(f"- **Reactant {reactant_index}:** `{reactant_smiles}`")
                md_content.append(f"![Reactant {reactant_index}](./images/{reactant_img_filename})")
                md_content.append(f"\n")
            else:
                md_content.append(f"- **Reactant {reactant_index}:** `{reactant_smiles}`")
                md_content.append(f"(Could not visualize)")
                md_content.append(f"\n")
        
        md_content.append("\n")

    # 3. Write the report file
    report_filename = f"{selected_molecule_name}_transition_priority_{selected_priority}.md"
    pdf_report_filename = f"{selected_molecule_name}_transition_priority_{selected_priority}.pdf"
    title = f"{selected_molecule_name}_transition_priority_{selected_priority}"
    report_path = os.path.join(report_dir, report_filename)
    with open(report_path, 'w') as f:
        f.write("\n".join(md_content))
        
    print(f"Report successfully generated at: {report_path}")
    return f"pandoc {report_filename} -s --css={REPORT_CSS_FILE_PATH} --pdf-engine=weasyprint -o {pdf_report_filename}"

# Generate the report - update to pass the parameters
create_report(
    reactant_results,
    save_path,
    selected_molecule_name,
    selected_priority,
    selected_position=selected_position,
    selected_forward_reaction=selected_forward_reaction
)