In [1]:
#@title Install Conda
%%capture
!pip install -q condacolab
import condacolab
condacolab.install()

In [None]:
#@title Install Dependencies
%%capture
!mamba install pdbfixer gradio=4.44.1 pydantic==2.10.6 rdkit openff-toolkit openmmforcefields
!mamba update pluggy

In [None]:
#@title Gradio Interface
import os, io, lzma, time
import parmed, pickle, shutil, zipfile
from openmm.app import Modeller, HBonds, Simulation, NoCutoff
from openmm import LangevinIntegrator, Platform, Context, System, CustomExternalForce
from openff.toolkit import Molecule
from openff.toolkit.utils.toolkits import ToolkitRegistry, RDKitToolkitWrapper
from openmm.unit import (nanometer, kelvin, picoseconds,
                         femtoseconds, bar, kilocalorie_per_mole, kilojoule_per_mole)
from openmmforcefields.generators import SystemGenerator
from pdbfixer import PDBFixer
from rdkit import Chem
from rdkit.Chem import Descriptors, QED
from concurrent.futures import ProcessPoolExecutor, as_completed

import pandas as pd
import numpy as np
import gradio as gr

from openmm import Platform

working_dir = os.path.join(os.path.abspath(''), 'working_dir')
output_dir = os.path.join(working_dir, 'output_dir')
input_dir = os.path.join(working_dir, 'input_dir')
os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

toolkit_registry = ToolkitRegistry([RDKitToolkitWrapper])

chem_prop_to_full_name_map = {'mw'  : 'Molecular Weight'        ,
                              'hbd' : 'Hydrogen Bond Donors', 'hba' : 'Hydrogen Bond Acceptors' ,
                              'logp': 'LogP'                , 'tpsa': 'Topological Polar Surface Area',
                              'rb'  : 'Rotatable Bonds'     , 'nor' : 'Number of Rings'         ,
                              'fc'  : 'Formal Charge'       , 'nha' : 'Number of Heavy Atoms'   ,
                              'mr'  : 'Molar Refractivity'  , 'na'  : 'Number of Atoms'         ,
                              'QED' : 'QED'}

property_functions = {'mw'  : Descriptors.MolWt,
                      'hbd' : Descriptors.NumHDonors,
                      'hba' : Descriptors.NumHAcceptors,
                      'logp': Descriptors.MolLogP,
                      'tpsa': Descriptors.TPSA,
                      'rb'  : Descriptors.NumRotatableBonds,
                      'nor' : lambda mol: mol.GetRingInfo().NumRings(),
                      'fc'  : lambda mol: sum([atom.GetFormalCharge() for atom in mol.GetAtoms()]),
                      'nha' : Descriptors.HeavyAtomCount,
                      'mr'  : Descriptors.MolMR,
                      'na'  : lambda mol: mol.GetNumAtoms(),
                      'QED' : QED.qed}

class ImplicitMinimizeComplex:
    def __init__(self, protein_pth: str, ligand_pth: str,
                 minimize_tolerance: float=None,
                 minimize_maxiter: int=0,
                 platform: str=None,
                 flex_bb: bool=False,
                 ph: float=7.0,
                 forcefield: str='openff_unconstrained-2.2.1.offxml'):
        # self.ligand_partial_charge = 'am1bcc'
        # self.ligand_partial_charge = 'am1-mulliken'
        # self.ligand_partial_charge = 'gasteiger'
        self.ligand_partial_charge = 'mmff94'
        self.forcefield_kwargs = {
            "constraints"   : None,
            'soluteDielectric': 1.0,
            'solventDielectric': 80.0,}

        self.protein_forcefield = ['amber14-all.xml', 'implicit/obc1.xml']
        self.ligand_forcefield = forcefield
        self.temperature = 298 * kelvin
        self.friction = 1 / picoseconds
        self.pressure = 1 * bar
        self.minimize_tolerance = minimize_tolerance
        self.minimize_maxiter = minimize_maxiter
        self.protein_pth = protein_pth
        self.ligand_pth = ligand_pth
        self.flex_bb = flex_bb
        self.ph = ph

        if platform is None:
            try:
                self.platform = Platform.getPlatformByName('CUDA')
                # print('Using CUDA.')
            except:
                try:
                    self.platform = Platform.getPlatformByName('OpenCL')
                    # print('Using OpenCL.')
                except:
                    self.platform = Platform.getPlatformByName('CPU')
                    # print('Using CPU.')
        else:
            self.platform = Platform.getPlatformByName(platform)

    @staticmethod
    def pdb_fix_and_cleanup(pdb_pth: str, ph: float):
        fixer = PDBFixer(pdb_pth)
        fixer.findMissingResidues()
        fixer.findNonstandardResidues()
        fixer.replaceNonstandardResidues()
        fixer.removeHeterogens(False)
        fixer.findMissingAtoms()
        fixer.addMissingAtoms()
        fixer.addMissingHydrogens()
        return fixer

    def setup_protein(self, protein_pth: str):
        protein = self.pdb_fix_and_cleanup(protein_pth, self.ph)
        return Modeller(protein.topology, protein.positions)

    def setup_ligand(self, ligand_pth: str):
        ligand = Molecule.from_file(ligand_pth, allow_undefined_stereo=True)
        ligand.assign_partial_charges(self.ligand_partial_charge, toolkit_registry=toolkit_registry)
        return ligand

    def setup_system_generator(self):
        return SystemGenerator(forcefields=self.protein_forcefield,
                               small_molecule_forcefield=self.ligand_forcefield,
                               forcefield_kwargs=self.forcefield_kwargs,
                               periodic_forcefield_kwargs={'nonbondedMethod': NoCutoff},
                               molecules=self.ligand)

    def constrain_backbone(self):
        force = CustomExternalForce('0.5 * k * ((x - x0)^2 + (y - y0)^2 + (z - z0)^2)')
        force.addPerParticleParameter('x0')
        force.addPerParticleParameter('y0')
        force.addPerParticleParameter('z0')
        force.addGlobalParameter('k', 1e5*kilojoule_per_mole/nanometer**2)

        positions = self.modeller.positions
        for atom in self.modeller.topology.atoms():
            if atom.residue.name != 'UNK':  # Find all non-ligand
                if atom.name in ['N', 'CA', 'C', 'O']:  # Set backbone to rigid
                    index = atom.index
                    position = positions[index]
                    force.addParticle(index, [position.x, position.y, position.z])

        self.system.addForce(force)

    def simulate_annealing(self, initial_temp=1000*kelvin, final_temp=298*kelvin,
                           total_steps=1000, steps_per_temp=10):
        integrator = self.simulation.integrator
        num_temp_steps = total_steps // steps_per_temp
        temp_schedule = np.linspace(initial_temp.value_in_unit(kelvin),
                                    final_temp.value_in_unit(kelvin), num_temp_steps)
        for temp in temp_schedule:
            integrator.setTemperature(temp * kelvin)
            self.simulation.step(steps_per_temp)

    def setup_simulation(self):
        ligand_topology = self.ligand.to_topology()
        self.modeller.add(ligand_topology.to_openmm(), ligand_topology.get_positions().to_openmm())
        self.system: System = self.sys_generator.create_system(self.modeller.topology)
        if not self.flex_bb:
            self.constrain_backbone()
        integrator = LangevinIntegrator(self.temperature, self.friction, 1 * femtoseconds)
        self.simulation = Simulation(self.modeller.topology, self.system, integrator, self.platform)
        self.simulation.context.setPositions(self.modeller.positions)

    def minimize_energy(self):
        if self.minimize_tolerance is None:
            self.simulation.minimizeEnergy(maxIterations=self.minimize_maxiter)
        else:
            self.simulation.minimizeEnergy(self.minimize_tolerance * kilojoule_per_mole,
                                           self.minimize_maxiter)

    def split_complex(self):
        struct = parmed.openmm.load_topology(self.simulation.topology,
                                             self.system,
                                             self.simulation.context.getState(getPositions=True).getPositions())
        struct.strip(':HOH,NA,CL')
        return struct, struct['!:UNK'], struct[':UNK'], self.ligand

    def __call__(self):
        self.modeller = self.setup_protein(self.protein_pth)
        self.ligand = self.setup_ligand(self.ligand_pth)
        self.sys_generator = self.setup_system_generator()
        self.setup_simulation()
        curr_eng = self.simulation.context.getState(getEnergy=True).getPotentialEnergy().value_in_unit(kilojoule_per_mole)
        yield f'Simulation setup, current energy: {curr_eng:.4f} kJ/mol', None
        self.minimize_energy()
        curr_eng = self.simulation.context.getState(getEnergy=True).getPotentialEnergy().value_in_unit(kilojoule_per_mole)
        yield f'Energy minimized, current energy: {curr_eng:.4f} kJ/mol', self.split_complex()
        # yield f'Energy minimized, current energy: {curr_eng:.4f} kJ/mol', None
        # self.simulation.step(self.sim_step)
        # curr_eng = self.simulation.context.getState(getEnergy=True).getPotentialEnergy().value_in_unit(kilojoule_per_mole)
        # yield f'Short Sim.  Done, current energy: {curr_eng:.4f} kJ/mol', self.split_complex()

class CalculateBindingEnergy:
    def __init__(self,
                 complex_struct: parmed.structure.Structure,
                 protein_struct: parmed.structure.Structure,
                 ligand_struct : parmed.structure.Structure,
                 ligand: Molecule):
        self.complex = complex_struct
        self.protein = protein_struct
        self.ligand  = ligand_struct
        self.forcefield_kwargs = {
            "constraints"   : None,
            'soluteDielectric': 1.0,
            'solventDielectric': 80.0,}
        self.implicit_solvent_system_generator = SystemGenerator(forcefields=['amber14-all.xml', 'implicit/obc1.xml'],
                                                                 small_molecule_forcefield='openff_unconstrained-2.2.1.offxml',
                                                                 molecules=[ligand],
                                                                 forcefield_kwargs=self.forcefield_kwargs,
                                                                 periodic_forcefield_kwargs={'nonbondedMethod': NoCutoff})

    def retrieve_potential_energy(self, struct: parmed.structure.Structure) -> float:
        system = self.implicit_solvent_system_generator.create_system(struct.topology)
        context = Context(system, LangevinIntegrator(298 * kelvin, 1 / picoseconds, 2 * femtoseconds))
        context.setPositions(struct.positions)
        eng = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit(kilocalorie_per_mole)
        del context
        return eng

    def calculate_binding_energy(self) -> float:
        protein_eng = self.retrieve_potential_energy(self.protein)
        ligand_eng  = self.retrieve_potential_energy(self.ligand )
        complex_eng = self.retrieve_potential_energy(self.complex)
        return complex_eng - protein_eng - ligand_eng

def process_protein_ligand_to_dict(protein, ligand, ligand_mol):
    protein_io = io.StringIO()
    protein.save(protein_io, format='pdb')
    ligand_io = io.StringIO()
    ligand.save(ligand_io, format='pdb')

    lig_mol = ligand_mol.to_rdkit()
    original_conf = lig_mol.GetConformer()
    pdb_mol = Chem.MolFromPDBBlock(ligand_io.getvalue(), removeHs=False)
    pdb_mol_conf = pdb_mol.GetConformer()
    for i in range(lig_mol.GetNumAtoms()):
        xyz_3d = pdb_mol_conf.GetAtomPosition(i)
        original_conf.SetAtomPosition(i, xyz_3d)
    prop = {chem_prop_to_full_name_map[k]: str(func(lig_mol)) for k, func in property_functions.items()}
    lig_mol = Chem.RemoveHs(lig_mol)

    ligand_pdb_str_list = Chem.MolToPDBBlock(pdb_mol).replace(' UNK ', ' UNL ').strip().split('\n')
    complex_pdb_str_list = protein_io.getvalue().strip().split('\n')[:-1]
    final_pos = int(complex_pdb_str_list[-1][6:11])
    for line in ligand_pdb_str_list:
        if   line.startswith('HETATM'):
            new_pos = int(line[6:11]) + final_pos
            line = line[:6] + f'{new_pos:>5}' + line[11:]
        elif line.startswith('CONECT'):
            conect_pos = []
            for i in range(6, 27, 5):
                pos = line[i:i+5].strip()
                if pos:
                    conect_pos.append(f'{int(pos) + final_pos:>5}')
                else:
                    break
            line = 'CONECT' + ''.join(conect_pos)
        complex_pdb_str_list.append(line)
    complex_str = '\n'.join(complex_pdb_str_list)
    output_dict = {'complex': complex_str,
                    'rdmol': lig_mol}
    return output_dict, prop

def save_dict_to_mdm(output_dict: dict, pth: str):
    with lzma.open(pth, 'wb') as f:
        pickle.dump(output_dict, f)

def single_minimize_complex(name, complex_pth, out_dir, csv_pth, ph, forcefield, platform_name):
    output_mdm = os.path.join(out_dir, f'{name}_output.mdm')
    protein_pth = os.path.join(complex_pth, 'protein.pdb')
    ligand_pth = os.path.join(complex_pth, f'{name}.sdf')
    read_next = False
    with open(ligand_pth) as f:
        for l in f:
            if read_next:
                old_eng = float(l)
                break
            if l == '>  <VINA Energy>  (1) \n' or l == '>  <Old Score>  (1) \n':
                read_next = True
    passed_str = ''
    tik = time.perf_counter()
    try:
        minimize_complex = ImplicitMinimizeComplex(protein_pth, ligand_pth,
                                                   platform=platform_name,
                                                   ph=ph,
                                                   forcefield=forcefield)
        for message, result in minimize_complex():
            passed_str += f'{message}\n'
            if result is not None:
                complex, protein, ligand, ligand_mol = result
                output_dict, prop = process_protein_ligand_to_dict(protein, ligand, ligand_mol)
                calculator = CalculateBindingEnergy(complex, protein, ligand, ligand_mol)
                binding_energy = calculator.calculate_binding_energy()
                output_dict['binding_energy'] = binding_energy
                output_dict['old_score'] = old_eng
                output_dict.update(prop)
                new_row = {'Name': [name], 'Minimized Energy': [binding_energy]}
                save_dict_to_mdm(output_dict, output_mdm)
                passed_str += f'{name} Binding Energy: {binding_energy:.4f} kcal/mol\n'
                tok = time.perf_counter()-tik
                with open(csv_pth, 'a') as f:
                    f.write(f'{name},{binding_energy},{old_eng},{",".join(v for v in prop.values())}\n')
                passed_str += f'Minimization took {tok:.4f} seconds.\n'
        return passed_str, new_row
    except KeyboardInterrupt:
        all_strings += 'Minimization interrupted by user.\n'
        return passed_str, {}
    except Exception as e:
        passed_str += f'{e}\n'
        new_row = {'Name': [name], 'Minimized Energy': [float('nan')]}
        empty = (len(property_functions) - 1) * ','
        with open(csv_pth, 'a') as f:
            f.write(f'{name},,{old_eng},{empty}\n')
        return passed_str, new_row

def recursive_rm_file(parent_dir: str):
    for f in os.listdir(parent_dir):
        p = os.path.join(parent_dir, f)
        if os.path.isdir(p):
            recursive_rm_file(p)
        elif f.startswith('.'):
            os.remove(p)

class OpenMMSimulationInterface:
    def start_interface(self):
        self.is_minimizing = False
        self.stop_minimizing = False
        input_dir = os.path.join(working_dir, 'input_dir')
        with gr.Blocks(css='footer{display:none !important}') as Interface:
            gr.Markdown('<span style="font-size:25px; font-weight:bold; ">OpenMM Minimization</span>')
            with gr.Row():
                with gr.Column(scale=1):
                    minimize_input_stat = gr.Textbox(label='Input OpenMM Minimization Format ZIP')
                    minimize_input = gr.File(label='Target File',
                                             file_count='single',
                                             file_types=['.zip'])
                with gr.Column(scale=1):
                    output_stat = gr.Textbox(label='Minimized ZIP File')
                    output_input = gr.File(label='Minimized File',
                                           file_count='single',
                                           file_types=['.zip'])
            with gr.Row():
                ph_value = gr.Number(value=7.0,
                                     label='pH to add hydrogen',
                                     step=0.1)
                concurrent_num = gr.Number(value=2,
                                           label='Concurrent Num.')
                forcefield_type = gr.Dropdown(['openff_unconstrained-2.2.1.offxml', 'gaff-2.11'],
                                              value='openff_unconstrained-2.2.1.offxml',
                                              label='Force Field')
            with gr.Row():
                dock_progress = gr.Text(label='Progress', interactive=False, scale=4)
                with gr.Column():
                    minimize_button = gr.Button('Minimize',
                                                interactive=bool(os.listdir(input_dir)))
                    zip_checkbox = gr.Checkbox(label='Zip Result')
                    zip_progress = gr.Textbox(label='Progress', interactive=False)
                    zip_name = gr.Textbox(label='Zip File Name',
                                          placeholder='docked_result',
                                          interactive=True)
                    zipped_file_output = gr.File(label='Zipped file',
                                                 file_count='single',
                                                 file_types=['.zip'],)
            energy_df = gr.DataFrame(value=None,
                                     headers=['Name', 'Minimized Energy'],
                                     interactive=False,)
            minimize_input.change(self.upload_target_files,
                                  inputs=minimize_input,
                                  outputs=[minimize_input_stat, minimize_button])
            output_input.change(self.upload_docked_ligand,
                                inputs=output_input,
                                outputs=[output_stat, energy_df])
            minimize_button.click(self.start_minimizing,
                                  inputs=[ph_value, concurrent_num, forcefield_type],
                                  outputs=[dock_progress, energy_df])
            Interface.load(self.zip_docked_files, [zip_checkbox, zip_name], [zipped_file_output, zip_checkbox, zip_progress], every=0.5)

            Interface.queue().launch(share=True, debug=True)

    def upload_target_files(self, file):
        input_dir = os.path.join(working_dir, 'input_dir')
        if file:
            with zipfile.ZipFile(file, 'r') as zip_f:
                zip_f.extractall(input_dir)
            for f_name in os.listdir(input_dir):
                n = os.path.join(input_dir, f_name)
                if f_name.startswith('.'):
                    os.remove(n)
            return f'{len(os.listdir(input_dir))} target complexes uploaded.', gr.update(interactive=True)
        shutil.rmtree(input_dir)
        os.mkdir(input_dir)
        return 'Files removed.', gr.update(interactive=False)

    def upload_docked_ligand(self, file):
        output_dir = os.path.join(working_dir, 'output_dir')
        if file:
            with zipfile.ZipFile(file, 'r') as zip_f:
                zip_f.extractall(output_dir)
            for f_name in os.listdir(output_dir):
                n = os.path.join(output_dir, f_name)
                if f_name.startswith('.'):
                    os.remove(n)
            minimize_csv = os.path.join(output_dir, 'minimize.csv')
            if os.path.isfile(minimize_csv):
                df = pd.read_csv(minimize_csv)[['Name', 'Minimized Energy']]
            else:
                df = None
            return f'{len(os.listdir(output_dir))} minimized complexes uploaded.', df
        shutil.rmtree(output_dir)
        os.mkdir(output_dir)
        return 'Files removed.', None

    def process_protein_ligand_to_dict(self, protein, ligand, ligand_mol):
        protein_io = io.StringIO()
        protein.save(protein_io, format='pdb')
        ligand_io = io.StringIO()
        ligand.save(ligand_io, format='pdb')

        lig_mol = ligand_mol.to_rdkit()
        original_conf = lig_mol.GetConformer()
        pdb_mol = Chem.MolFromPDBBlock(ligand_io.getvalue(), removeHs=False)
        pdb_mol_conf = pdb_mol.GetConformer()
        for i in range(lig_mol.GetNumAtoms()):
            xyz_3d = pdb_mol_conf.GetAtomPosition(i)
            original_conf.SetAtomPosition(i, xyz_3d)
        prop = {chem_prop_to_full_name_map[k]: str(func(lig_mol)) for k, func in property_functions.items()}
        lig_mol = Chem.RemoveHs(lig_mol)

        ligand_pdb_str_list = Chem.MolToPDBBlock(pdb_mol).replace(' UNK ', ' UNL ').strip().split('\n')
        complex_pdb_str_list = protein_io.getvalue().strip().split('\n')[:-1]
        final_pos = int(complex_pdb_str_list[-1][6:11])
        for line in ligand_pdb_str_list:
            if   line.startswith('HETATM'):
                new_pos = int(line[6:11]) + final_pos
                line = line[:6] + f'{new_pos:>5}' + line[11:]
            elif line.startswith('CONECT'):
                conect_pos = []
                for i in range(6, 27, 5):
                    pos = line[i:i+5].strip()
                    if pos:
                        conect_pos.append(f'{int(pos) + final_pos:>5}')
                    else:
                        break
                line = 'CONECT' + ''.join(conect_pos)
            complex_pdb_str_list.append(line)
        complex_str = '\n'.join(complex_pdb_str_list)
        output_dict = {'complex': complex_str,
                        'rdmol': lig_mol}
        return output_dict, prop

    def save_dict_to_mdm(self, output_dict: dict, pth: str):
        with lzma.open(pth, 'wb') as f:
            pickle.dump(output_dict, f)

    def zip_docked_files(self, check_status, zipped_name):
        if check_status:
            if not zipped_name:
                zipped_name = 'docked_result'
            zipped_file = os.path.join(working_dir, zipped_name + '.zip')
            output_dir = os.path.join(working_dir, 'output_dir')
            all_files = []
            for root, dirs, files in os.walk(output_dir):
                for file in files:
                    file_path = os.path.join(root, file)
                    all_files.append(file_path)
            file_cnt = len(all_files)
            yield gr.update(), False, f'Progress ({0:{len(str(file_cnt))}}/{file_cnt})'
            num = 0
            with zipfile.ZipFile(zipped_file, 'w', zipfile.ZIP_LZMA) as zipf:
                for file in all_files:
                    file_path = os.path.join(root, file)
                    zipf.write(file_path, os.path.relpath(file_path, output_dir))
                    num += 1
                    yield gr.update(), gr.update(), f'Progress ({num:{len(str(file_cnt))}}/{file_cnt})'
            yield zipped_file, gr.update(), f'Zipping Done'
        else:
            yield gr.update(), gr.update(), gr.update()

    def start_minimizing(self, protein_ph: float, concurrent_num: int, forcefield: str):
        if not self.is_minimizing:
            self.is_minimizing = True
            input_dir = os.path.join(working_dir, 'input_dir')
            output_dir = os.path.join(working_dir, 'output_dir')
            minimize_csv = os.path.join(output_dir, 'minimize.csv')
            if not os.path.isfile(minimize_csv):
                with open(minimize_csv, 'w') as f:
                    f.write('Name,Minimized Energy,Old Score,'+','.join(list(chem_prop_to_full_name_map.values()))+'\n')
            df = pd.read_csv(minimize_csv)[['Name', 'Minimized Energy']]
            calculated_ids = df['Name'].to_list()
            all_ligands = {f: os.path.join(input_dir, f) for f in os.listdir(input_dir)
                        if f not in calculated_ids and not f.startswith('.')}
            all_strings = ''
            i = 1
            platform_name = 'CPU'
            try:
                platform = Platform.getPlatformByName('CUDA')
                platform_name = 'CUDA'
            except:
                try:
                    platform = Platform.getPlatformByName('OpenCL')
                    platform_name = 'OpenCL'
                except:
                    pass
            all_strings += f'Using {platform_name}\n'
            yield all_strings, gr.update(value=df)

            with ProcessPoolExecutor(concurrent_num) as self.executor:
                futures = [self.executor.submit(single_minimize_complex, name, pth, output_dir,
                                                minimize_csv, protein_ph, forcefield, platform_name)
                           for name, pth in all_ligands.items()]
                for f in as_completed(futures):
                    result_str, df_row_dict = f.result()
                    all_strings += result_str
                    if df_row_dict:
                        if df.empty:
                            df = pd.DataFrame(df_row_dict)
                        else:
                            df = pd.concat([df, pd.DataFrame(df_row_dict)], ignore_index=True)
                    all_strings += '-'*60 + '\n'
                    yield all_strings, gr.update(value=df)
            all_strings += 'Minimization Done'
            return all_strings, gr.update(value=df)
        else:
            self.is_minimizing = False
            self.stop_minimizing = True
            all_strings += 'User stop\n'
            self.executor.shutdown(False, cancel_futures=True)
            return all_strings, gr.update(value=df)

if __name__ == '__main__':
    interface = OpenMMSimulationInterface()
    interface.start_interface()