In [1]:
pdb_path = "/lustre/fs6/lyu_lab/scratch/ichen/data/boltz_runs/ampc_pdb/ZINCnk00000bpPs7.pdb"

from boltz.rescore.load.protein import Protein
protein = Protein(pdb_path)
protein_seq = protein.get_sequence()
protein_coords = protein.get_coords()

In [2]:
from boltz.rescore.utils.mol2_utils import Mol2Utils
docked_mol = Mol2Utils.read_single_mol2_file("/lustre/fs6/lyu_lab/scratch/ichen/data/boltz_runs/ampc_mol2/ZINCnk00000bpPs7.mol2")



In [3]:
protein_seq

({'A': 'APQQINDIVHRTITPLIEQQKIPGMAVAVIYQGKPYYFTWGYADIAKKQPVTQQTLFELGSVSKTFTGVLGGDAIARGEIKLSDPTTKYWPELTAKQWNGITLLHLATYTAGGLPLQVPDEVKSSSDLLRFYQNWQPAWAPGTQRLYANSSIGLFGALAVKPSGLSFEQAMQTRVFQPLKLNHTWINVPPAEEKNYAWGYREGKAVHVSPGALDAEAYGVKSTIEDMARWVQSNLKPLDINEKTLQQGIQLAQSRYWQTGDMYQGLGWEMLDWPVNPDSIINGSDNKIALAARPVKAITPPTPAVRASWVHKTGATGGFGSYVAFIPEKELGIVMLANKNYPNPARVDAAWQILNALQ'},
 'B')

In [4]:
import sys

sys.path.append("/lustre/fs6/lyu_lab/scratch/ichen/boltz/src")

from pathlib import Path


In [5]:
path = Path('/lustre/fs6/lyu_lab/scratch/ichen/data/boltz_runs/ampc_rescore_cofold_fasta/ZINCnk00000bpPs7.fasta')

In [6]:
cache: str = "~/.boltz"
cache = Path(cache).expanduser()
cache.mkdir(parents=True, exist_ok=True)        

#download(cache)
ccd_path = cache / "ccd.pkl"

import pickle
with ccd_path.open("rb") as file:
    ccd = pickle.load(file)  # noqa: S301

In [7]:
from collections.abc import Mapping
from pathlib import Path

from Bio import SeqIO
from rdkit.Chem.rdchem import Mol

from boltz.data.types import Target

In [8]:
def parse_fasta(
    path: Path, ccd: Mapping[str, Mol], protein_coords, docked_mol, out_dir
) -> Target:  # noqa: C901
    """Parse a fasta file.

    The name of the fasta file is used as the name of this job.
    We rely on the fasta record id to determine the entity type.

    > CHAIN_ID|ENTITY_TYPE|MSA_ID
    SEQUENCE
    > CHAIN_ID|ENTITY_TYPE|MSA_ID
    ...

    Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles,
    and CHAIN_ID is the chain identifier, which should be unique.
    The MSA_ID is optional and should only be used on proteins.

    Parameters
    ----------
    fasta_file : Path
        Path to the fasta file.
    ccd : Dict
        Dictionary of CCD components.

    Returns
    -------
    Target
        The parsed target.

    """
    # Read fasta file
    with path.open("r") as f:
        records = list(SeqIO.parse(f, "fasta"))

    # Make sure all records have a chain id and entity
    for seq_record in records:
        if "|" not in seq_record.id:
            msg = f"Invalid record id: {seq_record.id}"
            raise ValueError(msg)

        header = seq_record.id.split("|")
        assert len(header) >= 2, f"Invalid record id: {seq_record.id}"

        chain_id, entity_type = header[:2]
        if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}:
            msg = f"Invalid entity type: {entity_type}"
            raise ValueError(msg)
        if chain_id == "":
            msg = "Empty chain id in input fasta!"
            raise ValueError(msg)
        if entity_type == "":
            msg = "Empty entity type in input fasta!"
            raise ValueError(msg)

    # Convert to yaml format
    sequences = []
    for seq_record in records:
        # Get chain id, entity type and sequence
        header = seq_record.id.split("|")
        chain_id, entity_type = header[:2]
        if len(header) == 3 and header[2] != "":
            assert (
                entity_type.lower() == "protein"
            ), "MSA_ID is only allowed for proteins"
            msa_id = header[2]
        else:
            msa_id = None

        entity_type = entity_type.upper()
        seq = str(seq_record.seq)

        if entity_type == "PROTEIN":
            molecule = {
                "protein": {
                    "id": chain_id,
                    "sequence": seq,
                    "modifications": [],
                    "msa": msa_id,
                },
            }
        elif entity_type == "RNA":
            molecule = {
                "rna": {
                    "id": chain_id,
                    "sequence": seq,
                    "modifications": [],
                },
            }
        elif entity_type == "DNA":
            molecule = {
                "dna": {
                    "id": chain_id,
                    "sequence": seq,
                    "modifications": [],
                }
            }
        elif entity_type.upper() == "CCD":
            molecule = {
                "ligand": {
                    "id": chain_id,
                    "ccd": seq,
                }
            }
        elif entity_type.upper() == "SMILES":
            molecule = {
                "ligand": {
                    "id": chain_id,
                    "smiles": seq,
                }
            }

        sequences.append(molecule)

    data = {
        "sequences": sequences,
        "bonds": [],
        "version": 1,
    }

    name = path.stem
    return name, data, ccd, protein_coords, docked_mol,out_dir


In [9]:
name, data, ccd, protein_coords, docked_mol,out_dir = parse_fasta(path, ccd, protein_coords, docked_mol, "/lustre/fs6/lyu_lab/scratch/ichen")

In [10]:
data

{'sequences': [{'protein': {'id': 'A',
    'sequence': 'APQQINDIVHRTITPLIEQQKIPGMAVAVIYQGKPYYFTWGYADIAKKQPVTQQTLFELGSVSKTFTGVLGGDAIARGEIKLSDPTTKYWPELTAKQWNGITLLHLATYTAGGLPLQVPDEVKSSSDLLRFYQNWQPAWAPGTQRLYANSSIGLFGALAVKPSGLSFEQAMQTRVFQPLKLNHTWINVPPAEEKNYAWGYREGKAVHVSPGALDAEAYGVKSTIEDMARWVQSNLKPLDINEKTLQQGIQLAQSRYWQTGDMYQGLGWEMLDWPVNPDSIINGSDNKIALAARPVKAITPPTPAVRASWVHKTGATGGFGSYVAFIPEKELGIVMLANKNYPNPARVDAAWQILNALQ',
    'modifications': [],
    'msa': '/lustre/fs6/lyu_lab/scratch/ichen/data/boltz_rescore/ampc/ampc_raw/ampc.a3m'}},
  {'ligand': {'id': 'B',
    'smiles': 'CCCN(CCC)C(=O)CCC(=O)[N-]CC1=NOC=C1C([O-])=O'}}],
 'bonds': [],
 'version': 1}

In [11]:
import os
import pickle

from rdkit import Chem
from rdkit.Chem import AllChem

from boltz.rescore.utils.rdkit_utils import RdkitUtils

import pdb
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Optional

import click
import numpy as np
from rdkit import Chem, rdBase
from rdkit.Chem import AllChem
from rdkit.Chem.rdchem import Conformer, Mol

from boltz.data import const
from boltz.data.types import (
    Atom,
    Bond,
    Chain,
    ChainInfo,
    Connection,
    InferenceOptions,
    Interface,
    Record,
    Residue,
    Structure,
    StructureInfo,
    Target,
)
from boltz.rescore.coordinate.construct import Construct

####################################################################################################
# DATACLASSES
####################################################################################################


@dataclass(frozen=True)
class ParsedAtom:
    """A parsed atom object."""

    name: str
    element: int
    charge: int
    coords: tuple[float, float, float]
    conformer: tuple[float, float, float]
    is_present: bool
    chirality: int


@dataclass(frozen=True)
class ParsedBond:
    """A parsed bond object."""

    atom_1: int
    atom_2: int
    type: int


@dataclass(frozen=True)
class ParsedResidue:
    """A parsed residue object."""

    name: str
    type: int
    idx: int
    atoms: list[ParsedAtom]
    bonds: list[ParsedBond]
    orig_idx: Optional[int]
    atom_center: int
    atom_disto: int
    is_standard: bool
    is_present: bool


@dataclass(frozen=True)
class ParsedChain:
    """A parsed chain object."""

    entity: str
    type: str
    residues: list[ParsedResidue]


####################################################################################################
# HELPERS
####################################################################################################


def convert_atom_name(name: str) -> tuple[int, int, int, int]:
    """Convert an atom name to a standard format.

    Parameters
    ----------
    name : str
        The atom name.

    Returns
    -------
    Tuple[int, int, int, int]
        The converted atom name.

    """
    name = name.strip()
    name = [ord(c) - 32 for c in name]
    name = name + [0] * (4 - len(name))
    return tuple(name)


def compute_3d_conformer(mol: Mol, version: str = "v3") -> bool:
    """Generate 3D coordinates using EKTDG method.

    Taken from `pdbeccdutils.core.component.Component`.

    Parameters
    ----------
    mol: Mol
        The RDKit molecule to process
    version: str, optional
        The ETKDG version, defaults ot v3

    Returns
    -------
    bool
        Whether computation was successful.

    """
    if version == "v3":
        options = AllChem.ETKDGv3()
    elif version == "v2":
        options = AllChem.ETKDGv2()
    else:
        options = AllChem.ETKDGv2()

    options.clearConfs = False
    conf_id = -1

    try:
        conf_id = AllChem.EmbedMolecule(mol, options)

        if conf_id == -1:
            print(
                f"WARNING: RDKit ETKDGv3 failed to generate a conformer for molecule "
                f"{Chem.MolToSmiles(AllChem.RemoveHs(mol))}, so the program will start with random coordinates. "
                f"Note that the performance of the model under this behaviour was not tested."
            )
            options.useRandomCoords = True
            conf_id = AllChem.EmbedMolecule(mol, options)

        AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000)

    except RuntimeError:
        pass  # Force field issue here
    except ValueError:
        pass  # sanitization issue here

    if conf_id != -1:
        conformer = mol.GetConformer(conf_id)
        conformer.SetProp("name", "Computed")
        conformer.SetProp("coord_generation", f"ETKDG{version}")

        return True

    return False


def get_conformer(mol: Mol) -> Conformer:
    """Retrieve an rdkit object for a deemed conformer.

    Inspired by `pdbeccdutils.core.component.Component`.

    Parameters
    ----------
    mol: Mol
        The molecule to process.

    Returns
    -------
    Conformer
        The desired conformer, if any.

    Raises
    ------
    ValueError
        If there are no conformers of the given tyoe.

    """
    # Try using the computed conformer
    for c in mol.GetConformers():
        try:
            if c.GetProp("name") == "Computed":
                return c
        except KeyError:  # noqa: PERF203
            pass

    # Fallback to the ideal coordinates
    for c in mol.GetConformers():
        try:
            if c.GetProp("name") == "Ideal":
                return c
        except KeyError:  # noqa: PERF203
            pass

    msg = "Conformer does not exist."
    raise ValueError(msg)


####################################################################################################
# PARSING
####################################################################################################


def parse_ccd_residue(
    name: str,
    ref_mol: Mol,
    res_idx: int,
) -> Optional[ParsedResidue]:
    """Parse an MMCIF ligand.

    First tries to get the SMILES string from the RCSB.
    Then, tries to infer atom ordering using RDKit.

    Parameters
    ----------
    name: str
        The name of the molecule to parse.
    ref_mol: Mol
        The reference molecule to parse.
    res_idx : int
        The residue index.

    Returns
    -------
    ParsedResidue, optional
       The output ParsedResidue, if successful.

    """
    unk_chirality = const.chirality_type_ids[const.unk_chirality_type]

    # Remove hydrogens
    ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)

    # Check if this is a single atom CCD residue
    if ref_mol.GetNumAtoms() == 1:
        pos = (0, 0, 0)
        ref_atom = ref_mol.GetAtoms()[0]
        chirality_type = const.chirality_type_ids.get(
            str(ref_atom.GetChiralTag()), unk_chirality
        )
        atom = ParsedAtom(
            name=ref_atom.GetProp("name"),
            element=ref_atom.GetAtomicNum(),
            charge=ref_atom.GetFormalCharge(),
            coords=pos,
            conformer=(0, 0, 0),
            is_present=True,
            chirality=chirality_type,
        )
        unk_prot_id = const.unk_token_ids["PROTEIN"]
        residue = ParsedResidue(
            name=name,
            type=unk_prot_id,
            atoms=[atom],
            bonds=[],
            idx=res_idx,
            orig_idx=None,
            atom_center=0,  # Placeholder, no center
            atom_disto=0,  # Placeholder, no center
            is_standard=False,
            is_present=True,
        )
        return residue

    # Get reference conformer coordinates
    conformer = get_conformer(ref_mol)

    # Parse each atom in order of the reference mol
    atoms = []
    atom_idx = 0
    idx_map = {}  # Used for bonds later

    for i, atom in enumerate(ref_mol.GetAtoms()):
        # Get atom name, charge, element and reference coordinates
        atom_name = atom.GetProp("name")
        charge = atom.GetFormalCharge()
        element = atom.GetAtomicNum()
        ref_coords = conformer.GetAtomPosition(atom.GetIdx())
        ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
        chirality_type = const.chirality_type_ids.get(
            str(atom.GetChiralTag()), unk_chirality
        )

        # Get PDB coordinates, if any
        coords = (0, 0, 0)
        atom_is_present = True

        # Add atom to list
        atoms.append(
            ParsedAtom(
                name=atom_name,
                element=element,
                charge=charge,
                coords=coords,
                conformer=ref_coords,
                is_present=atom_is_present,
                chirality=chirality_type,
            )
        )
        idx_map[i] = atom_idx
        atom_idx += 1  # noqa: SIM113

    # Load bonds
    bonds = []
    unk_bond = const.bond_type_ids[const.unk_bond_type]
    for bond in ref_mol.GetBonds():
        idx_1 = bond.GetBeginAtomIdx()
        idx_2 = bond.GetEndAtomIdx()

        # Skip bonds with atoms ignored
        if (idx_1 not in idx_map) or (idx_2 not in idx_map):
            continue

        idx_1 = idx_map[idx_1]
        idx_2 = idx_map[idx_2]
        start = min(idx_1, idx_2)
        end = max(idx_1, idx_2)
        bond_type = bond.GetBondType().name
        bond_type = const.bond_type_ids.get(bond_type, unk_bond)
        bonds.append(ParsedBond(start, end, bond_type))

    unk_prot_id = const.unk_token_ids["PROTEIN"]
    return ParsedResidue(
        name=name,
        type=unk_prot_id,
        atoms=atoms,
        bonds=bonds,
        idx=res_idx,
        atom_center=0,
        atom_disto=0,
        orig_idx=None,
        is_standard=False,
        is_present=True,
    )


def parse_polymer(
    sequence: list[str],
    entity: str,
    chain_type: str,
    components: dict[str, Mol],
) -> Optional[ParsedChain]:
    """Process a sequence into a chain object.

    Performs alignment of the full sequence to the polymer
    residues. Loads coordinates and masks for the atoms in
    the polymer, following the ordering in const.atom_order.

    Parameters
    ----------
    sequence : list[str]
        The full sequence of the polymer.
    entity : str
        The entity id.
    entity_type : str
        The entity type.
    components : dict[str, Mol]
        The preprocessed PDB components dictionary.

    Returns
    -------
    ParsedChain, optional
        The output chain, if successful.

    Raises
    ------
    ValueError
        If the alignment fails.

    """
    ref_res = set(const.tokens)
    unk_chirality = const.chirality_type_ids[const.unk_chirality_type]

    # Get coordinates and masks
    parsed = []
    for res_idx, res_name in enumerate(sequence):
        # Check if modified residue
        # Map MSE to MET
        res_corrected = res_name if res_name != "MSE" else "MET"

        # Handle non-standard residues
        if res_corrected not in ref_res:
            ref_mol = components[res_corrected]
            residue = parse_ccd_residue(
                name=res_corrected,
                ref_mol=ref_mol,
                res_idx=res_idx,
            )
            parsed.append(residue)
            continue

        # Load ref residue
        ref_mol = components[res_corrected]
        ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
        ref_conformer = get_conformer(ref_mol)

        # Only use reference atoms set in constants
        ref_name_to_atom = {a.GetProp("name"): a for a in ref_mol.GetAtoms()}
        ref_atoms = [ref_name_to_atom[a] for a in const.ref_atoms[res_corrected]]

        # Iterate, always in the same order
        atoms: list[ParsedAtom] = []

        for ref_atom in ref_atoms:
            # Get atom name
            atom_name = ref_atom.GetProp("name")
            idx = ref_atom.GetIdx()

            # Get conformer coordinates
            ref_coords = ref_conformer.GetAtomPosition(idx)
            ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)

            # Set 0 coordinate
            atom_is_present = True
            coords = (0, 0, 0)

            # Add atom to list
            atoms.append(
                ParsedAtom(
                    name=atom_name,
                    element=ref_atom.GetAtomicNum(),
                    charge=ref_atom.GetFormalCharge(),
                    coords=coords,
                    conformer=ref_coords,
                    is_present=atom_is_present,
                    chirality=const.chirality_type_ids.get(
                        str(ref_atom.GetChiralTag()), unk_chirality
                    ),
                )
            )

        atom_center = const.res_to_center_atom_id[res_corrected]
        atom_disto = const.res_to_disto_atom_id[res_corrected]
        parsed.append(
            ParsedResidue(
                name=res_corrected,
                type=const.token_ids[res_corrected],
                atoms=atoms,
                bonds=[],
                idx=res_idx,
                atom_center=atom_center,
                atom_disto=atom_disto,
                is_standard=True,
                is_present=True,
                orig_idx=None,
            )
        )

    # Return polymer object
    return ParsedChain(
        entity=entity,
        residues=parsed,
        type=chain_type,
    )

def parse_boltz_schema(  # noqa: C901, PLR0915, PLR0912
    name: str, schema: dict, ccd: Mapping[str, Mol], protein_coords, docked_mol, out_dir
) -> Target:
    """Parse a Boltz input yaml / json.

    The input file should be a dictionary with the following format:

    version: 1
    sequences:
        - protein:
            id: A
            sequence: "MADQLTEEQIAEFKEAFSLF"
            msa: path/to/msa1.a3m
        - protein:
            id: [B, C]
            sequence: "AKLSILPWGHC"
            msa: path/to/msa2.a3m
        - rna:
            id: D
            sequence: "GCAUAGC"
        - ligand:
            id: E
            smiles: "CC1=CC=CC=C1"
        - ligand:
            id: [F, G]
            ccd: []
    constraints:
        - bond:
            atom1: [A, 1, CA]
            atom2: [A, 2, N]
        - pocket:
            binder: E
            contacts: [[B, 1], [B, 2]]

    Parameters
    ----------
    name : str
        A name for the input.
    schema : dict
        The input schema.
    components : dict
        Dictionary of CCD components.

    Returns
    -------
    Target
        The parsed target.

    """
    # Assert version 1
    version = schema.get("version", 1)
    if version != 1:
        msg = f"Invalid version {version} in input!"
        raise ValueError(msg)

    # Disable rdkit warnings
    blocker = rdBase.BlockLogs()  # noqa: F841

    # First group items that have the same type, sequence and modifications
    items_to_group = {}
    for item in schema["sequences"]:
        # Get entity type
        entity_type = next(iter(item.keys())).lower()
        if entity_type not in {"protein", "dna", "rna", "ligand"}:
            msg = f"Invalid entity type: {entity_type}"
            raise ValueError(msg)

        # Get sequence
        if entity_type in {"protein", "dna", "rna"}:
            seq = str(item[entity_type]["sequence"])
        elif entity_type == "ligand":
            assert "smiles" in item[entity_type] or "ccd" in item[entity_type]
            assert "smiles" not in item[entity_type] or "ccd" not in item[entity_type]
            if "smiles" in item[entity_type]:
                seq = str(item[entity_type]["smiles"])
            else:
                seq = str(item[entity_type]["ccd"])
        items_to_group.setdefault((entity_type, seq), []).append(item)

    # Go through entities and parse them
    chains: dict[str, ParsedChain] = {}
    chain_to_msa: dict[str, str] = {}
    entity_to_seq: dict[str, str] = {}
    is_msa_custom = False
    is_msa_auto = False
    for entity_id, items in enumerate(items_to_group.values()):
        # Get entity type and sequence
        entity_type = next(iter(items[0].keys())).lower()

        # Ensure all the items share the same msa
        msa = -1
        if entity_type == "protein":
            # Get the msa, default to 0, meaning auto-generated
            msa = items[0][entity_type].get("msa", 0)
            if (msa is None) or (msa == ""):
                msa = 0

            # Check if all MSAs are the same within the same entity
            for item in items:
                item_msa = item[entity_type].get("msa", 0)
                if (item_msa is None) or (item_msa == ""):
                    item_msa = 0

                if item_msa != msa:
                    msg = "All proteins with the same sequence must share the same MSA!"
                    raise ValueError(msg)

            # Set the MSA, warn if passed in single-sequence mode
            if msa == "empty":
                msa = -1
                msg = (
                    "Found explicit empty MSA for some proteins, will run "
                    "these in single sequence mode. Keep in mind that the "
                    "model predictions will be suboptimal without an MSA."
                )
                click.echo(msg)

            if msa not in (0, -1):
                is_msa_custom = True
            elif msa == 0:
                is_msa_auto = True

        # Parse a polymer
        if entity_type in {"protein", "dna", "rna"}:
            # Get token map
            if entity_type == "rna":
                token_map = const.rna_letter_to_token
            elif entity_type == "dna":
                token_map = const.dna_letter_to_token
            elif entity_type == "protein":
                token_map = const.prot_letter_to_token
            else:
                msg = f"Unknown polymer type: {entity_type}"
                raise ValueError(msg)

            # Get polymer info
            chain_type = const.chain_type_ids[entity_type.upper()]
            unk_token = const.unk_token[entity_type.upper()]

            # Extract sequence
            seq = items[0][entity_type]["sequence"]
            entity_to_seq[entity_id] = seq

            # Convert sequence to tokens
            seq = [token_map.get(c, unk_token) for c in list(seq)]

            # Apply modifications
            for mod in items[0][entity_type].get("modifications", []):
                code = mod["ccd"]
                idx = mod["position"] - 1  # 1-indexed
                seq[idx] = code

            # Parse a polymer
            parsed_chain = parse_polymer(
                sequence=seq,
                entity=entity_id,
                chain_type=chain_type,
                components=ccd,
            )

        # Parse a non-polymer
        elif (entity_type == "ligand") and "ccd" in (items[0][entity_type]):
            seq = items[0][entity_type]["ccd"]
            if isinstance(seq, str):
                seq = [seq]

            residues = []
            for code in seq:
                if code not in ccd:
                    msg = f"CCD component {code} not found!"
                    raise ValueError(msg)

                # Parse residue
                residue = parse_ccd_residue(
                    name=code,
                    ref_mol=ccd[code],
                    res_idx=0,
                )
                residues.append(residue)

            # Create multi ligand chain
            parsed_chain = ParsedChain(
                entity=entity_id,
                residues=residues,
                type=const.chain_type_ids["NONPOLYMER"],
            )
        elif (entity_type == "ligand") and ("smiles" in items[0][entity_type]):
            seq = items[0][entity_type]["smiles"]
            mol = AllChem.MolFromSmiles(seq)
            mol = AllChem.AddHs(mol)

            # Set atom names
            canonical_order = AllChem.CanonicalRankAtoms(mol)
            for atom, can_idx in zip(mol.GetAtoms(), canonical_order):
                atom_name = atom.GetSymbol().upper() + str(can_idx + 1)
                if len(atom_name) > 4:
                    raise ValueError(
                        f"{seq} has an atom with a name longer than 4 characters: {atom_name}"
                    )
                atom.SetProp("name", atom_name)

            success = compute_3d_conformer(mol)
            if not success:
                msg = f"Failed to compute 3D conformer for {seq}"
                raise ValueError(msg)

            mol_no_h = AllChem.RemoveHs(mol)
            residue = parse_ccd_residue(
                name="LIG",
                ref_mol=mol_no_h,
                res_idx=0,
            )
            parsed_chain = ParsedChain(
                entity=entity_id,
                residues=[residue],
                type=const.chain_type_ids["NONPOLYMER"],
            )
        else:
            msg = f"Invalid entity type: {entity_type}"
            raise ValueError(msg)

        # Add as many chains as provided ids
        for item in items:
            ids = item[entity_type]["id"]
            if isinstance(ids, str):
                ids = [ids]
            for chain_name in ids:
                chains[chain_name] = parsed_chain
                chain_to_msa[chain_name] = msa

    # Check if msa is custom or auto
    if is_msa_custom and is_msa_auto:
        msg = "Cannot mix custom and auto-generated MSAs in the same input!"
        raise ValueError(msg)

    # If no chains parsed fail
    if not chains:
        msg = "No chains parsed!"
        raise ValueError(msg)

    # Create tables
    atom_data = []
    atom_data2 = []
    bond_data = []
    res_data = []
    chain_data = []

    # Convert parsed chains to tables
    atom_idx = 0
    res_idx = 0
    asym_id = 0
    sym_count = {}
    chain_to_idx = {}

    # Keep a mapping of (chain_name, residue_idx, atom_name) to atom_idx
    atom_idx_map = {}

    for asym_id, (chain_name, chain) in enumerate(chains.items()):
        # Compute number of atoms and residues
        res_num = len(chain.residues)
        atom_num = sum(len(res.atoms) for res in chain.residues)

        # Find all copies of this chain in the assembly
        entity_id = int(chain.entity)
        sym_id = sym_count.get(entity_id, 0)
        chain_data.append(
            (
                chain_name,
                chain.type,
                entity_id,
                sym_id,
                asym_id,
                atom_idx,
                atom_num,
                res_idx,
                res_num,
            )
        )
        chain_to_idx[chain_name] = asym_id
        sym_count[entity_id] = sym_id + 1

        # Add residue, atom, bond, data
        for res in chain.residues:
            atom_center = atom_idx + res.atom_center
            atom_disto = atom_idx + res.atom_disto
            res_data.append(
                (
                    res.name,
                    res.type,
                    res.idx,
                    atom_idx,
                    len(res.atoms),
                    atom_center,
                    atom_disto,
                    res.is_standard,
                    res.is_present,
                )
            )

            for bond in res.bonds:
                atom_1 = atom_idx + bond.atom_1
                atom_2 = atom_idx + bond.atom_2
                bond_data.append((atom_1, atom_2, bond.type))

            for atom in res.atoms:
                # Add atom to map
                atom_idx_map[(chain_name, res.idx, atom.name)] = (
                    asym_id,
                    res_idx,
                    atom_idx,
                )

                # Add atom to data
                atom_data.append(
                    (
                        convert_atom_name(atom.name),
                        atom.element,
                        atom.charge,
                        atom.coords,
                        atom.conformer,
                        atom.is_present,
                        atom.chirality
                    )
                )
                atom_data2.append(
                    (
                        atom.name,
                        atom.element,
                        atom.charge,
                        atom.conformer,
                        atom.is_present,
                        atom.chirality,
                        res.name,
                        res.idx,
                    )
                )
                atom_idx += 1

            res_idx += 1

    # Parse constraints
    connections = []
    pocket_binders = []
    pocket_residues = []
    constraints = schema.get("constraints", [])
    for constraint in constraints:
        if "bond" in constraint:
            if "atom1" not in constraint["bond"] or "atom2" not in constraint["bond"]:
                msg = f"Bond constraint was not properly specified"
                raise ValueError(msg)

            c1, r1, a1 = tuple(constraint["bond"]["atom1"])
            c2, r2, a2 = tuple(constraint["bond"]["atom2"])
            c1, r1, a1 = atom_idx_map[(c1, r1 - 1, a1)]  # 1-indexed
            c2, r2, a2 = atom_idx_map[(c2, r2 - 1, a2)]  # 1-indexed
            connections.append((c1, c2, r1, r2, a1, a2))
        elif "pocket" in constraint:
            if (
                "binder" not in constraint["pocket"]
                or "contacts" not in constraint["pocket"]
            ):
                msg = f"Pocket constraint was not properly specified"
                raise ValueError(msg)

            binder = constraint["pocket"]["binder"]
            contacts = constraint["pocket"]["contacts"]

            if len(pocket_binders) > 0:
                if pocket_binders[-1] != chain_to_idx[binder]:
                    msg = f"Only one pocket binders is supported!"
                    raise ValueError(msg)
                else:
                    pocket_residues[-1].extend(
                        [
                            (chain_to_idx[chain_name], residue_index - 1)
                            for chain_name, residue_index in contacts
                        ]
                    )

            else:
                pocket_binders.append(chain_to_idx[binder])
                pocket_residues.extend(
                    [
                        (chain_to_idx[chain_name], residue_index - 1)
                        for chain_name, residue_index in contacts
                    ]
                )
        else:
            msg = f"Invalid constraint: {constraint}"
            raise ValueError(msg)

    # Convert into datatypes
    atoms = np.array(atom_data, dtype=Atom)
    bonds = np.array(bond_data, dtype=Bond)
    residues = np.array(res_data, dtype=Residue)
    chains = np.array(chain_data, dtype=Chain)
    interfaces = np.array([], dtype=Interface)
    connections = np.array(connections, dtype=Connection)
    mask = np.ones(len(chain_data), dtype=bool)

    data = Structure(
        atoms=atoms,
        bonds=bonds,
        residues=residues,
        chains=chains,
        connections=connections,
        interfaces=interfaces,
        mask=mask,
    )

    # Create metadata
    struct_info = StructureInfo(num_chains=len(chains))
    chain_infos = []
    for chain in chains:
        chain_info = ChainInfo(
            chain_id=int(chain["asym_id"]),
            chain_name=chain["name"],
            mol_type=int(chain["mol_type"]),
            cluster_id=-1,
            msa_id=chain_to_msa[chain["name"]],
            num_residues=int(chain["res_num"]),
            valid=True,
            entity_id=int(chain["entity_id"]),
        )
        chain_infos.append(chain_info)

    options = InferenceOptions(binders=pocket_binders, pocket=pocket_residues)

    record = Record(
        id=name,
        structure=struct_info,
        chains=chain_infos,
        interfaces=[],
        inference_options=options,
    )
    return atom_idx_map, protein_coords, mol_no_h, docked_mol, name, out_dir, res_data, atom_data2


In [12]:
atom_idx_map, protein_coords, mol_no_h, docked_mol, name, out_dir, res_data,atom_data2 = parse_boltz_schema(name, data, ccd, protein_coords, docked_mol,out_dir)

In [13]:
from Bio.PDB import PDBParser

def check_backbone_completeness(pdb_file):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("PDB", pdb_file)
    missing = []

    for model in structure:
        for chain in model:
            for residue in chain:
                if residue.id[0] != ' ':  # Skip heteroatoms
                    continue
                atoms = {atom.get_id() for atom in residue}
                for atom_name in ["N", "CA", "C", "O"]:
                    if atom_name not in atoms:
                        missing.append((chain.id, residue.id[1], atom_name))

    return missing

# Example usage
if __name__ == "__main__":
    missing_atoms = check_backbone_completeness(pdb_path)
    for chain, res_id, atom in missing_atoms:
        print(f"Missing {atom} in residue {res_id} (chain {chain})")


In [14]:
# Construct.get_dock_coords(atom_idx_map, protein_coords, mol_no_h, docked_mol, name, out_dir)


atom_idx_list = list(atom_idx_map.keys())

# Get name symbols from mol_no_h
mol_no_h_names = []
for atom in mol_no_h.GetAtoms():
    atom_name = atom.GetProp("name")
    atom_name = atom_name.strip()
    mol_no_h_names.append(atom_name)

print(mol_no_h_names)

docked_mol = AllChem.RemoveHs(docked_mol)

match = docked_mol.GetSubstructMatch(mol_no_h)
print("Before...")
print(match)
if len(match) == mol_no_h.GetNumAtoms():
    atom_map = [(i, match[i]) for i in range(mol_no_h.GetNumAtoms())]
else:
    raise ValueError("Couldn't find complete match - partial mapping only")

print(RdkitUtils.get_coord_from_mol(docked_mol))

# Validation
reordered_mol = Chem.RenumberAtoms(docked_mol, list(match))
print("After...")

match2 = reordered_mol.GetSubstructMatch(mol_no_h)
print(match2)
atom_map2 = [(i, match2[i]) for i in range(mol_no_h.GetNumAtoms())]
for i, j in atom_map2:
    if i != j:
        raise ValueError(
            f"Mismatch found in atom_map2: ({i}, {j}) is not equal"
        )

ligand_coords = RdkitUtils.get_coord_from_mol(reordered_mol)

if len(ligand_coords) != len(mol_no_h_names):
    raise ValueError("Dim not matching")

print(ligand_coords)

# pdb.
updated_molecule_coords = [
    ("B", 0, new_value, t[0], t[1], t[2])
    for t, new_value in zip(ligand_coords, mol_no_h_names)
]

result = Construct.filter_and_validate_table_fast(
    protein_coords + updated_molecule_coords, atom_idx_list
)

trimmed_list = [t[3:] for t in result]


['C37', 'C43', 'C41', 'N35', 'C42', 'C44', 'C38', 'C31', 'O24', 'C40', 'C39', 'C30', 'O23', 'N27', 'C36', 'C32', 'N26', 'O28', 'C33', 'C34', 'C29', 'O25', 'O22']
Before...
(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)
[(0.948, 5.759, 16.814), (1.061, 6.325, 15.488), (1.99, 5.511, 14.578), (1.831, 6.028, 13.294), (2.597, 7.217, 12.995), (4.068, 6.844, 12.798), (4.863, 8.015, 12.43), (0.998, 5.49, 12.409), (0.454, 4.465, 12.699), (0.867, 6.098, 11.033), (0.691, 5.033, 9.994), (1.645, 5.222, 8.835), (2.178, 6.293, 8.585), (1.826, 4.134, 8.078), (2.672, 4.164, 6.921), (4.106, 3.914, 7.249), (4.944, 4.891, 7.652), (6.111, 4.313, 7.906), (6.095, 2.997, 7.689), (4.805, 2.739, 7.286), (4.304, 1.401, 6.988), (3.319, 1.289, 6.287), (4.856, 0.414, 7.455)]
After...
(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)
[(0.948, 5.759, 16.814), (1.061, 6.325, 15.488), (1.99, 5.511, 14.578), (1.831, 6.028, 13.294), (2.597, 7.217, 12.99

In [15]:
print(mol_no_h)
num_atoms = mol_no_h.GetNumAtoms() if hasattr(mol_no_h, "GetNumAtoms") else "N/A"
smiles = Chem.MolToSmiles(mol_no_h) if hasattr(mol_no_h, "GetNumAtoms") else "N/A"
print(f"[INFO] mol_no_h num_atoms: {num_atoms}")
print(f"[INFO] mol_no_h SMILES: {smiles}")

<rdkit.Chem.rdchem.Mol object at 0x7f7979065d90>
[INFO] mol_no_h num_atoms: 23
[INFO] mol_no_h SMILES: CCCN(CCC)C(=O)CCC(=O)[N-]Cc1nocc1C(=O)[O-]
