In [None]:
import re
import sys

import yaml
from openeye import oechem, oespruce

sys.path.append("../")
from asapdiscovery.data import pdb
from asapdiscovery.data.fragalysis import parse_xtal
from asapdiscovery.data.openeye import (
    load_openeye_pdb,
    save_openeye_pdb,
    split_openeye_mol,
)
from asapdiscovery.data.schema import CrystalCompoundData
from asapdiscovery.data.utils import edit_pdb_file, seqres_to_res_list
from asapdiscovery.docking.modeling import (
    align_receptor,
    du_to_complex,
    mutate_residues,
    prep_receptor,
    remove_extra_ligands,
)

In [None]:
frag_dir = "/lila/data/chodera/asap-datasets/mpro_fragalysis_2022_10_12/aligned"
main_dir = "/lila/data/chodera/kaminowb/stereochemistry_pred/mers/"
frag_fn = f"{main_dir}/test_sequence_metadata_xtal.csv"
xtal_compounds = parse_xtal(frag_fn, frag_dir)
assert len(xtal_compounds) == 1
xtal = xtal_compounds[0]
## From prep_proteins.py
re_pat = rf"/{xtal.dataset}_([0-9][A-Z])/"
try:
    frag_chain = re.search(re_pat, xtal.str_fn).groups()[0]
except AttributeError:
    print(
        f"Regex chain search failed: {re_pat}, {xtal.str_fn}.",
        "Using A as default.",
        flush=True,
    )
    frag_chain = "0A"
xtal.output_name = f"{xtal.dataset}_{frag_chain}_{xtal.compound_id}"

## We also want the chain in the form of a single letter ('A', 'B'), etc
xtal.active_site_chain = frag_chain[-1]

In [None]:
print(xtal, flush=True)

In [None]:
seqres_fn = "../metadata/mpro_sars2_seqres.yaml"
with open(seqres_fn) as f:
    seqres_dict = yaml.safe_load(f)
seqres = seqres_dict["SEQRES"]
seqres_list = seqres_to_res_list(seqres)
seqres_list[3] = "GLY"
seqres = " ".join(seqres_list)
print(seqres_list, flush=True)
print(seqres, flush=True)
print(len(seqres_list), flush=True)

In [None]:
initial_complex = load_openeye_pdb(xtal.str_fn)
initial_split = split_openeye_mol(initial_complex)
initial_prot = initial_split["pro"]
prot_seq = [
    res.GetName()
    for res in oechem.OEGetResidues(initial_prot)
    if (res.GetExtChainID() == "A")
    # and (res.GetName() != "DMS")
    # and (res.GetName() != "LIG")
]
print(prot_seq)
print(len(prot_seq))

In [None]:
metadata = oespruce.OEStructureMetadata()
all_prot_chains = {
    res.GetExtChainID() for res in oechem.OEGetResidues(initial_prot)
}
for chain in all_prot_chains:
    seq_metadata = oespruce.OESequenceMetadata()
    seq_metadata.SetChainID(chain)
    seq_metadata.SetSequence(seqres)
    metadata.AddSequenceMetadata(seq_metadata)
for seq in metadata.GetSequenceMetadata():
    print(seq.GetChainID())
    print(seq.GetSequence())

In [None]:
loop_db = "/lila/home/kaminowb/.openeye/rcsb_spruce.loop_db"

In [None]:
## Set up DU building options
opts = oespruce.OEMakeDesignUnitOptions()
opts.SetSuperpose(False)
if loop_db is not None:
    opts.GetPrepOptions().GetBuildOptions().GetLoopBuilderOptions().SetLoopDBFilename(
        loop_db
    )

## Options set from John's function ########################################
## (https://github.com/FoldingAtHome/covid-moonshot/blob/454098f4255467f4655102e0330ebf9da0d09ccb/synthetic-enumeration/sprint-14-quinolones/00-prep-receptor.py)
opts.GetPrepOptions().SetStrictProtonationMode(True)
# set minimal number of ligand atoms to 5, e.g. a 5-membered ring fragment\
opts.GetSplitOptions().SetMinLigAtoms(5)

# also consider alternate locations outside binding pocket, important for later filtering
opts.GetPrepOptions().GetEnumerateSitesOptions().SetCollapseNonSiteAlts(False)

# alignment options, only matches are important
opts.GetPrepOptions().GetBuildOptions().GetLoopBuilderOptions().SetSeqAlignMethod(
    oechem.OESeqAlignmentMethod_Identity
)
opts.GetPrepOptions().GetBuildOptions().GetLoopBuilderOptions().SetSeqAlignGapPenalty(
    -1
)
opts.GetPrepOptions().GetBuildOptions().GetLoopBuilderOptions().SetSeqAlignExtendPenalty(
    0
)

# Both N- and C-termini should be zwitterionic
# Mpro cleaves its own N- and C-termini
# See https://www.pnas.org/content/113/46/12997
opts.GetPrepOptions().GetBuildOptions().SetCapNTermini(False)
opts.GetPrepOptions().GetBuildOptions().SetCapCTermini(False)
# Don't allow truncation of termini, since force fields don't have
#  parameters for this
opts.GetPrepOptions().GetBuildOptions().GetCapBuilderOptions().SetAllowTruncate(
    False
)
# Build loops and sidechains
opts.GetPrepOptions().GetBuildOptions().SetBuildLoops(True)
opts.GetPrepOptions().GetBuildOptions().SetBuildSidechains(True)

# Generate ligand tautomers
opts.GetPrepOptions().GetProtonateOptions().SetGenerateTautomers(True)

## Allow spruce to build tails at the ends of the protein
opts.GetPrepOptions().GetBuildOptions().GetLoopBuilderOptions().SetBuildTails(
    True
)

In [None]:
initial_bound = initial_prot.CreateCopy()
oechem.OEAddMols(initial_bound, initial_split["lig"].CreateCopy())
dus = oespruce.OEMakeDesignUnits(initial_bound, metadata, opts)
print(dus)

In [None]:
for du in dus:
    print(du.HasProtein(), du.HasLigand())

In [None]:
assert du.HasPDBMetaData()

In [None]:
complex_mol = du_to_complex(du, include_solvent=True)
# save_openeye_pdb(complex_mol, "../../test_seqres_du.pdb")
# ofs = oechem.oemolostream()
# ofs.open("../../test_seqres_du.pdb")
# print(oechem.OEWritePDBFile(ofs, complex_mol))

In [None]:
print(oechem.OEHasPDBData(complex_mol))

In [None]:
if not oechem.OEHasPDBData(complex_mol, "SEQRES"):
    for seqres_line in seqres_dict["SEQRES"].split("\n"):
        if seqres_line != "":
            oechem.OEAddPDBData(complex_mol, "SEQRES", seqres_line[6:])

In [None]:
save_openeye_pdb(complex_mol, "../../test_seqres_du.pdb")

In [None]:
du_prot_mol = oechem.OEGraphMol()
print(du.GetProtein(du_prot_mol))
print(du.GetPDBMetaData(du_prot_mol))
save_openeye_pdb(du_prot_mol, "../../test_seqres_du_prot.pdb")

In [None]:
for d in complex_mol.GetDataIter():
    t = d.GetTag()
    print(t, d, d.GetDataType())

In [None]:
for i, (init_res, du_res, seqres_res) in enumerate(
    zip(
        [
            res.GetName()
            for res in oechem.OEGetResidues(initial_prot)
            if res.GetExtChainID() == "A"
        ],
        [
            res.GetName()
            for res in oechem.OEGetResidues(du_prot)
            if res.GetExtChainID() == "A"
        ],
        seqres_list,
    )
):
    if (
        (init_res != du_res)
        or (init_res != seqres_res)
        or (du_res != seqres_res)
    ):
        print(i, init_res, du_res, seqres_res, flush=True)