# ColabReaction v1.0.0
Reaction path optimization using Direct MaxFlux (DMF) and the UMA machine learning potential.

<img src="https://github.com/BILAB/ColabReaction/raw/main/TOC_logo.jpg" alt="TOC logo" width="600">


In [2]:
#@title Installation for app (This may take minutes, for installation of customized Panels)
! time  pip install \
  git+https://github.com/luvwinnie/panel@1.7.2-rc.1-chem-new \
  git+https://github.com/luvwinnie/panel-3dmol.git \
  param jupyter_bokeh comm plotly py3Dmol rdkit -U

Collecting git+https://github.com/luvwinnie/panel@1.7.2-rc.1-chem-new
  Cloning https://github.com/luvwinnie/panel (to revision 1.7.2-rc.1-chem-new) to /tmp/pip-req-build-gxo3jxjx
  Running command git clone --filter=blob:none --quiet https://github.com/luvwinnie/panel /tmp/pip-req-build-gxo3jxjx
  Running command git checkout -b 1.7.2-rc.1-chem-new --track origin/1.7.2-rc.1-chem-new
  Switched to a new branch '1.7.2-rc.1-chem-new'
  Branch '1.7.2-rc.1-chem-new' set up to track remote branch '1.7.2-rc.1-chem-new' from 'origin'.
  Resolved https://github.com/luvwinnie/panel to commit d80acae43fa3f23781e0c96fae5054c787961a46
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting git+https://github.com/luvwinnie/panel-3dmol.git
  Cloning https://github.com/luvwinnie/panel-3dmol.git to /tmp/pip-req-build-_e__ajg8
  Running command git clone --filter=blob:none --qui

# 📝 I. setup section

## Input arguments

Please complete the following steps:
1. Upload reactant and product files (.xyz, .com, .gjf, .pdb, .mol, .sdf).
2. Set calculation parameters for DMF
3. Input your Hugging Face API token to access the UMA pretrained model

In [3]:
#@title 1. Upload reactant and product files (.xyz, .com, .gjf, .pdb, .mol, .sdf). Rerun the cell code for uploading new file. (Dual Visualizer)
import panel as pn
import param
from panel.reactive import ReactiveHTML
from panel_3dmol import Mol3DViewer
import tempfile
import os
import re

# RDKit imports
try:
    from rdkit import Chem
    from rdkit.Chem import AllChem
    RDKIT_AVAILABLE = True
except ImportError:
    RDKIT_AVAILABLE = False

# Enable Panel extensions
pn.extension('filedropper')
pn.config.sizing_mode = 'stretch_width'

# Global variables to store XYZ data only
uploaded_reactant = ""
uploaded_product = ""

# Create viewers
reactant_viewer = Mol3DViewer(show_atom_labels=True)
product_viewer = Mol3DViewer(show_atom_labels=True)

# Create file droppers
reactant_dropper = pn.widgets.FileDropper(
    name="🧪 Drop Reactant File (.xyz, .pdb, .mol, .sdf, .com, .gjf)",
    height=100
)

product_dropper = pn.widgets.FileDropper(
    name="🎯 Drop Product File (.xyz, .pdb, .mol, .sdf, .com, .gjf)",
    height=100
)

def mol_to_xyz(mol, title="Converted molecule"):
    """Convert RDKit molecule to XYZ format string"""
    if mol is None:
        return None

    try:
        conf = mol.GetConformer()
    except:
        return None

    xyz_lines = [str(mol.GetNumAtoms()), title]

    for i, atom in enumerate(mol.GetAtoms()):
        pos = conf.GetAtomPosition(i)
        symbol = atom.GetSymbol()
        xyz_lines.append(f"{symbol} {pos.x:.6f} {pos.y:.6f} {pos.z:.6f}")

    return "\n".join(xyz_lines)

def molblock_to_xyz(mol_block, title="Converted from MOL block"):
    """Convert MOL block format to XYZ format"""
    lines = mol_block.strip().split('\n')

    if len(lines) < 4:
        return None

    try:
        counts_line = lines[3].split()
        if len(counts_line) < 2:
            return None

        atom_count = int(counts_line[0])
        atoms = []

        for i in range(4, 4 + atom_count):
            if i >= len(lines):
                break

            parts = lines[i].split()
            if len(parts) >= 4:
                x, y, z, element = parts[0], parts[1], parts[2], parts[3]
                atoms.append(f"{element} {x} {y} {z}")

        if atoms:
            xyz_content = f"{len(atoms)}\n{title}\n" + "\n".join(atoms)
            return xyz_content

    except (ValueError, IndexError):
        pass

    return None

def parse_gaussian_to_xyz(content, filename):
    """Convert Gaussian input (.gjf/.com) to XYZ format"""
    lines = [line.strip() for line in content.strip().split('\n')]

    blank_count = 0
    atoms = []

    for line in lines:
        if line == "":
            blank_count += 1
            continue

        if blank_count == 2:
            blank_count += 1
            continue
        elif blank_count == 3:
            tokens = line.split()
            if len(tokens) >= 4:
                try:
                    element = tokens[0]
                    x = float(tokens[1])
                    y = float(tokens[2])
                    z = float(tokens[3])
                    atoms.append(f"{element} {x:.6f} {y:.6f} {z:.6f}")
                except (ValueError, IndexError):
                    break
            else:
                break

    if atoms:
        xyz_content = f"{len(atoms)}\nConverted from {filename}\n" + "\n".join(atoms)
        return xyz_content

    return None

def fallback_convert_sdf_to_xyz(content, filename):
    """Fallback SDF to XYZ conversion"""
    lines = content.strip().split('\n')

    if len(lines) < 4:
        return None

    try:
        counts_line = lines[3].split()
        if len(counts_line) < 2:
            return None

        atom_count = int(counts_line[0])
        atoms = []

        for i in range(4, 4 + atom_count):
            if i >= len(lines):
                break

            parts = lines[i].split()
            if len(parts) >= 4:
                try:
                    x = float(parts[0])
                    y = float(parts[1])
                    z = float(parts[2])
                    element = parts[3].strip().capitalize()
                    atoms.append(f"{element} {x:.6f} {y:.6f} {z:.6f}")
                except (ValueError, IndexError):
                    continue

        if atoms:
            xyz_content = f"{len(atoms)}\nConverted from {filename}\n" + "\n".join(atoms)
            return xyz_content

    except (ValueError, IndexError):
        pass

    return None

def fallback_convert_mol_to_xyz(content, filename):
    """Fallback MOL to XYZ conversion"""
    lines = content.strip().split('\n')

    if len(lines) < 4:
        return None

    try:
        counts_line = lines[3].split()
        if len(counts_line) < 2:
            return None

        atom_count = int(counts_line[0])
        atoms = []

        for i in range(4, 4 + atom_count):
            if i >= len(lines):
                break

            parts = lines[i].split()
            if len(parts) >= 4:
                try:
                    x = float(parts[0])
                    y = float(parts[1])
                    z = float(parts[2])
                    element = parts[3].strip().capitalize()
                    atoms.append(f"{element} {x:.6f} {y:.6f} {z:.6f}")
                except (ValueError, IndexError):
                    continue

        if atoms:
            xyz_content = f"{len(atoms)}\nConverted from {filename}\n" + "\n".join(atoms)
            return xyz_content

    except (ValueError, IndexError):
        pass

    return None

def fallback_convert_pdb_to_xyz(content, filename):
    """Fallback PDB to XYZ conversion"""
    lines = content.strip().split('\n')
    atoms = []

    for line in lines:
        if line.startswith(('ATOM', 'HETATM')):
            try:
                atom_name = line[12:16].strip()
                x = float(line[30:38].strip())
                y = float(line[38:46].strip())
                z = float(line[46:54].strip())

                # Try to get element symbol from columns 77-78
                element = ""
                if len(line) > 77:
                    element = line[76:78].strip()

                if not element:
                    element = re.sub(r'[0-9]+', '', atom_name).strip()
                    if len(element) > 2:
                        element = element[:2]
                    elif len(element) == 0:
                        element = atom_name[:1]

                element = element.capitalize()
                atoms.append(f"{element} {x:.6f} {y:.6f} {z:.6f}")

            except (ValueError, IndexError):
                continue

    if atoms:
        xyz_content = f"{len(atoms)}\nConverted from {filename}\n" + "\n".join(atoms)
        return xyz_content

    return None

def convert_to_xyz_rdkit(content, file_format, filename):
    """Convert SDF, MOL, PDB formats to XYZ using RDKit"""
    if not RDKIT_AVAILABLE or file_format not in ['sdf', 'mol', 'pdb']:
        return None

    mol = None

    try:
        with tempfile.NamedTemporaryFile(mode='w', suffix=f'.{file_format}', delete=False) as tmp_file:
            tmp_file.write(content)
            tmp_path = tmp_file.name

        try:
            if file_format == 'sdf':
                try:
                    suppl = Chem.SDMolSupplier(tmp_path, removeHs=False, sanitize=False)
                    mol = next(suppl) if suppl else None
                    if mol is not None:
                        try:
                            Chem.SanitizeMol(mol)
                        except:
                            pass
                except Exception:
                    mol = None

                if mol is None:
                    try:
                        suppl = Chem.SDMolSupplier(tmp_path, removeHs=False, sanitize=True)
                        mol = next(suppl) if suppl else None
                    except Exception:
                        mol = None

            elif file_format == 'mol':
                try:
                    mol = Chem.MolFromMolFile(tmp_path, removeHs=False, sanitize=False)
                    if mol is not None:
                        try:
                            Chem.SanitizeMol(mol)
                        except:
                            pass
                except Exception:
                    mol = None

                if mol is None:
                    try:
                        mol = Chem.MolFromMolFile(tmp_path, removeHs=False, sanitize=True)
                    except Exception:
                        mol = None

            elif file_format == 'pdb':
                try:
                    mol = Chem.MolFromPDBFile(tmp_path, removeHs=False, sanitize=False)
                    if mol is not None:
                        try:
                            Chem.SanitizeMol(mol)
                        except:
                            pass
                except Exception:
                    mol = None

                if mol is None:
                    try:
                        mol = Chem.MolFromPDBFile(tmp_path, removeHs=False, sanitize=True)
                    except Exception:
                        mol = None

        finally:
            os.unlink(tmp_path)

        if mol is None:
            return None

        # Generate 3D coordinates if missing
        if mol.GetNumConformers() == 0:
            mol = Chem.AddHs(mol)
            try:
                AllChem.EmbedMolecule(mol, randomSeed=42)
                AllChem.UFFOptimizeMolecule(mol)
            except Exception:
                return None

        # Convert to XYZ format
        if file_format in ['pdb', 'mol', 'sdf']:
            try:
                mol_block = Chem.MolToMolBlock(mol)
                xyz_content = molblock_to_xyz(mol_block, f"Converted from {filename}")

                if xyz_content:
                    return xyz_content
                else:
                    xyz_content = mol_to_xyz(mol, f"Converted from {filename}")
                    if xyz_content:
                        return xyz_content
            except Exception:
                xyz_content = mol_to_xyz(mol, f"Converted from {filename}")
                if xyz_content:
                    return xyz_content
        else:
            xyz_content = mol_to_xyz(mol, f"Converted from {filename}")
            if xyz_content:
                return xyz_content

        return None

    except Exception:
        return None

def count_atoms(file_content, file_ext):
    """Count atoms in various molecular file formats"""
    try:
        lines = [line.strip() for line in file_content.strip().split('\n')]

        if file_ext == 'xyz':
            return int(lines[0]) if lines else 0

        elif file_ext == 'pdb':
            return len([line for line in lines if line.startswith(('ATOM', 'HETATM'))])

        elif file_ext in ('sdf', 'mol'):
            if len(lines) >= 4:
                counts_line = lines[3].split()
                if len(counts_line) >= 2:
                    return int(counts_line[0])

        elif file_ext in ('com', 'gjf'):
            blank_count = 0
            atom_count = 0

            for line in lines:
                if line.strip() == "":
                    blank_count += 1
                    continue

                if blank_count == 2:
                    blank_count += 1
                    continue
                elif blank_count == 3:
                    tokens = line.split()
                    if len(tokens) >= 4:
                        try:
                            float(tokens[1])
                            float(tokens[2])
                            float(tokens[3])
                            atom_count += 1
                        except ValueError:
                            break
                    else:
                        break

            return atom_count

        else:
            return len([line for line in lines if line and not line.startswith('#')])

    except Exception:
        return "Unknown"

def convert_to_xyz(content, file_format, filename):
    """Master conversion function - converts ALL formats to XYZ"""
    if file_format == 'xyz':
        return content

    # Handle Gaussian files
    if file_format in ('com', 'gjf'):
        xyz_content = parse_gaussian_to_xyz(content, filename)
        if xyz_content:
            return xyz_content
        else:
            return None

    # Try RDKit for supported formats
    if RDKIT_AVAILABLE and file_format in ['sdf', 'mol', 'pdb']:
        xyz_content = convert_to_xyz_rdkit(content, file_format, filename)
        if xyz_content:
            original_atom_count = count_atoms(content, file_format)
            converted_count = count_atoms(xyz_content, 'xyz')

            # Check if significant atoms lost, try fallback
            if isinstance(original_atom_count, int) and isinstance(converted_count, int):
                if file_format in ['pdb', 'mol', 'sdf'] and converted_count < original_atom_count * 0.8:

                    if file_format == 'pdb':
                        fallback_xyz = fallback_convert_pdb_to_xyz(content, filename)
                    elif file_format == 'mol':
                        fallback_xyz = fallback_convert_mol_to_xyz(content, filename)
                    elif file_format == 'sdf':
                        fallback_xyz = fallback_convert_sdf_to_xyz(content, filename)

                    if fallback_xyz:
                        fallback_count = count_atoms(fallback_xyz, 'xyz')
                        if fallback_count > converted_count:
                            return fallback_xyz

            return xyz_content

        # If RDKit fails, try fallback
        if file_format == 'pdb':
            xyz_content = fallback_convert_pdb_to_xyz(content, filename)
            if xyz_content:
                return xyz_content
        elif file_format == 'mol':
            xyz_content = fallback_convert_mol_to_xyz(content, filename)
            if xyz_content:
                return xyz_content
        elif file_format == 'sdf':
            xyz_content = fallback_convert_sdf_to_xyz(content, filename)
            if xyz_content:
                return xyz_content

    # Fallback when RDKit not available
    elif file_format == 'pdb':
        xyz_content = fallback_convert_pdb_to_xyz(content, filename)
        if xyz_content:
            return xyz_content
    elif file_format == 'mol':
        xyz_content = fallback_convert_mol_to_xyz(content, filename)
        if xyz_content:
            return xyz_content
    elif file_format == 'sdf':
        xyz_content = fallback_convert_sdf_to_xyz(content, filename)
        if xyz_content:
            return xyz_content

    return None

def process_file_upload(filename, file_content, molecule_type):
    """Process uploaded file and convert to XYZ"""
    if isinstance(file_content, bytes):
        file_content = file_content.decode('utf-8', errors='ignore')

    extension = filename.split('.')[-1].lower()

    # Convert to XYZ format
    xyz_content = convert_to_xyz(file_content, extension, filename)

    if xyz_content:
        # Save original file
        original_path = f"./{filename}"
        with open(original_path, 'w', encoding='utf-8') as f:
            f.write(file_content)

        # Save XYZ version
        xyz_filename = filename.rsplit('.', 1)[0] + '.xyz'
        xyz_path = f"./{xyz_filename}"
        with open(xyz_path, 'w', encoding='utf-8') as f:
            f.write(xyz_content)

        # Return only XYZ data
        xyz_data = {
            'original_filename': filename,
            'original_format': extension,
            'xyz_filename': xyz_filename,
            'xyz_content': xyz_content,
            'xyz_path': xyz_path,
            'atom_count': count_atoms(xyz_content, 'xyz')
        }

        return xyz_data, xyz_content, 'xyz'
    else:
        return None, None, None

def on_reactant_drop(event):
    global uploaded_reactant
    if reactant_dropper.value:
        try:
            filename, file_content = next(iter(reactant_dropper.value.items()))

            xyz_data, display_content, display_format = process_file_upload(
                filename, file_content, 'reactant'
            )

            if xyz_data is not None:
                uploaded_reactant = xyz_data

                # Update viewer
                reactant_viewer.filetype = display_format
                reactant_viewer.structure = display_content
                reactant_viewer.param.trigger('structure')

        except Exception:
            pass

def on_product_drop(event):
    global uploaded_product
    if product_dropper.value:
        try:
            filename, file_content = next(iter(product_dropper.value.items()))

            xyz_data, display_content, display_format = process_file_upload(
                filename, file_content, 'product'
            )

            if xyz_data is not None:
                uploaded_product = xyz_data

                # Update viewer
                product_viewer.filetype = display_format
                product_viewer.structure = display_content
                product_viewer.param.trigger('structure')

        except Exception:
            pass

# Register event handlers
reactant_dropper.param.watch(on_reactant_drop, 'value')
product_dropper.param.watch(on_product_drop, 'value')

# Utility functions
def sync_styles():
    """Copy all styles from reactant to product"""
    product_viewer.show_stick = reactant_viewer.show_stick
    product_viewer.show_sphere = reactant_viewer.show_sphere
    product_viewer.show_cartoon = reactant_viewer.show_cartoon
    product_viewer.show_line = reactant_viewer.show_line
    product_viewer.show_surface = reactant_viewer.show_surface
    product_viewer.background_color = reactant_viewer.background_color
    product_viewer.render()

def get_xyz_content(molecule_type='reactant'):
    """Get the XYZ content for a molecule"""
    if molecule_type == 'reactant' and uploaded_reactant:
        return uploaded_reactant['xyz_content']
    elif molecule_type == 'product' and uploaded_product:
        return uploaded_product['xyz_content']
    else:
        return None

def get_xyz_path(molecule_type='reactant'):
    """Get the path to the XYZ file for a molecule"""
    if molecule_type == 'reactant' and uploaded_reactant:
        return uploaded_reactant['xyz_path']
    elif molecule_type == 'product' and uploaded_product:
        return uploaded_product['xyz_path']
    else:
        return None

def get_molecule_info(molecule_type='reactant'):
    """Get complete molecule information"""
    if molecule_type == 'reactant' and uploaded_reactant:
        return uploaded_reactant
    elif molecule_type == 'product' and uploaded_product:
        return uploaded_product
    else:
        return None

def get_original_filename(molecule_type='reactant'):
    """Get the original filename before conversion"""
    if molecule_type == 'reactant' and uploaded_reactant:
        return uploaded_reactant['original_filename']
    elif molecule_type == 'product' and uploaded_product:
        return uploaded_product['original_filename']
    else:
        return None

def get_rdkit_mol(molecule_type='reactant'):
    """Get RDKit molecule object from XYZ content"""
    if not RDKIT_AVAILABLE:
        return None

    xyz_content = get_xyz_content(molecule_type)
    if not xyz_content:
        return None

    try:
        with tempfile.NamedTemporaryFile(mode='w', suffix='.xyz', delete=False) as tmp_file:
            tmp_file.write(xyz_content)
            tmp_path = tmp_file.name

        try:
            mol = Chem.MolFromXYZFile(tmp_path)
            if mol is None:
                return None
            return mol
        finally:
            os.unlink(tmp_path)

    except Exception:
        return None

def list_saved_files():
    """List all files that have been saved locally"""
    import glob
    files = {
        'original': glob.glob('./*.pdb') + glob.glob('./*.sdf') + glob.glob('./*.mol') +
                   glob.glob('./*.com') + glob.glob('./*.gjf'),
        'xyz': glob.glob('./*.xyz')
    }

    return files

# Layout
app = pn.Column(
    "## 🧬 Dual Molecular Viewer",

    # File droppers
    pn.Row(reactant_dropper, product_dropper),

    # Side-by-side viewers
    pn.Row(
        pn.Column("### 🧪 Reactant", reactant_viewer),
        pn.Column("### 🎯 Product", product_viewer)
    ),

    sizing_mode='stretch_width'
)

app.servable()
app

In [4]:
#@title 2. Calculation settings (Please click the "Apply" button; otherwise, the settings will not take effect.)
# Panel-based Calculation Settings Interface
# Creates a nice UI for molecular calculation parameters
# Global variables to store calculation settings
import panel as pn
import param

# Enable Panel extensions
pn.extension('tabulator')
pn.config.sizing_mode = 'stretch_width'

# Global variables to store calculation settings
calculation_settings = {
    'charge': 0,
    'mult': 1,
    'nmove': 20,
    'update_teval': False,
    'DMF_convergence': 'tight'
}

class CalculationSettingsApp(param.Parameterized):
    """Panel app for molecular calculation settings"""

    # Parameters
    charge = param.Integer(default=0, bounds=(-10, 10), doc="Total molecular charge")
    mult = param.Integer(default=1, bounds=(1, 10), doc="Spin multiplicity (2S+1)")
    nmove = param.Integer(default=20, bounds=(1, 200), doc="Number of movable evaluation points")
    update_teval = param.Boolean(default=False, doc="Concentrate evaluation points around highest-energy region")
    DMF_convergence = param.Selector(default='tight', objects=['tight', 'middle', 'loose'], doc="DMF convergence criterion")

    def __init__(self, **params):
        super().__init__(**params)
        self.setup_ui()
        # Update global variables when parameters change
        self.param.watch(self.update_globals, list(self.param))

    def setup_ui(self):
        """Setup the user interface"""

        # Header
        self.header = pn.pane.HTML("""
        <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                    color: white; padding: 20px; border-radius: 12px; margin: 10px 0;
                    box-shadow: 0 4px 8px rgba(0,0,0,0.1); text-align: center;">
            <h2 style="margin: 0; font-size: 24px;">⚙️ Calculation Settings</h2>
            <p style="margin: 10px 0 0 0; opacity: 0.9;">Configure molecular calculation parameters</p>
        </div>
        """, width=700)

        # Parameter descriptions
        self.descriptions = pn.pane.HTML("""
        <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; margin: 15px 0;
                    border-left: 4px solid #007bff;">
            <h3 style="color: #007bff; margin-top: 0;">🔧 Parameter Guide</h3>
            <div style="display: grid; gap: 15px;">
                <div style="background: white; padding: 12px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
                    <strong style="color: #495057;">⚡ charge:</strong> Total molecular charge (e.g., 0 for neutral, +1 for cation, -1 for anion)
                </div>
                <div style="background: white; padding: 12px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
                    <strong style="color: #495057;">🌀 mult:</strong> Spin multiplicity (2S+1) - 1 for singlet, 2 for doublet, 3 for triplet
                </div>
                <div style="background: white; padding: 12px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
                    <strong style="color: #495057;">🎯 nmove:</strong> Number of movable evaluation points in DMF (20-50 recommended for initial guesses)
                </div>
                <div style="background: white; padding: 12px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
                    <strong style="color: #495057;">📍 update_teval:</strong> Concentrate evaluation points around highest-energy regions
                </div>
                <div style="background: white; padding: 12px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
                    <strong style="color: #495057;">🎚️ DMF_convergence:</strong> Convergence criterion - tight (more accurate), middle (balanced), loose (faster)
                </div>
            </div>
        </div>
        """, width=700)

        # Current settings display
        self.current_settings = pn.pane.HTML("", width=700)
        self.update_current_display()

        # Apply button
        self.apply_btn = pn.widgets.Button(
            name="✅ Apply Settings",
            button_type="primary",
            width=200,
            height=40,
            margin=(15, 10)
        )
        self.apply_btn.on_click(self.apply_settings)

        # Reset button
        self.reset_btn = pn.widgets.Button(
            name="🔄 Reset to Defaults",
            button_type="light",
            width=200,
            height=40,
            margin=(15, 10)
        )
        self.reset_btn.on_click(self.reset_settings)

    def update_globals(self, event):
        """Update global variables when parameters change"""
        global calculation_settings

        calculation_settings.update({
            'charge': self.charge,
            'mult': self.mult,
            'nmove': self.nmove,
            'update_teval': self.update_teval,
            'DMF_convergence': self.DMF_convergence
        })

        # Also set individual global variables for backward compatibility
        globals()['charge'] = self.charge
        globals()['mult'] = self.mult
        globals()['nmove'] = self.nmove
        globals()['update_teval'] = self.update_teval
        globals()['DMF_convergence'] = self.DMF_convergence

        self.update_compact_display()

    def update_current_display(self):
        """Update the current settings display"""
        self.current_settings.object = f"""
        <div style="background: #e8f5e9; padding: 20px; border-radius: 10px; margin: 15px 0;
                    border-left: 4px solid #28a745;">
            <h3 style="color: #28a745; margin-top: 0;">📊 Current Settings</h3>
            <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px;">
                <div style="background: white; padding: 12px; border-radius: 8px; text-align: center;">
                    <strong style="color: #007bff;">⚡ Charge</strong><br>
                    <span style="font-size: 18px; color: #495057;">{self.charge}</span>
                </div>
                <div style="background: white; padding: 12px; border-radius: 8px; text-align: center;">
                    <strong style="color: #007bff;">🌀 Multiplicity</strong><br>
                    <span style="font-size: 18px; color: #495057;">{self.mult}</span>
                </div>
                <div style="background: white; padding: 12px; border-radius: 8px; text-align: center;">
                    <strong style="color: #007bff;">🎯 nmove</strong><br>
                    <span style="font-size: 18px; color: #495057;">{self.nmove}</span>
                </div>
                <div style="background: white; padding: 12px; border-radius: 8px; text-align: center;">
                    <strong style="color: #007bff;">📍 update_teval</strong><br>
                    <span style="font-size: 18px; color: #495057;">{'Yes' if self.update_teval else 'No'}</span>
                </div>
                <div style="background: white; padding: 12px; border-radius: 8px; text-align: center;">
                    <strong style="color: #007bff;">🎚️ Convergence</strong><br>
                    <span style="font-size: 18px; color: #495057;">{self.DMF_convergence}</span>
                </div>
            </div>
        </div>
        """

    def apply_settings(self, event):
        """Apply current settings"""
        self.update_globals(None)

        # Show confirmation
        confirmation = pn.pane.HTML(f"""
        <div style="background: #d4edda; padding: 15px; border-radius: 8px; margin: 10px 0;
                    border-left: 4px solid #28a745; text-align: center;">
            <p style="color: #155724; margin: 0; font-weight: bold;">
                ✅ Settings applied successfully!
            </p>
        </div>
        """, width=400)

        # Display confirmation temporarily
        import time
        confirmation.param.trigger('object')

    def reset_settings(self, event):
        """Reset to default settings"""
        self.charge = 0
        self.mult = 1
        self.nmove = 20
        self.update_teval = False
        self.DMF_convergence = 'tight'

        self.update_globals(None)

    def create_layout(self):
        """Create compact single-row layout"""

        # Create compact parameter widgets
        charge_widget = pn.widgets.IntSlider(
            name="⚡ Charge",
            value=self.charge,
            start=-10, end=10,
            width=120,
            margin=(5, 5)
        )
        charge_widget.param.watch(lambda event: setattr(self, 'charge', event.new), 'value')

        mult_widget = pn.widgets.IntSlider(
            name="🌀 Mult",
            value=self.mult,
            start=1, end=10,
            width=120,
            margin=(5, 5)
        )
        mult_widget.param.watch(lambda event: setattr(self, 'mult', event.new), 'value')

        nmove_widget = pn.widgets.IntSlider(
            name="🎯 nmove",
            value=self.nmove,
            start=1, end=200,
            width=120,
            margin=(5, 5)
        )
        nmove_widget.param.watch(lambda event: setattr(self, 'nmove', event.new), 'value')

        update_teval_widget = pn.widgets.Checkbox(
            name="📍 update_teval",
            value=self.update_teval,
            width=120,
            margin=(5, 5)
        )
        update_teval_widget.param.watch(lambda event: setattr(self, 'update_teval', event.new), 'value')

        convergence_widget = pn.widgets.Select(
            name="🎚️ Convergence",
            value=self.DMF_convergence,
            options=['tight', 'middle', 'loose'],
            width=120,
            margin=(5, 5)
        )
        convergence_widget.param.watch(lambda event: setattr(self, 'DMF_convergence', event.new), 'value')

        # Compact apply button
        apply_btn = pn.widgets.Button(
            name="✅ Apply",
            button_type="primary",
            width=80,
            height=35,
            margin=(5, 5)
        )
        apply_btn.on_click(self.apply_settings)

        # Compact header
        compact_header = pn.pane.HTML("""
        <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                    color: white; padding: 12px; border-radius: 8px; margin: 5px 0;
                    text-align: center;">
            <h3 style="margin: 0; font-size: 18px;">⚙️ Calculation Settings</h3>
        </div>
        """, width=750)

        # Compact current settings
        self.compact_settings = pn.pane.HTML("", width=750, height=50)
        self.update_compact_display()

        # Store widgets for updates
        self.widgets = {
            'charge': charge_widget,
            'mult': mult_widget,
            'nmove': nmove_widget,
            'update_teval': update_teval_widget,
            'convergence': convergence_widget
        }

        # Controls row WITHOUT the apply button
        controls_row = pn.Row(
            charge_widget,
            mult_widget,
            nmove_widget,
            update_teval_widget,
            convergence_widget,
            margin=(10, 5)
        )

        # Apply button row (separate, below current settings)
        apply_row = pn.Row(
            apply_btn,
            margin=(10, 5)
        )

        return pn.Column(
            compact_header,
            controls_row,
            self.compact_settings,
            apply_row,  # Apply button now below current settings
            width=800,
            margin=(5, 5)
        )

    def update_compact_display(self):
        """Update compact settings display"""
        self.compact_settings.object = f"""
        <div style="background: #f8f9fa; padding: 8px; border-radius: 6px; margin: 5px 0;
                    border-left: 3px solid #28a745; font-size: 14px;">
            <strong>Current:</strong>
            Charge={self.charge}, Mult={self.mult}, nmove={self.nmove},
            update_teval={self.update_teval}, Convergence={self.DMF_convergence}
        </div>
        """

# Helper functions
def create_settings_app():
    """Create the calculation settings app"""
    app = CalculationSettingsApp()
    return app.create_layout()

def get_calculation_settings():
    """Get current calculation settings"""
    return calculation_settings.copy()

def show_current_settings():
    """Display current settings in a readable format"""
    print("Current Calculation Settings:")
    print("=" * 30)
    print(f"⚡ Charge: {calculation_settings['charge']}")
    print(f"🌀 Multiplicity: {calculation_settings['mult']}")
    print(f"🎯 nmove: {calculation_settings['nmove']}")
    print(f"📍 update_teval: {calculation_settings['update_teval']}")
    print(f"🎚️ DMF_convergence: {calculation_settings['DMF_convergence']}")

def export_settings_dict():
    """Export settings as a dictionary for easy use"""
    return {
        'charge': charge,
        'mult': mult,
        'nmove': nmove,
        'update_teval': update_teval,
        'DMF_convergence': DMF_convergence
    }

def validate_settings():
    """Validate current settings and show recommendations"""
    issues = []
    recommendations = []

    if calculation_settings['charge'] < -5 or calculation_settings['charge'] > 5:
        issues.append("⚠️ Unusual charge value - double-check your molecule")

    if calculation_settings['mult'] > 5:
        issues.append("⚠️ High multiplicity - ensure this is correct for your system")

    if calculation_settings['nmove'] < 10:
        recommendations.append("💡 Consider nmove ≥ 20 for better accuracy")
    elif calculation_settings['nmove'] > 100:
        recommendations.append("💡 Large nmove may slow down calculations")

    if issues:
        print("Issues found:")
        for issue in issues:
            print(f"  {issue}")

    if recommendations:
        print("Recommendations:")
        for rec in recommendations:
            print(f"  {rec}")

    if not issues and not recommendations:
        print("✅ Settings look good!")

# Initialize global variables for backward compatibility
charge = calculation_settings['charge']
mult = calculation_settings['mult']
nmove = calculation_settings['nmove']
update_teval = calculation_settings['update_teval']
DMF_convergence = calculation_settings['DMF_convergence']

# Create the app
settings_app = create_settings_app()
settings_app

In [None]:
#@title 3. Hugging Face Token Input
import os
print("Enter your Hugging Face access token. The input will be hidden for security.")
import getpass
HF_TOKEN = getpass.getpass("Hugging Face token:")

os.environ["HF_TOKEN"] = HF_TOKEN


#### 💡 Input arguments were completed!  

# 🚀 II. Execution Section


**`Runtime`** → **`Run this cell and all below`** to start your calculation.

In [None]:
#@title Install dependencies (This may take a few minutes)
# Install compatible numba version first
! pip install "numba==0.60.0"

# Install fairchem-core with specific compatible versions
! time pip install fairchem-core -U

# Install DMF
! pip install git+https://github.com/shin1koda/dmf.git

! pip show fairchem-core

In [None]:
#@title Optimization of the reaction path by DMF/UMA  (This may take a few minutes)
# Run DMF with UMA
import time
import ase.io
import numpy as np
from ase.io import write, read
from dmf import DirectMaxFlux, interpolate_fbenm
from fairchem.core import pretrained_mlip, FAIRChemCalculator
from ase.io.trajectory import Trajectory
import pandas as pd
import matplotlib.pyplot as plt
from google.colab import output
from copy import deepcopy
from ase import Atoms
output.enable_custom_widget_manager()

# Start timer (total)
t_total_start = time.perf_counter()

# title Hugging Face CLI
# Login using the token (non-interactive)
! huggingface-cli login --token $HF_TOKEN

# input args
for uploaded, name in [(uploaded_reactant, 'reactant.xyz'), (uploaded_product, 'product.xyz')]:
    with open(uploaded['xyz_path'], 'r') as src, open(name, 'w') as dst:
        dst.write(src.read())

reactant = 'reactant.xyz'
product = 'product.xyz'

# UMA model. s: small model, m: medium model (accurate but slow).
## uma-s-1 will be deprecated in the future.
model_name = 'uma-s-1p1' # "uma-s-1p1" or "uma-m-1p1"


# == Run DMF/UMA ===================
# Read reactant and product
ref_images = [read(reactant), read(product)]

# Generate initial path by FB-ENM
t_fbenm_start = time.perf_counter()
mxflx_fbenm = interpolate_fbenm(ref_images, correlated=True)
t_fbenm_end = time.perf_counter()
write('DMF_init.xyz', mxflx_fbenm.images)

# Write initial path and its coefficients
write('DMF_init.traj', mxflx_fbenm.images)
coefs = mxflx_fbenm.coefs.copy()
np.save('DMF_init_coefs', coefs)

# Set up and solve Direct MaxFlux
t_dmf_start = time.perf_counter()
mxflx = DirectMaxFlux(ref_images, coefs=coefs, nmove=nmove, update_teval=update_teval)

# Set up predictor
predictor = pretrained_mlip.get_predict_unit(model_name, device='cuda')
for image in mxflx.images:
    image.info["charge"] = charge
    image.info["spin"] = mult
    image.calc = FAIRChemCalculator(predictor, task_name='omol')

mxflx.add_ipopt_options({'output_file': 'DMF_ipopt.out'})
try:
    mxflx.solve(tol=DMF_convergence)
except Exception as e:
    print("solve failed:", e)
    write("DMF_last_before_error.xyz", mxflx.images)
    write("DMF_last_before_error.traj", mxflx.images)
t_dmf_end = time.perf_counter()

# DMF_final.traj: Recalculate SPC for mxflx.images (some frames lack energy)
final_images = []
for img in mxflx.images:
    # Copy atoms and info
    atoms = Atoms(positions=img.get_positions(), numbers=img.get_atomic_numbers())
    atoms.info["charge"] = img.info.get("charge", charge)
    atoms.info["spin"] = img.info.get("spin", mult)

    # Set calculator
    atoms.calc = FAIRChemCalculator(predictor, task_name='omol')

    try:
        # Explicitly calculate energy
        _ = atoms.get_potential_energy()
    except Exception as e:
        print(f"Warning: failed to compute energy for image {len(final_images)}: {e}")

    final_images.append(atoms)

# End timer (total)
t_total_end = time.perf_counter()


# ===== Final file IO =============
# Write energy and force history
with open('energy_history.txt', 'w') as f:
    for step, energies in enumerate(mxflx.history.energies):
        f.write(f"# Iteration {step}\n")
        for i, energy in enumerate(energies):
            f.write(f"Image {i}: {energy:.8f} eV\n")
        f.write("\n")

with open('force_history.txt', 'w') as f:
    for step, forces in enumerate(mxflx.history.forces):
        f.write(f"# Iteration {step}\n")
        for i, force_array in enumerate(forces):  # image index
            f.write(f"Image {i}:\n")
            for vec in force_array:  # force for each atom
                fx, fy, fz = vec
                f.write(f"{fx:.6f} {fy:.6f} {fz:.6f}\n")
            f.write("\n")

# Convert traj to xyz
def traj_to_xyz(traj, out_xyz_path):
    """
    Convert ASE traj to .xyz

    Parameters:
        traj (list of ase.Atoms), out_xyz_path (str): .xyz
    """
    try:
        for atoms in traj:
            atoms.info = {str(k): v for k, v in atoms.info.items()}
        write(out_xyz_path, traj)
    except Exception as e:
        print(f"Warning: An error occurred while writing {out_xyz_path}: {e}")


# x(tmax): path and history
images_tmax = mxflx.history.images_tmax
write('DMF_tmax.traj', images_tmax)
traj_to_xyz(images_tmax, 'DMF_tmax.xyz')

# final_images: save images to .traj
write('DMF_final.traj', final_images)
traj_to_xyz(final_images, 'DMF_final.xyz')


# final images: energies
EV_TO_KCAL_MOL = 23.0605
EV_TO_HARTREE = 1 / 27.2114  # ≒ 0.0367493

data = []

for i, atoms in enumerate(final_images):
    try:
        energy_ev = atoms.get_potential_energy()
        energy_hartree = energy_ev * EV_TO_HARTREE
        energy_kcal = energy_ev * EV_TO_KCAL_MOL
        data.append([i, energy_ev, energy_hartree, energy_kcal])
    except Exception:
        data.append([i, None, None, None])

df = pd.DataFrame(data, columns=[
    "image",
    "energy [eV]",
    "energy [hartree]",
    "energy [kcal/mol]"
])

# Relative energy (kcal/mol)
if df["energy [kcal/mol]"].notna().any():
    ref = df.loc[0, "energy [kcal/mol]"]
    df["Delta E vs. reactant [kcal/mol]"] = df["energy [kcal/mol]"] - ref
else:
    df["Delta E vs. reactant [kcal/mol]"] = None

df.to_csv('DMF_energy.csv', index=False)


# Output time log
with open('timing_log.txt', 'w') as logf:
    logf.write(f"Time for FB-ENM interpolation: {t_fbenm_end - t_fbenm_start:.2f} s\n")
    logf.write(f"Time for DMF total (setup + solve): {t_dmf_end - t_dmf_start:.2f} s\n")
    logf.write(f"Total time: {t_total_end - t_total_start:.2f} s\n")



In [None]:
#@title Latest Molecule Animation Visualizer (It may take a few minutes)
#!/usr/bin/env python3
import panel as pn
import param
import plotly.express as px
import pandas as pd
from panel_3dmol import Mol3DViewer
import warnings
import os
import asyncio

import warnings

#Suppress BokehWarnings using official Bokeh methods
#os.environ['BOKEH_LOG_LEVEL'] = 'error'
# import logging
# logging.getLogger('bokeh').setLevel(logging.ERROR)

# Additional warning suppression
warnings.filterwarnings('ignore', category=UserWarning, module='bokeh')
warnings.filterwarnings('ignore', message='.*bokeh.*')
warnings.filterwarnings("ignore", category=UserWarning, message=".*reference already known.*")
warnings.filterwarnings('ignore', message='.*width-responsive sizing_mode.*')
warnings.filterwarnings('ignore', category=UserWarning, module='param')

# Bokeh validation silence
try:
    from bokeh.core.validation import silence
    from bokeh.core.validation.warnings import MISSING_RENDERERS
    silence(MISSING_RENDERERS, True)
except ImportError:
    pass

try:
    import bokeh.core.validation as validation
    validation.silence_all_warnings()
except (ImportError, AttributeError):
    pass
# for debugging the COMM message, 1/true for debugging.
os.environ["PANEL_DEBUG_COMM"] = ""

# Enable Panel extensions
pn.extension('plotly')
pn.config.sizing_mode = 'stretch_width'

def extract_xyz_frames_to_list(input_file='DMF_final.xyz'):
    """Extract individual frames from a multi-frame XYZ file and return as list"""

    # Read multi-frame xyz file
    with open(input_file, 'r') as f:
        xyz_data = f.read()

    # Split each frame
    lines = xyz_data.split('\n')

    # Remove any empty lines at the end
    while lines and lines[-1].strip() == '':
        lines.pop()

    if not lines:
        return []

    natoms = int(lines[0])   # Number of atoms in the first line
    frame_size = natoms + 2  # Number of lines per frame (natoms + header + comment)
    num_frames = len(lines) // frame_size

    # Extract each frame as individual XYZ strings
    frames = []
    for i in range(num_frames):
        start = i * frame_size
        end = start + frame_size

        # Extract frame lines
        frame_lines = lines[start:end]
        frame_content = '\n'.join(frame_lines)
        frames.append(frame_content)

    return frames

# Load data
xyz_frames = extract_xyz_frames_to_list('DMF_final.xyz')
df_energy = pd.read_csv('DMF_energy.csv')

num_frames = len(xyz_frames)

class AnimatedMolecularViewer(param.Parameterized):
    """
    Enhanced molecular viewer with proper animation start/stop control
    Fixed play/pause functionality to properly control 3Dmol.js animation
    """

    # Animation parameters that sync with Mol3DViewer
    current_frame = param.Integer(default=0, bounds=(0, max(0, num_frames-1)))
    animation_speed = param.Integer(default=200, bounds=(10, 2000), doc="Animation interval in ms")
    is_playing = param.Boolean(default=False)
    loop_mode = param.Selector(default="forward", objects=["forward", "backward", "pingpong"])

    # Display parameters
    show_stick = param.Boolean(default=True)
    show_sphere = param.Boolean(default=True)

    def __init__(self, **params):
        super().__init__(**params)

        # Create the molecular viewer using panel-3dmol with latest animation API
        self.mol_viewer = Mol3DViewer(
            min_width=600,
            height=600,
            background_color='white',
            show_atom_labels=True,
            animate=False,  # Start with animation disabled
            current_frame=0,
            total_frames=num_frames,
            animation_speed=self.animation_speed
        )

        # Load all frames using panel-3dmol's addFrames method
        if xyz_frames:
            self.mol_viewer.addFrames(xyz_frames, 'xyz')
            self.mol_viewer.setStyle({}, {
                'stick': {'radius': 0.08},
                'sphere': {'scale': 0.12}
            })
            # Set to first frame
            self.mol_viewer.setFrame(0)

        # Animation control state
        self._animation_active = False
        self._animation_callback = None
        self._animation_direction = 1  # 1 for forward, -1 for backward
        self._updating_from_panel = False

        # Create UI components
        self.energy_plot = self.create_energy_plot()
        self.controls = self.create_controls()
        self.info_panel = pn.pane.HTML(self.get_frame_info_html(0),
                                      min_width=400, height=500)

        # Set up parameter watchers for Panel-controlled animation
        self.param.watch(self.on_frame_change, 'current_frame')
        self.param.watch(self.on_animation_control, ['is_playing', 'animation_speed', 'loop_mode'])
        self.param.watch(self.on_display_change, ['show_stick', 'show_sphere'])

        # Watch mol_viewer frame changes (for 3Dmol.js animation sync)
        self.mol_viewer.param.watch(self.on_mol_viewer_frame_change, 'current_frame')

        print(f"Initialized with {num_frames} frames using native 3Dmol.js animation")

    def on_mol_viewer_frame_change(self, event):
        """Handle frame changes from 3Dmol.js animation (JavaScript -> Python sync)"""
        new_frame = event.new

        if new_frame != self.current_frame and not self._updating_from_panel:
            # Update Panel's current_frame to match 3Dmol.js animation
            self.current_frame = new_frame

            # Update info panel and energy plot to follow 3Dmol.js animation
            self.info_panel.object = self.get_frame_info_html(new_frame)
            self.update_energy_plot()

    def apply_molecular_style(self):
        """Apply current molecular styling"""
        style = {}
        if self.show_stick:
            style['stick'] = {'radius': 0.08}
        if self.show_sphere:
            style['sphere'] = {'scale': 0.12}

        if style:
            self.mol_viewer.setStyle({}, style)
        else:
            # Default style if nothing selected
            self.mol_viewer.setStyle({}, {'stick': {'radius': 0.08}})

    def create_energy_plot(self):
        """Create interactive energy plot with current frame indicator"""
        fig = px.scatter(
            df_energy,
            x='image',
            y='Delta E vs. reactant [kcal/mol]',
            labels={
                'image': 'Image index',
                'Delta E vs. reactant [kcal/mol]': 'ΔE (kcal/mol)'
            }
        )

        # Add current frame indicator
        self.update_energy_plot_marker(fig)

        # Styling
        fig.update_traces(marker=dict(size=8))
        fig.update_layout(
            font=dict(family='Arial', size=14, color='black'),
            title=dict(text="Energy Profile Along Reaction Path",
                      font=dict(family='Arial', size=16, color='black')),
            xaxis=dict(title_font=dict(family='Arial', size=14, color='black'),
                      tickfont=dict(family='Arial', size=12, color='black'),
                      showline=True, linecolor='black', linewidth=1,
                      mirror=True, ticks='outside', showgrid=False),
            yaxis=dict(title_font=dict(family='Arial', size=14, color='black'),
                      tickfont=dict(family='Arial', size=12, color='black'),
                      showline=True, linecolor='black', linewidth=1,
                      mirror=True, ticks='outside', showgrid=False),
            plot_bgcolor='rgba(0,0,0,0)',
            paper_bgcolor='rgba(0,0,0,0)',
            height=500,
            showlegend=False
        )

        plot_pane = pn.pane.Plotly(fig, min_width=600, height=500)
        plot_pane.param.watch(self.on_plot_click, 'click_data')

        return plot_pane

    def update_energy_plot_marker(self, fig):
        """Update the current frame marker on the energy plot"""
        if len(df_energy) > self.current_frame:
            current_energy = df_energy.iloc[self.current_frame]['Delta E vs. reactant [kcal/mol]']
            fig.add_scatter(
                x=[self.current_frame],
                y=[current_energy],
                mode='markers',
                marker=dict(size=15, color='red', symbol='diamond'),
                name='Current Frame',
                showlegend=False
            )

    def create_controls(self):
        """Create animation control panel"""

        # Frame controls
        frame_slider = pn.Param(
            self, parameters=['current_frame'],
            widgets={'current_frame': pn.widgets.IntSlider},
            name="Frame Control"
        )

        # Animation controls with explicit play/stop buttons
        play_button = pn.widgets.Button(
            name='▶️ Play',
            button_type='primary',
            width=80,
            height=35
        )
        play_button.on_click(self.start_animation)

        stop_button = pn.widgets.Button(
            name='⏹️ STOP',
            button_type='default',
            width=80,
            height=35,
            css_classes=['stop-button']  # Add CSS class for styling
        )
        stop_button.on_click(self.stop_animation)

        # Animation speed control with live update capability
        speed_control = pn.Param(
            self, parameters=['animation_speed'],
            widgets={'animation_speed': {
                'type': pn.widgets.IntSlider,
                'throttled': True,  # Reduces excessive updates while dragging
                'step': 10  # Make it easier to adjust in meaningful increments
            }},
            name="Speed (ms) - Live Update"
        )

        loop_control = pn.Param(
            self, parameters=['loop_mode'],
            widgets={'loop_mode': pn.widgets.Select},
            name="Loop Mode"
        )

        # Display controls
        display_controls = pn.Param(
            self, parameters=['show_stick', 'show_sphere'],
            name="Display Options"
        )

        # Quick navigation buttons with enhanced stop functionality
        def go_to_start():
            self.force_stop_animation()  # Use enhanced stop
            self.current_frame = 0

        def go_to_end():
            self.force_stop_animation()  # Use enhanced stop
            self.current_frame = num_frames - 1

        def step_backward():
            self.force_stop_animation()  # Use enhanced stop
            if self.current_frame > 0:
                self.current_frame -= 1

        def step_forward():
            self.force_stop_animation()  # Use enhanced stop
            if self.current_frame < num_frames - 1:
                self.current_frame += 1

        nav_buttons = pn.Row(
            pn.widgets.Button(name='⏮️ Start', button_type='primary', width=80),
            pn.widgets.Button(name='⏪ Step', button_type='default', width=80),
            pn.widgets.Button(name='⏩ Step', button_type='default', width=80),
            pn.widgets.Button(name='⏭️ End', button_type='primary', width=80)
        )

        # Connect button callbacks
        nav_buttons[0].on_click(lambda event: go_to_start())
        nav_buttons[1].on_click(lambda event: step_backward())
        nav_buttons[2].on_click(lambda event: step_forward())
        nav_buttons[3].on_click(lambda event: go_to_end())

        # Animation control buttons
        animation_buttons = pn.Row(
            play_button,
            stop_button,
            margin=(10, 5)
        )

        return pn.Column(
            "### 🎬 Animation Controls",
            frame_slider,
            nav_buttons,
            animation_buttons,
            speed_control,
            loop_control,
            "### 🎨 Display Options",
            display_controls,
            width=350
        )

    def start_animation(self, event):
        """Start the 3Dmol.js animation"""
        print("🎬 Starting animation...")

        # Update the animation speed first
        self.mol_viewer.animation_speed = self.animation_speed

        # Enable animation on the mol_viewer
        self.mol_viewer.animate = True
        self.is_playing = True
        self._animation_active = True

        # Update info panel to show playing status
        self.info_panel.object = self.get_frame_info_html(self.current_frame)

        print(f"Animation started with speed: {self.animation_speed}ms")

    def stop_animation(self, event):
        """Stop the 3Dmol.js animation instantly using stopAnimationImmediate()"""
        print("⏸️ Stopping animation instantly...")

        try:
            # Use the panel-3dmol specific method for immediate stop
            self.mol_viewer.stopAnimationImmediate()

            # Update internal state
            self.is_playing = False
            self._animation_active = False

            print("✅ Animation stopped immediately using stopAnimationImmediate()")

        except Exception as e:
            print(f"⚠️ Error with stopAnimationImmediate: {e}")
            # Fallback to original method
            self.mol_viewer.animate = False
            self.is_playing = False
            self._animation_active = False
            print("🔄 Used fallback stop method")

        # Update info panel to show stopped status
        self.info_panel.object = self.get_frame_info_html(self.current_frame)

    def get_frame_info_html(self, frame_id):
        """Generate HTML for frame information"""
        if frame_id < len(df_energy):
            energy_value = df_energy.iloc[frame_id]['Delta E vs. reactant [kcal/mol]']
        else:
            energy_value = 0.0

        progress = (frame_id / max(1, num_frames - 1) * 100) if num_frames > 1 else 0

        # Status indicator
        status = "🎬 Playing" if self.is_playing else "⏸️ Stopped"

        return f"""
        <div style="padding: 15px; border: 1px solid #2E86C1; border-radius: 8px; background-color: #f8f9fa;">
            <h3 style="margin-top: 0; color: #2E86C1;">Current Frame Information</h3>
            <hr style="border-color: #2E86C1; margin: 10px 0;">
            <p><strong>Frame:</strong> {frame_id} / {num_frames - 1}</p>
            <p><strong>Energy:</strong> {energy_value:.2f} kcal/mol</p>
            <p><strong>Progress:</strong> {progress:.1f}%</p>
            <p><strong>Status:</strong> {status}</p>
            <p><strong>Speed:</strong> {self.animation_speed}ms</p>
            <p><strong>Loop Mode:</strong> {self.loop_mode}</p>
            <div style="background-color: #e7f3ff; padding: 8px; border-radius: 4px; margin-top: 10px;">
                <small><em>Use Play/Stop buttons for reliable animation control</em></small>
            </div>
        </div>
        """

    def on_frame_change(self, event):
        """Handle frame changes - update mol_viewer and energy plot"""
        frame_id = event.new

        # Flag to prevent feedback loops
        self._updating_from_panel = True

        # Update the molecular viewer frame
        self.mol_viewer.setFrame(frame_id)

        # Update info panel
        self.info_panel.object = self.get_frame_info_html(frame_id)

        # Update energy plot marker
        self.update_energy_plot()

        self._updating_from_panel = False

    def update_energy_plot(self):
        """Update the energy plot with new current frame marker"""
        # Recreate the plot with updated marker
        fig = px.scatter(
            df_energy,
            x='image',
            y='Delta E vs. reactant [kcal/mol]',
            labels={
                'image': 'Image index',
                'Delta E vs. reactant [kcal/mol]': 'ΔE (kcal/mol)'
            }
        )

        self.update_energy_plot_marker(fig)

        # Apply styling
        fig.update_traces(marker=dict(size=8))
        fig.update_layout(
            font=dict(family='Arial', size=14, color='black'),
            title=dict(text="Energy Profile Along Reaction Path",
                      font=dict(family='Arial', size=16, color='black')),
            xaxis=dict(title_font=dict(family='Arial', size=14, color='black'),
                      tickfont=dict(family='Arial', size=12, color='black'),
                      showline=True, linecolor='black', linewidth=1,
                      mirror=True, ticks='outside', showgrid=False),
            yaxis=dict(title_font=dict(family='Arial', size=14, color='black'),
                      tickfont=dict(family='Arial', size=12, color='black'),
                      showline=True, linecolor='black', linewidth=1,
                      mirror=True, ticks='outside', showgrid=False),
            plot_bgcolor='rgba(0,0,0,0)',
            paper_bgcolor='rgba(0,0,0,0)',
            height=500,
            showlegend=False
        )

        # Update the plot pane
        self.energy_plot.object = fig

    def on_animation_control(self, event):
        """Handle animation control changes"""
        if event.name == 'animation_speed':
            print(f"🎚️ Changing animation speed to {event.new}ms")

            # Update speed on mol_viewer
            self.mol_viewer.animation_speed = event.new

            # If currently playing, use immediate stop and restart
            if self.is_playing and self._animation_active:
                print("🔄 Restarting animation with new speed...")
                # Use immediate stop method
                try:
                    self.mol_viewer.stopAnimationImmediate()
                except:
                    self.mol_viewer.animate = False

                # Brief delay to ensure stop takes effect
                import time
                time.sleep(1)

                # Restart with new speed
                self.mol_viewer.animate = True
                print(f"✅ Animation restarted with speed: {event.new}ms")

            # Update info panel to reflect new speed
            self.info_panel.object = self.get_frame_info_html(self.current_frame)

        elif event.name == 'loop_mode':
            # Future: could modify JavaScript to support different loop modes
            pass

    def on_display_change(self, event):
        """Handle display option changes"""
        self.apply_molecular_style()

    def force_stop_animation(self):
        """Force stop animation without event parameter - for internal use"""
        self.stop_animation(None)

    def on_plot_click(self, event):
        """Handle energy plot clicks"""
        if event.new and 'points' in event.new:
            point = event.new['points'][0]
            frame_id = int(point['x'])
            if 0 <= frame_id < num_frames:
                # Force stop animation when user clicks on plot
                self.force_stop_animation()
                self.current_frame = frame_id

    def create_dashboard(self):
        """Create the complete dashboard"""

        # Title
        title = pn.pane.HTML("""
        <h1 style="color: #2E86C1; text-align: center; font-family: Arial;">
            🧬 Animated Molecular Reaction Path Viewer
        </h1>
        <hr style="border-color: #2E86C1;">
        """, min_width=1400, height=80)

        # Instructions
        instructions = pn.pane.HTML("""
        <div style="padding: 10px; background-color: #e7f3ff; border-left: 4px solid #2E86C1; border-radius: 4px; margin-bottom: 20px;">
            <strong>Enhanced Animation Controls:</strong> Using stopAnimationImmediate() for instant, single-click stop functionality.
            <strong>Wide Speed Range:</strong> Animation speed from 10ms (ultra-fast) to 2000ms (very slow) - adjust in real-time while playing!
            The animation will stop immediately when you click the stop button or use navigation controls.
        </div>
        """, min_width=1400, height=60)

        # Main layout
        main_content = pn.Row(
            # Left: Molecular viewer
            pn.Column(
                pn.pane.HTML("<h3 style='color: #2E86C1; text-align: center;'>🧬 3D Structure</h3>"),
                self.mol_viewer,
                min_width=600
            ),

            # Center: Energy plot
            pn.Column(
                pn.pane.HTML("<h3 style='color: #2E86C1; text-align: center;'>📊 Energy Profile</h3>"),
                self.energy_plot,
                min_width=600
            ),

            # Right: Controls and info
            pn.Column(
                self.controls,
                pn.Spacer(height=20),
                self.info_panel,
                min_width=400
            ),

            sizing_mode='stretch_width'
        )

        return pn.Column(
            title,
            instructions,
            main_content,
            min_width=1600,
            height=900,
            sizing_mode='stretch_width'
        )

# Create the application
viewer_app = AnimatedMolecularViewer()
app = viewer_app.create_dashboard()

# Make it servable
app.servable()

# Display the app
app

In [None]:
#@title Calculating Imaginary Frequency of All Local Maxima (This may take a few minutes)
import numpy as np
import os
from ase.io import read, write
from ase import Atoms
from ase.vibrations import Vibrations
from fairchem.core import pretrained_mlip, FAIRChemCalculator
from ase.io.trajectory import Trajectory
from scipy.signal import find_peaks

def extract_peaks_from_traj(trajfile, maxima_filename, prominence=0.01):
    # Load all frames from trajectory
    traj = read(trajfile, index=':')
    energies = []
    for i, atoms in enumerate(traj):
        try:
            energy = atoms.get_potential_energy()
        except:
            #print(f"Warning: energy not available for frame {i}")
            energy = np.nan  # or continue
        energies.append(energy)
    energies = np.array(energies)

    # fill NaN
    def forward_fill_nan(arr):
        filled = arr.copy()
        last_valid = np.nan
        for i in range(len(filled)):
            if not np.isnan(filled[i]):
                last_valid = filled[i]
            else:
                filled[i] = last_valid
        return filled

    energies_filled = forward_fill_nan(energies)

    # peak detect
    peaks, _ = find_peaks(energies_filled, prominence=prominence)

    # Input basename (e.g., input.xyz → input)
    base_name = os.path.splitext(os.path.basename(maxima_filename))[0]

    print(f"Detected {len(peaks)} peak(s). Saving structures:")
    peak_list = []

    for idx in peaks:
        atoms = traj[idx]
        filename = f"{base_name}_{idx}.xyz"
        peak_list.append(filename)
        write(filename, atoms)
        print(f"  → {filename} (energy = {energies[idx]:.6f})")

    return peak_list

def write_gaussian_like_freq(vib: Vibrations, output):
    freqs_cm1 = vib.get_frequencies(method="standard")  # unit: cm^-1
    zpe_eV = vib.get_zero_point_energy()  # unit: eV
    energy_eV = vib.atoms.get_potential_energy()  # unit: eV

    zpe_hartree = zpe_eV / 27.2114
    energy_hartree = energy_eV / 27.2114
    total_energy = energy_hartree + zpe_hartree

    with open(output, "w") as f:
        f.write(" Frequencies -- ")
        for i, freq in enumerate(freqs_cm1):
            f.write(f"{freq:10.4f}")
            if (i + 1) % 3 == 0:
                f.write("\n Frequencies -- ")
        f.write("\n\n")
        f.write(f" Zero-point correction= {zpe_hartree:.6f} (Hartree/Particle)\n")
        f.write(f" Electronic energy    = {energy_hartree:.6f} Hartree\n")
        f.write(f" E + ZPE              = {total_energy:.6f} Hartree\n")
    print(f"[INFO] Energies & frequencies are saved to: {output}")

def write_gaussian_style_log(vib, atoms, filename, charge=0, mult=1):
    freqs = vib.get_frequencies()
    natoms = len(atoms)
    numbers = atoms.get_atomic_numbers()
    modes = [vib.get_mode(i) for i in range(len(freqs))]
    forces = atoms.get_forces()

    # Vibration（cm-1）
    def write_modes_block(f, start, end):
        f.write("\n{:>12}".format(""))
        for i in range(start, end):
            f.write("{:^24}".format(i + 1))
        f.write("\n{:>12}".format(""))
        for _ in range(start, end):
            f.write("{:^24}".format("A"))
        f.write("\n Frequencies --")
        for i in range(start, end):
            freq = freqs[i]
            if isinstance(freq, complex):
                if abs(freq.imag) > 1e-6:
                    f.write(f"{-freq.imag:>16.4f}       ")
                else:
                    f.write(f"{freq.real:>16.4f}       ")
            else:
                f.write(f"{freq:>16.4f}       ")
        f.write("\n  Atom  AN" + "      X      Y        Z " * (end - start) + "\n")

        # GaussView Results>Vibration animation
        for a in range(natoms):
            f.write(f"{a+1:6d}{numbers[a]:4d}")
            for j in range(start, end):
                dx, dy, dz = modes[j][a]
                f.write(f"{dx:8.2f}{dy:8.2f}{dz:8.2f}")
            f.write("\n")

    # COordinates in log file
    def format_atoms(atoms, charge, mult):
        parts = [f"q\\\\t\\\\{charge},{mult}"]
        for i, atom in enumerate(atoms):
            symbol = atom.symbol
            x, y, z = atom.position
            parts.append(f"{symbol},{x:.10f},{y:.10f},{z:.10f}")
        full_string = "\\".join(parts) + "\\\\Version=Fujitsu-XTC-G16RevC.02\\"
        line_width=70
        lines = [full_string[i:i+line_width] for i in range(0, len(full_string), line_width)]
        return "\n ".join(lines)
        #return "\\\n".join(lines) + "\\\\Version=Fujitsu-XTC-G16RevC.02\\"

    # Writing to freq_from_uma.log
    with open(filename, "w") as f:
        f.write(" ----------------------------------------------------------------------\n")
        f.write(" #P \n")
        f.write(" ----------------------------------------------------------------------\n")
        f.write("\n  and normal coordinates:\n")

        for i in range(0, len(freqs), 3):
            write_modes_block(f, i, min(i + 3, len(freqs)))

        # Writing Forces
        f.write("\n ***** Axes restored to original set *****\n")
        f.write(" -------------------------------------------------------------------\n")
        f.write(" Center     Atomic                   Forces (Hartrees/Bohr)\n")
        f.write(" Number     Number              X              Y              Z\n")
        f.write(" -------------------------------------------------------------------\n")
        for i, (num, force) in enumerate(zip(numbers, forces)):
            f.write(f"{i+1:10d}{num:10d}{force[0]:16.9f}{force[1]:14.9f}{force[2]:14.9f}\n")
        f.write("-------------------------------------------------------------------\n\n\n")
        f.write(" Test job not archived.\n")
        f.write(" 1\\1\\G\\OPT\\R\\C\\S\n")
        f.write(" 5\\0\\\\#P\n ")
        formatted = format_atoms(atoms, charge, mult)
        f.write(formatted + "\n")
        f.write(" [X()]\\\\\\@\n\n Normal termination of Gaussian 16 at Sat Jun 28 18:15:40 2025.\n")

    print(f"[INFO] frequency log file compatible to GaussView is saved to: {filename}")

def generate_vibration_xyz(atoms, vib, mode_index, steps, scale, output):
    """
    Generate an .xyz animation of vibration along the selected mode.

    Parameters:
    - atoms: ASE Atoms object (original geometry)
    - vib: ASE Vibrations object
    - mode_index: Index of vibration mode to animate (default: 0)
    - steps: Number of steps for half cycle (default: 10)
    - scale: Scaling factor for mode displacement (default: 1.0)
    - output: Output .xyz file name
    """
    mode = vib.get_mode(mode_index)  # (N_atoms, 3) displacement vectors
    mode = np.array(mode)
    original_positions = atoms.get_positions()
    images = []

    def generate_half_cycle(sign):
        for i in range(steps):
            factor = sign * (i + 1) / steps
            displaced = original_positions + factor * scale * mode
            new_atoms = atoms.copy()
            new_atoms.set_positions(displaced)
            images.append(new_atoms.copy())

        for i in range(steps):
            factor = sign * (steps - i - 1) / steps
            displaced = original_positions + factor * scale * mode
            new_atoms = atoms.copy()
            new_atoms.set_positions(displaced)
            images.append(new_atoms.copy())

    generate_half_cycle(+1)  # +mode → original
    generate_half_cycle(-1)  # -mode → original

    write(output, images)
    print(f"[Info] Written {len(images)} frames to {output}")



# animation options
steps = 10  #@param {type:"integer", min:1}
scale = 1.0  #@param {type:"number", min:0.5}

# detect & save local maxima
peak_files = []
peak_files = extract_peaks_from_traj("DMF_final.traj", "local_maxima.xyz", prominence=0.01)

# output file name
for peak_file in peak_files:
    base, _ = os.path.splitext(peak_file)
    log_filename = base + ".log"
    txt_filename = base + ".txt"
    #vib_filename = base + "_vib.xyz"
    atoms = read(peak_file)
    atoms.info["charge"] = charge
    atoms.info["spin"] = mult
    predictor = pretrained_mlip.get_predict_unit(model_name, device="cuda")
    calc = FAIRChemCalculator(predictor, task_name="omol")
    atoms.calc = calc
    forces = atoms.get_forces()
    max_force = np.abs(forces).max()
    vib = Vibrations(atoms, name="vib_uma")
    vib.run()
    vib.get_frequencies()
    write_gaussian_style_log(vib, atoms, log_filename, charge, mult)
    write_gaussian_like_freq(vib, txt_filename)
    #generate_vibration_xyz(atoms, vib, 0, steps, scale, vib_filename)
    for mode in range(0, 3):
        vib_filename = f"{base}_vib_{mode}.xyz"
        generate_vibration_xyz(atoms, vib, mode, steps, scale, vib_filename)
    vib.clean()




In [None]:
#@title Visualizing Imaginary Frequencies of All Local Maxima
import ipywidgets as widgets
from IPython.display import display
import py3Dmol
import glob
import re
from IPython.display import Markdown

vib_files = glob.glob("*_vib_*.xyz")
vib_files.sort(key=lambda x: int(re.findall(r'(\d+)', x)[-1]), reverse=False)

for vib_file in vib_files:

    # Read multi-frame xyz file
    with open(vib_file) as f:
        xyz_data = f.read()

    # Parse frame count
    lines = xyz_data.strip().split('\n')
    num_atoms = int(lines[0])
    frame_size = num_atoms + 2
    num_frames = len(lines) // frame_size

    print(f"Detected frames: {num_frames}")

    # Interval slider (1–500 ms)
    interval_slider = widgets.IntSlider(
        value=30, min=1, max=500, step=1, description='Interval (ms)', continuous_update=False
    )

    # Output area
    out = widgets.Output()

    # Animation rendering function
    def draw_animation(interval):
        with out:
            out.clear_output(wait=True)
            view = py3Dmol.view(width=600, height=600)
            view.addModelsAsFrames(xyz_data, 'xyz')
            view.setStyle({'sphere': {'scale': 0.12}, 'stick': {'radius': 0.08}})
            view.setBackgroundColor('white')

            view.animate({'loop': 'forward', 'step': 1, 'interval': interval})
            view.zoomTo()
            view.show()

    # Bind slider to rendering
    display(Markdown(f"**✅ Imaginary frequency of `{vib_file}`**"))
    interactive_plot = widgets.interactive_output(draw_animation, {'interval': interval_slider})

    # Display UI
    display(interval_slider, out)


In [None]:
#@title  (Option) Converting the Reaction Animation to GaussView Compatible File
from ase.io import read
import sys

EV_TO_HARTREE = 1 / 27.2114

HEADER = """\
-----------------------------------------------------------------------
# DirectMaxFlux/UMA, Energy(wB97M-V/def2-TZVPD)
-----------------------------------------------------------------------

IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC

Copyright (c) 2025
Computational Biology Laboratory＠the University of Tokyo
(https://www.bi.a.u-tokyo.ac.jp/)　All Rights Reserved.


***********************************************************************
   18-June-2025  Created by Computational Biology Laboratory＠UTokyo
***********************************************************************


If you use this program in your research, please cite the following ref
erences.

1. Nakano, M.; Karasawa, M.; Ohmura, T.; Terada, T.; Sato, H. ChemRxiv
   2025. DOI: 10.26434/chemrxiv-2025-md8k6-v2
2. Koda, S.; Saito, S. Locating Transition States by Variational Reacti
   on Path Optimization with an Energy-Derivative-Free Objective Functi
   on. J. Chem. Theory Comput. 2024, 20 (7), 2798-2811.
3. Koda, S.; Saito, S. Flat-Bottom Elastic Network Model for Generating
   Improved Plausible Reaction Paths. Journal of Chemical Theory and Co
   mputation 2024, 20 (16), 7176-7187.
4. Koda, S.; Saito, S. Correlated Flat-Bottom Elastic Network Model for
   Improved Bond Rearrangement in Reaction Paths. J. Chem. Theory Compu
   t. 2025, 21 (7), 3513-3522.
5. Wood, B. M.; Dzamba, M.; Fu, X.; Gao, M.; Shuaibi, M.; Barroso-Luque
   , L.; Abdelmaqsoud, K.; Gharakhanyan, V.; Kitchin, J. R.; Levine, D.
   S.; et al. UMA: A Family of Universal Models for Atoms. arXiv prepri
   nt 2025, https://ai.meta.com/research/publications/uma-a-family-of-u
   niversal-models-for-atoms.
6. Levine, D. S.; Shuaibi, M.; Spotte-Smith, E. W. C.; Taylor, M. G.; H
   asyim, M. R.; Michel, K.; Batatia, I.; Csányi, G.; Dzamba, M.; Eastm
   an, P.; et al. The Open Molecules 2025 (OMol25) Dataset, Evaluations
   , and Models. arXiv preprint 2025, arXiv:2505.08762 [physics.chem-ph
   ].
7. fairchem; https://github.com/facebookresearch/fairchem
"""

def format_structure(atoms):
    lines = []
    lines.append("IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC")
    lines.append("                         Input orientation:")
    lines.append("---------------------------------------------------------------------")
    lines.append("Center     Atomic      Atomic             Coordinates (Angstroms)")
    lines.append("Number     Number       Type             X           Y           Z")
    lines.append("---------------------------------------------------------------------")
    for i, atom in enumerate(atoms):
        atomic_num = atom.number
        x, y, z = atom.position
        lines.append(f"{i+1:5d}\t{atomic_num:<2d}\t0\t{x: .10f}\t{y: .10f}\t{z: .10f}")
    lines.append("---------------------------------------------------------------------")
    return "\n".join(lines)

def convert_traj_to_log(traj_file, output_log):
    traj = read(traj_file, ":")
    nframes = len(traj)
    empty_indices = []
    last_energy = 0.0

    with open(output_log, "w") as f:
        f.write(HEADER + "\n\n")

        for i, atoms in enumerate(traj):
            coord = format_structure(atoms)

            try:
                energy_ev = atoms.get_potential_energy()
                if energy_ev is None:
                    raise ValueError("Energy is None")
                energy_hartree = energy_ev * EV_TO_HARTREE
                last_energy = energy_hartree
            except Exception:
                energy_hartree = last_energy
                empty_indices.append(i)

            if i == 0:
                f.write(coord + "\n")
                f.write(f"SCF Done:  E(scf) =  {energy_hartree: .10f}     A.U.\n\n")
            else:
                pt_number = i - 1
                f.write("IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC\n")
                f.write(f"Pt {pt_number} Step number   1 out of a maximum of  1\n")
                f.write(f"NET REACTION COORDINATE UP TO THIS POINT = {float(pt_number):20.10f}\n\n")
                f.write(coord + "\n")
                f.write(f"SCF Done:  E(scf) =  {energy_hartree: .10f}     A.U.\n\n")

        f.write("IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC\n")
        f.write(f"Pt {nframes - 1} Step number   1 out of a maximum of  1\n")
        f.write(f"NET REACTION COORDINATE UP TO THIS POINT = {float(nframes - 1):20.10f}\n\n")
        f.write("IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC-IRC\n")
        f.write("Normal termination of Gaussian\n")


    # Warning for empty energy
    if empty_indices:
        print("⚠️ Warning: Energy not found at the following structure(s):")
        print("  " + ", ".join(f"#{i}" for i in empty_indices))

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("Usage: python traj_to_gaussian_log.py input.traj output.log")
        sys.exit(1)

    traj_file = "DMF_final.traj"
    output_log = "DMF_final_gv.log"
    convert_traj_to_log(traj_file, output_log)


In [None]:
#@title Export DMF Path as Multi-model PDB or SDF
from rdkit.Chem import AllChem, SDWriter, PDBWriter, MolToMolBlock, rdmolops
from ase.io import read

def load_rdkit_mol(input_file):
    #ext = os.path.splitext(input_file)[-1].lower()
    #ext = input_file['format']
    ext = uploaded_reactant['original_format']

    # Load molecule using RDKit based on the file extension
    if ext == '.pdb':
        mol = Chem.MolFromPDBFile(input_file, removeHs=False, sanitize=False)
    elif ext == '.mol2':
        mol = Chem.MolFromMol2File(input_file, removeHs=False, sanitize=False)
    elif ext == '.sdf':
        suppl = Chem.SDMolSupplier(input_file, removeHs=False, sanitize=False)
        mol = suppl[0] if suppl and len(suppl) > 0 else None
    else:
        #raise ValueError(f"Unsupported file format: {ext}")
        print(f"❌ Unsupported extension: {ext}")
        mol = None

    return mol


def write_clean_multipdb(mol, frames, output_pdb):
    natoms = mol.GetNumAtoms()
    with open(output_pdb, 'w') as f:
        for i, frame in enumerate(frames):
            f.write(f"{'MODEL':<6}{' ':4}{i+1:>4}\n")

            # Replace coordinates with those from each frame
            conf = Chem.Conformer(natoms)
            for j in range(natoms):
                x, y, z = frame.positions[j]
                conf.SetAtomPosition(j, (x, y, z))
            mol.RemoveAllConformers()
            mol.AddConformer(conf)

            # Generate PDB block and remove unwanted lines
            pdb_block = Chem.MolToPDBBlock(mol)
            clean_block = "\n".join(
                line for line in pdb_block.splitlines()
                if not line.startswith(("MODEL", "ENDMDL", "REMARK", "CONECT", "MASTER", "END"))
            )
            f.write(clean_block + "\n")
            f.write("ENDMDL\n")
        f.write("END\n")
    print(f"✅ Multi-model PDB saved: {output_pdb}")


def write_clean_sdf(mol, frames, output_sdf):
    natoms = mol.GetNumAtoms()
    mol.RemoveAllConformers()

    # Insert each frame as a conformer
    for i, frame in enumerate(frames):
        conf = Chem.Conformer(natoms)
        for j in range(natoms):
            x, y, z = frame.positions[j]
            conf.SetAtomPosition(j, (x, y, z))
        conf.SetId(i)
        mol.AddConformer(conf)

    # Write each conformer as a separate entry (without kekulization)
    with open(output_sdf, 'w') as f:
        for i in range(mol.GetNumConformers()):
            mol_block = MolToMolBlock(mol, confId=i, kekulize=False)
            f.write(mol_block)
            f.write('$$$$\n')

    print(f"✅ Multi-model SDF saved:: {output_sdf}")

# Input file (uploaded via Colab)
print(f"template: {uploaded_reactant}")
input_file = uploaded_reactant
mol = load_rdkit_mol(input_file)
frames = read('DMF_final.xyz', index=':')

# If mol is None, skip processing
if mol is not None:
    basename, ext = os.path.splitext(input_file['filename'])

    if ext.lower() == '.pdb':
        write_clean_multipdb(mol, frames, output_pdb="DMF_final.pdb")

    elif ext.lower() == '.mol2':
        print("⚠️ RDKit does not support writing MOL2 format. The structure was saved as a PDB file instead.")
        write_clean_multipdb(mol, frames, output_pdb="DMF_final.pdb")

    elif ext.lower() == '.sdf':
        write_clean_sdf(mol, frames, output_sdf="DMF_final.sdf")

    else:
        print(f"❌ Unsupported extension: {ext}")
else:
    print("ℹ️ Skipped export: mol is None (unsupported or failed to parse input file)")


In [None]:
#@title Download results (zip file)
import zipfile
import re
import glob
from google.colab import files

# Name of the zip archive to be created
zip_name = "DMF_results.zip"

# Explicitly listed output files (add or remove as needed)
output_files = [
    "DMF_energy.csv", "DMF_final.traj","DMF_final.xyz",
    "DMF_init.traj", "DMF_init.xyz",
    "DMF_ipopt.out", "DMF_tmax.traj", "DMF_tmax.xyz",
    "energy_history.txt", "force_history.txt", "timing.log",
    "fbenm_ipopt.out", "reactant.xyz", "product.xyz"
]

output_files = (
    ["DMF_energy.csv", "energy_history.txt"] +
    glob.glob("DMF_*.traj") +
    glob.glob("DMF_*.xyz") +
    glob.glob("*.out") +
    glob.glob("*.log") +
    glob.glob("*.txt") +
    ["reactant.xyz", "product.xyz"]
)

# Include any .pdb, .sdf, or .mol2 files in the current directory
allowed_exts = ('.pdb', '.sdf', '.mol2')
for fname in os.listdir():
    if fname.endswith(allowed_exts):
        if fname not in output_files:
            output_files.append(fname)

# Create zip archive
with zipfile.ZipFile(zip_name, 'w') as zipf:
    for f in output_files:
        if os.path.exists(f):
            zipf.write(f)
        else:
            print(f"⚠️ Skipped missing file: {f}")

# Provide download link
files.download(zip_name)


# 📦 Output Description
**`reactant.xyz` and `product.xyz`**

Coordinates of the reactant and product used for the DMF calculation, converted from the input files.

---

## 🗝️ Key Output Files

### `DMF_final`

* The final reaction pathway optimized by DMF/UMA.
* `.xyz`: Atomic coordinates in XYZ format.
* `.traj`: ASE trajectory file.
* `.pdb` or `.sdf`: Generated when input files are in `.pdb`, `.sdf`, or `.mol2` format.
* `_gv.log`: GaussView-compatible file for reaction pathway animation.


### `DMF_energy.csv`

* Energies along the optimized reaction path by DMF/UMA.
* “Image” refers to the energy evaluation points.


### `DMF_ipopt.out`

* Optimization log output by IPOPT during the DMF calculation.

---

## 📎 Supplementary Files

### `DMF_tmax`

* Structures at the maximum energy point (`tmax`) along the path in each iteration.
* `.xyz`: XYZ format.
* `.traj`: ASE trajectory file.

### `timing.log`

* Log of the calculation time.

### `DMF_init`

* Initial path generated using Correlated Flat-Bottom Elastic Network Model (CFB-ENM).
* `.xyz`: XYZ format.
* `.traj`: ASE trajectory file.

### `fbenm_ipopt.out`

* Optimization log from IPOPT during the CFB-ENM initial path construction.

### `energy_history.txt`

* Energies at each iteration during optimization.

### `force_history.txt`

* Atomic forces at each iteration during optimization.

### `local_maxima`

* .log: Vibration result file compatible with GaussView. Users can view the vibration animation in GaussView via Results > Vibrations.
* .txt: Text file containing vibration calculation results.




