In [None]:
import deepdih
from rdkit import Chem
import os
import numpy as np
from tblite.ase import TBLite
import matplotlib.pyplot as plt
from dp_calculator import DPCalculator

In [None]:
molecules = []
for file in os.listdir("molecules"):
    if file.endswith(".mol"):
        molecules.append(Chem.MolFromMolFile("molecules/"+file, sanitize=True, removeHs=False))
fragments = deepdih.mollib.create_lib(molecules)
deepdih.utils.write_sdf(fragments, "fragments.sdf")

In [None]:
# optimize the fragments using GFN2-xTB （energy with bias)
opt_calculator = DPCalculator("model.pt")

if not os.path.exists("fragments"):
    os.makedirs("fragments")

for n_frag, frag in enumerate(fragments):
    rotamers = deepdih.utils.get_rotamers(frag)
    dih_results = []
    for rot in rotamers:
        dih_result_rot = deepdih.geomopt.dihedral_scan(frag, opt_calculator, rot, 12)
        dih_results.extend(dih_result_rot)
    recalc_confs = [deepdih.geomopt.recalc_energy(c, opt_calculator) for c in dih_results]
    deepdih.utils.write_sdf(recalc_confs, f"fragments/fragment_{n_frag}_dihedral_scan.sdf")

In [None]:
fragments = deepdih.utils.read_sdf("fragments.sdf")

if not os.path.exists("topologies"):
    os.makedirs("topologies")

if not os.path.exists("mm_relax"):
    os.makedirs("mm_relax")

for nfrag in range(len(fragments)):
    frag = fragments[nfrag]
    deepdih.preparation.build_gmx_top(frag, top=f"topologies/fragment_{nfrag}.top")
    calculator = deepdih.calculators.GromacsTopCalculator(frag, f"topologies/fragment_{nfrag}.top")
    init_conformations = deepdih.utils.read_sdf(f"fragments/fragment_{nfrag}_dihedral_scan.sdf")
    relax_conformations = [deepdih.geomopt.relax_conformation(c, calculator) for c in init_conformations]
    recalc_conformations = [deepdih.geomopt.recalc_energy(c, calculator) for c in relax_conformations]
    deepdih.utils.write_sdf(recalc_conformations, f"mm_relax/fragment_{nfrag}_dihedral_scan.sdf")
    deepdih.geomopt.plot_opt_results(recalc_conformations, init_conformations, f"mm_relax/fragment_{nfrag}_opt.png")

In [None]:
fragments = deepdih.utils.read_sdf("fragments.sdf")

# prepare training data
training_data = []
delta_sum = 0.0
for nfrag in range(len(fragments)):
    frag = fragments[nfrag]
    qm_conformations = deepdih.utils.read_sdf(f"fragments/fragment_{nfrag}_dihedral_scan.sdf")
    mm_conformations = deepdih.utils.read_sdf(f"mm_relax/fragment_{nfrag}_dihedral_scan.sdf")

    tmp_embedded_mol = deepdih.utils.TorEmbeddedMolecule(mm_conformations[0])
    torsions = [tor.torsion for tor in tmp_embedded_mol.torsions]
    calculator = deepdih.calculators.GromacsTopCalculator(frag, f"topologies/fragment_{nfrag}.top", turnoff_propers=torsions)
    recalc_conformations = [deepdih.geomopt.recalc_energy(c, calculator) for c in mm_conformations]

    mm_positions = [c.GetConformer().GetPositions() for c in mm_conformations]
    qm_energies = np.array([float(c.GetProp("ENERGY")) for c in qm_conformations])
    mm_energies = np.array([float(c.GetProp("ENERGY")) for c in recalc_conformations])
    qm_energies = qm_energies - qm_energies.mean()
    mm_energies = mm_energies - mm_energies.mean()
    delta_energies = qm_energies - mm_energies # In Hartree
    delta_energies = delta_energies / deepdih.utils.EV_TO_HARTREE * deepdih.utils.EV_TO_KJ_MOL # In kJ/mol
    embedded_mol = deepdih.utils.TorEmbeddedMolecule(mm_conformations[0], conf=mm_positions, target=delta_energies)
    training_data.append(embedded_mol)

In [None]:
params = deepdih.finetune.finetune_workflow(training_data, n_fold=3)

# save the parameters
import pickle

with open("params.pkl", "wb") as f:
    pickle.dump(params, f)

In [None]:
# load parameters
import pickle

fragments = deepdih.utils.read_sdf("fragments.sdf")
with open("params.pkl", "rb") as f:
    params = pickle.load(f)

if not os.path.exists("top_tuned"):
    os.makedirs("top_tuned")

for nfrag in range(len(fragments)):
    frag = fragments[nfrag]
    inp_top = f"topologies/fragment_{nfrag}.top"
    out_top = f"top_tuned/fragment_{nfrag}.top"
    deepdih.finetune.update_gmx_top(frag, inp_top, params, out_top)

In [None]:
# valid original
for nfrag in range(len(fragments)):
    frag = fragments[nfrag]
    calculator = deepdih.calculators.GromacsTopCalculator(frag, f"topologies/fragment_{nfrag}.top")
    init_conformations = deepdih.utils.read_sdf(f"fragments/fragment_{nfrag}_dihedral_scan.sdf")
    relax_conformations = deepdih.utils.read_sdf(f"mm_relax/fragment_{nfrag}_dihedral_scan.sdf")
    recalc_conformations = [deepdih.geomopt.recalc_energy(c, calculator) for c in relax_conformations]
    r2, rmse = deepdih.geomopt.plot_opt_results(recalc_conformations, init_conformations, f"mm_relax/fragment_{nfrag}_opt.png")
    print(f"Frag {nfrag} R2: {r2:.3f}, RMSE: {rmse:.3f}")

# valid
print("====== Valid ======")
if not os.path.exists("mm_valid"):
    os.makedirs("mm_valid")

for nfrag in range(len(fragments)):
    frag = fragments[nfrag]
    calculator = deepdih.calculators.GromacsTopCalculator(frag, f"top_tuned/fragment_{nfrag}.top")
    init_conformations = deepdih.utils.read_sdf(f"fragments/fragment_{nfrag}_dihedral_scan.sdf")
    relax_conformations = deepdih.utils.read_sdf(f"mm_relax/fragment_{nfrag}_dihedral_scan.sdf")
    recalc_conformations = [deepdih.geomopt.recalc_energy(c, calculator) for c in relax_conformations]
    deepdih.utils.write_sdf(recalc_conformations, f"mm_valid/fragment_{nfrag}_dihedral_scan.sdf")
    r2, rmse = deepdih.geomopt.plot_opt_results(recalc_conformations, init_conformations, f"mm_valid/fragment_{nfrag}_opt.png")
    print(f"Frag {nfrag} R2: {r2:.3f}, RMSE: {rmse:.3f}")

In [None]:
# patch molecules 
if not os.path.exists("molecules_patched"):
    os.makedirs("molecules_patched")

molecules = {}
for file in os.listdir("molecules"):
    if file.endswith(".mol"):
        molecules[file] = Chem.MolFromMolFile("molecules/"+file, sanitize=True, removeHs=False)

import pickle

with open("params.pkl", "rb") as f:
    params = pickle.load(f)

for file in molecules:
    frag = molecules[file]
    name = file.split(".")[0]
    inp_top = f"molecules/{name}.top"
    out_top = f"molecules_patched/{name}.top"
    deepdih.preparation.build_gmx_top(frag, top=inp_top)
    deepdih.finetune.update_gmx_top(frag, inp_top, params, out_top)