In [1]:
import sys
import warnings
import tempfile
import threading
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import List, Set
import re
import numpy as np

from Bio.PDB import PDBParser, NeighborSearch, Select, PDBIO
from loguru import logger
import pymol
from pymol import cmd
from rdkit import Chem

from spyrmsd import graph, molecule

warnings.filterwarnings("ignore", category=UserWarning, module="MDAnalysis.core.universe")

logger.remove()
logger.add(sys.stdout, format="{message}", level="DEBUG")

BOND_LENGTHS = {
    ("C", "C"): 1.54,
    ("C", "N"): 1.65,
    ("C", "O"): 1.63,
    ("C", "S"): 1.80,
    ("P", "O"): 1.64,
    ("N", "O"): 1.55,
    ("N", "S"): 1.70,
    ("Fe", "S"): 2.33,
    ("Fe", "O"): 2.20,
    ("Fe", "N"): 2.15,
    ("Zn", "S"): 2.39,
    ("Zn", "N"): 2.26,
    ("Zn", "O"): 2.42,
    ("Cu", "S"): 2.26,
    ("Cu", "N"): 1.98,
    ("Cu", "O"): 1.95,
    ("Ni", "S"): 2.28,
    ("Ni", "N"): 2.18,
    ("Ni", "O"): 2.20,
    ("Mn", "O"): 2.18,
    ("Mn", "N"): 2.20,
    ("Co", "S"): 2.32,
    ("Co", "N"): 2.25,
    ("Co", "O"): 2.20,
    ("Mg", "O"): 2.18,
    ("Ca", "O"): 2.45,
    ("Mo", "S"): 2.42,
    ("Mo", "O"): 2.00,
    ("Na", "O"): 2.35,
    ("K", "O"): 2.75,
}

METALS = {"Fe", "Zn", "Cu", "Ni", "Mn", "Co", "Mg", "Ca", "Mo", "Na", "K"}

def _get_bond_threshold(el1, el2, default_bond_distance):
    pair = (el1, el2)
    if pair in BOND_LENGTHS:
        return BOND_LENGTHS[pair]
    rev = (el2, el1)
    if rev in BOND_LENGTHS:
        return BOND_LENGTHS[rev]
    return default_bond_distance

def _elem(atom) -> str:
    e = getattr(atom, "element", "").strip()
    if e:
        return e.capitalize()
    name = atom.get_name().strip()
    m = re.match(r"([A-Za-z]{1,2})", name)
    if not m:
        return name[0].upper()
    s = m.group(1).upper()
    return s.capitalize()

class _PymolSession:
    _lock = threading.Lock()
    _started = False

    @classmethod
    def ensure(cls) -> None:
        with cls._lock:
            if not cls._started:
                pymol.finish_launching(["pymol", "-qc"])
                cls._started = True

    @classmethod
    def locked(cls):
        cls.ensure()
        return cls._lock


def _ligands_connected(r1, r2, default_bond_distance):
    for a1 in r1.get_atoms():
        for a2 in r2.get_atoms():
            el1 = _elem(a1)
            el2 = _elem(a2)
            thr = _get_bond_threshold(el1, el2, default_bond_distance)
            dist = np.linalg.norm(a1.coord - a2.coord)
            #logger.debug(f"_ligands_connected: Checking {r1.get_resname()}-{r2.get_resname()} : "
            #             f"{a1.get_name()}-{a2.get_name()}, dist={dist:.3f}, thr={thr}")
            if dist <= thr:
                #logger.debug(f"_ligands_connected: FOUND BOND between {r1.get_resname()} and {r2.get_resname()}")
                return True
    return False

def _rmsd_on_ca(pdb1: str, pdb2: str) -> float:
    with _PymolSession.locked():
        o1, o2 = f"o{uuid.uuid4().hex}", f"o{uuid.uuid4().hex}"
        cmd.delete("all")
        cmd.load(pdb1, o1)
        cmd.load(pdb2, o2)
        try:
            return float(cmd.align(f"{o1} and name CA", f"{o2} and name CA")[0])
        finally:
            cmd.delete("all")


def _rmsd_on_ligand(pdb1: str, pdb2: str) -> float:
    with _PymolSession.locked():
        o1, o2 = f"o{uuid.uuid4().hex}", f"o{uuid.uuid4().hex}"
        cmd.delete("all")
        cmd.load(pdb1, o1)
        cmd.load(pdb2, o2)

        cmd.align(f"{o1} and name CA", f"{o2} and name CA", cycles=0)

        def _dump(obj, fn):
            tmp_obj = f"{obj}_x"
            cmd.create(tmp_obj, obj, 1, 1)
            cmd.save(fn, f"{tmp_obj} and hetatm and not resn HOH", state=1)
            cmd.delete(tmp_obj)

        t1 = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False)
        t2 = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False)
        _dump(o1, t1.name)
        _dump(o2, t2.name)

        m1 = Chem.MolFromPDBFile(t1.name, removeHs=False)
        m2 = Chem.MolFromPDBFile(t2.name, removeHs=False)
        t1.close()
        t2.close()
        if not m1 or not m2:
            logger.debug("    ligand RMSD -> RDKit build failed (m1 or m2 is None)")
            cmd.delete("all")
            return float("inf")

        c1 = np.array([a.coord for a in cmd.get_model(o1 + " and hetatm and not resn HOH").atom])
        c2 = np.array([a.coord for a in cmd.get_model(o2 + " and hetatm and not resn HOH").atom])

        try:
            mol1 = molecule.Molecule.from_rdkit(m1)
            mol2 = molecule.Molecule.from_rdkit(m2)
            G1 = graph.graph_from_adjacency_matrix(mol1.adjacency_matrix, mol1.atomicnums)
            G2 = graph.graph_from_adjacency_matrix(mol2.adjacency_matrix, mol2.atomicnums)


            for idx1, idx2 in graph.match_graphs(G1, G2):
                r = np.sqrt(np.mean((c1[idx1] - c2[idx2]) ** 2))
                logger.debug("    ligand RMSD -> graph mode matched. RMSD=%.3f" % r)
                cmd.delete("all")
                return float(r)
        except Exception as e:
            logger.debug(f"    ligand RMSD -> graph error: {e}")

        logger.debug("    ligand RMSD -> return inf (no match)")
        cmd.delete("all")
        return float("inf")


@dataclass
class Pocket:
    ligands: List
    chains: Set[str]
    bond_type: str


class LigandPocketExtractor:

    def __init__(self, interact_d=4.5, ligand_cluster_d=1.5, default_bond_d=1.5, short_peptide=10):
        self.interact_d = interact_d
        self.default_bond_d = default_bond_d
        self.short_peptide = short_peptide
        self.ligand_cluster_d = max(ligand_cluster_d, default_bond_d)
        self._search_radius = max(self.ligand_cluster_d, max(BOND_LENGTHS.values(), default=0.0))

    @staticmethod

    def _chains_near(lig_atoms, prot_atoms, d):
        if not prot_atoms:
            return set()
        ns = NeighborSearch(prot_atoms)
        return {
            a.get_parent().get_parent().id
            for la in lig_atoms
            for a in ns.search(la.coord, d, level="A")
        }

    @staticmethod
    def _save_temp(structure, selector):
        io = PDBIO()
        tmp = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False)
        io.set_structure(structure)
        io.save(tmp.name, select=selector)
        tmp.close()
        return tmp.name

    def _check_bond_type(self, la, pa):
        el_l = _elem(la)
        el_p = _elem(pa)
        thr = _get_bond_threshold(el_l, el_p, self.default_bond_d)
        dist = np.linalg.norm(la.coord - pa.coord)
        if dist <= thr:
            if el_l in METALS or el_p in METALS:
                return "coord"
            return "cov"
        return None

    def _get_pocket_bond_type(self, lig_atoms, prot_atoms):

        p_ns = NeighborSearch(prot_atoms)
        all_ns = NeighborSearch(prot_atoms + lig_atoms)


        for la in lig_atoms:
            for pa in all_ns.search(la.coord, 3.0, level="A"):
                if pa is la:
                    continue
                if _elem(la) in METALS or _elem(pa) in METALS:
                    bt = self._check_bond_type(la, pa)
                    if bt == "coord":
                        return "coord"


        for la in lig_atoms:
            for pa in p_ns.search(la.coord, 3.0, level="A"): 
                bt = self._check_bond_type(la, pa)
                if bt == "cov":
                    return "cov"

        return ""

    def extract(self, structure):

        if len(structure) > 1:
            logger.debug("extract: multiple models, using first model.")
            structure = structure[0]

        chain_res = {}
        for r in structure.get_residues():
            chain_res.setdefault(r.get_parent().id, []).append(r)

        protein_chains = {cid for cid, res in chain_res.items()
                          if sum(1 for r in res if r.id[0] == " ") > self.short_peptide}

        prot_atoms = [a for a in structure.get_atoms()
                      if a.get_parent().id[0] == " "
                      and a.get_parent().get_parent().id in protein_chains]
        
        for cid, reslist in chain_res.items():
            logger.debug(f"Chain {cid} has {len(reslist)} residues. "
                         f"Num_std={sum(1 for r in reslist if r.id[0]==' ')}")


        candidates = [r for r in structure.get_residues()
                      if (r.id[0] != " " and r.get_resname() != "HOH")
                      or len(chain_res[r.get_parent().id]) <= self.short_peptide]

        logger.debug(f"Candidate residues: {len(candidates)}")

        index = {r: i for i, r in enumerate(candidates)}
        parent = list(range(len(candidates)))

        def find(i):
            while parent[i] != i:
                parent[i] = parent[parent[i]]
                i = parent[i]
            return i

        def union(i, j):
            pi, pj = find(i), find(j)
            if pi != pj:
                parent[pj] = pi

        ns = NeighborSearch([a for r in candidates for a in r.get_atoms()])
        for r in candidates:
            i = index[r]
            for a in r.get_atoms():
                for n in ns.search(a.coord, self._search_radius, level="A"):
                    r2 = n.get_parent()
                    if r2 is r or r2 not in index:
                        continue
                    dist = np.linalg.norm(a.coord - n.coord)
                    thr = _get_bond_threshold(_elem(a), _elem(n), self.ligand_cluster_d)
                    if dist <= thr:
                        union(i, index[r2])

        components = {}
        for r in candidates:
            components.setdefault(find(index[r]), []).append(r)

        pockets = []
        for residues in components.values():
            lig_atoms = [a for rr in residues for a in rr.get_atoms()]
            chains = self._chains_near(lig_atoms, prot_atoms, self.interact_d)
            if not chains:
                continue
            btype = self._get_pocket_bond_type(lig_atoms, prot_atoms)
            logger.debug(f"New pocket with residues {[r.get_resname() for r in residues]} -> bond={btype}, chains={chains}")
            pockets.append(Pocket(residues, chains, btype))

        logger.debug(f"Pockets found: {len(pockets)}")
        return pockets


import copy

def _unique_name(base, taken):
    if base not in taken:
        return base
    for i in range(1, 100):
        cand = (base[:3] + str(i))[:4]
        if cand not in taken:
            return cand
    raise ValueError("too many duplicates")

class _MergedSelect(Select):
    def __init__(self, ligands, chains):
        self._ligands = set(ligands)
        self._chains = chains
    def accept_chain(self, chain):
        return chain.id in self._chains or any(r.get_parent() is chain for r in self._ligands)
    def accept_residue(self, residue):
        if residue in self._ligands:
            return True
        return residue.get_parent().id in self._chains and residue.id[0] == " "

class PocketWriter:
    def __init__(self,
                 rmsd_thr=2.0,
                 lig_rmsd_thr=0.5,
                 ligand_cluster_distance=3.0,
                 default_bond_distance=1.5):
        self.rmsd_thr = rmsd_thr
        self.lig_rmsd_thr = lig_rmsd_thr
        self.ligand_cluster_distance = ligand_cluster_distance
        self.default_bond_distance = default_bond_distance
        self._saved = []

    def _is_duplicate(self, new_tmp, lig_resnames, new_chains):
        new_names = set(lig_resnames)
        for s in self._saved:
            if s["chains"] != new_chains and _rmsd_on_ca(s["tmp"], new_tmp) >= self.rmsd_thr:
                continue
            if new_names != set(s["ligand_res"]):
                continue
            if _rmsd_on_ligand(s["tmp"], new_tmp) < self.lig_rmsd_thr:
                return True
        return False

    @staticmethod
    def _merge_residues(residues, new_name):
        main = residues[0]
        main.resname = "LIG"
        taken = {a.get_name() for a in main.get_atoms()}
        
        for r in residues[1:]:
            if r is main:
                continue
            
            for at in list(r.get_atoms()):
                c = copy.copy(at)
                nn = _unique_name(at.name.strip(), taken)
                c.id = c.name = nn
                c.fullname = f"{nn:>4}"
                taken.add(nn)
                main.add(c)
            
            parent = r.get_parent()
            if parent is not None:
                parent.detach_child(r.id)
        return main

    def save(self, structure, pocket, out_dir, extra_lines, pdb_basename):
        lig_raw = [r for r in pocket.ligands if r.get_resname() != "HOH"]
        if not lig_raw:
            return

        parent = list(range(len(lig_raw)))
        def find(i):
            while parent[i] != i:
                parent[i] = parent[parent[i]]
                i = parent[i]
            return i
        def union(i, j):
            pi, pj = find(i), find(j)
            if pi != pj:
                parent[pj] = pi
        for i in range(len(lig_raw)):
            for j in range(i + 1, len(lig_raw)):
                if _ligands_connected(lig_raw[i], lig_raw[j], self.default_bond_distance):
                    union(i, j)

        comps = {}
        for idx, r in enumerate(lig_raw):
            comps.setdefault(find(idx), []).append(r)

        merged_ligs = []
        group_names = []
        for comp in comps.values():
            names = sorted({r.get_resname() for r in comp})
            if len(comp) > 1:
                new_name = "".join(names)[:3]
                main = self._merge_residues(comp, new_name)
                merged_ligs.append(main)
            else:
                merged_ligs.append(comp[0])
            group_names.append("".join(names))

        pocket.ligands = merged_ligs

        file_lig_code = "_".join(sorted(group_names)) if len(group_names) > 1 else group_names[0]
        chains_code = "_".join(sorted(pocket.chains))
        parts = [pdb_basename, file_lig_code, "chains", chains_code]
        if pocket.bond_type:
            parts.append(pocket.bond_type)
        root = "_".join(parts)

        out_path = out_dir / f"{root}.pdb"
        v = 2
        while out_path.exists():
            out_path = out_dir / f"{root}_v{v}.pdb"
            v += 1

        selector = _MergedSelect(pocket.ligands, pocket.chains)
        tmp = LigandPocketExtractor._save_temp(structure, selector)
        if self._is_duplicate(tmp, group_names, pocket.chains):
            Path(tmp).unlink(missing_ok=True)
            return

        io = PDBIO()
        io.set_structure(structure)
        with open(out_path, "w") as fh:
            fh.writelines(extra_lines)
            io.save(fh, select=selector)
            fh.write("\n")

        self._saved.append({"output": out_path,
                            "tmp": tmp,
                            "ligand_res": group_names,
                            "chains": pocket.chains})
        logger.debug(f"Saved: {out_path}")




class ProteinAnalyzer:
    def __init__(
        self,
        interaction_distance=4.5,
        ligand_cluster_distance=0.5,
        rmsd_threshold=2.0,
        default_bond_distance=1.5,
        short_peptide_length=10,
        ligand_rmsd_threshold=0.5,
        overlap_distance=0.5,
        limit_low=3,
        limit_high=5,
        protect_distance=3.0,
    ):
        self.extractor = LigandPocketExtractor(
            interaction_distance,
            ligand_cluster_distance,
            default_bond_distance,
            short_peptide_length,
        )
        self.writer = PocketWriter(
            rmsd_threshold,
            ligand_rmsd_threshold,
            ligand_cluster_distance,
            default_bond_distance,
        )
        self.overlap_distance = overlap_distance
        self.limit_low = limit_low
        self.limit_high = limit_high
        self.protect_distance = protect_distance

    @staticmethod
    def _read_extra_lines(pdb_filepath: Path):
        other = []
        with open(pdb_filepath) as fh:
            for line in fh:
                if not line.startswith(("ATOM", "HETATM", "END", "MASTER", "TER", "CONECT", "ANISOU")):
                    other.append(line)
        return other

    @staticmethod
    def _min_dist(r1, r2):
        a1 = np.array([a.coord for a in r1.get_atoms()])
        a2 = np.array([a.coord for a in r2.get_atoms()])
        return np.min(np.linalg.norm(a1[:, None, :] - a2[None, :, :], axis=-1))

    @staticmethod
    def _pocket_min_dist(p1, p2):
        a1 = np.array([a.coord for r in p1.ligands for a in r.get_atoms()])
        a2 = np.array([a.coord for r in p2.ligands for a in r.get_atoms()])
        return np.min(np.linalg.norm(a1[:, None, :] - a2[None, :, :], axis=-1))

    def _split_overlapping(self, pocket, prot_atoms):

        ligs = [r for r in pocket.ligands if r.get_resname() != "HOH"]
        if len(ligs) < 2:
            return [pocket]

        overlapped = {
            (i, j)
            for i in range(len(ligs))
            for j in range(i + 1, len(ligs))
            if self._min_dist(ligs[i], ligs[j]) < self.overlap_distance
        }
        if not overlapped:
            return [pocket]

        others = [r for r in ligs if all((ligs.index(r), k) not in overlapped and
                                        (k, ligs.index(r)) not in overlapped for k in range(len(ligs)))]

        pockets = []
        for lig in {ligs[i] for i, _ in overlapped}.union({ligs[j] for _, j in overlapped}):
            lig_atoms = [a for a in lig.get_atoms()]
            btype = self.extractor._get_pocket_bond_type(lig_atoms, prot_atoms)
            pockets.append(Pocket([lig] + others, pocket.chains, btype))
        return pockets

    def _filter_overabundant(self, pockets):
        counts = {}
        for p in pockets:
            for n in {r.get_resname() for r in p.ligands if r.get_resname() != "HOH"}:
                counts[n] = counts.get(n, 0) + 1
        if not counts:
            return pockets
        min_count = min(counts.values())
        limit = self.limit_low if min_count <= 2 else self.limit_high
        abundant = {n for n, c in counts.items() if c > limit}
        if not abundant:
            return pockets
        rare_pockets = [
            p for p in pockets
            if not {r.get_resname() for r in p.ligands if r.get_resname() != "HOH"}.issubset(abundant)
        ]
        kept, removed = [], 0
        for p in pockets:
            names = {r.get_resname() for r in p.ligands if r.get_resname() != "HOH"}
            if names and names.issubset(abundant):
                close_to_rare = any(self._pocket_min_dist(p, rp) < self.protect_distance for rp in rare_pockets)
                if not close_to_rare:
                    removed += 1
                    continue
            kept.append(p)
        logger.debug(
            f"Overabundant filter (min {min_count}, limit {limit}, protect {self.protect_distance} Å): "
            f"removed {removed} pockets"
        )
        return kept

    def analyze(self, pdb_filepath, output_directory="separated_complexes"):
        pdb_filepath = Path(pdb_filepath)
        out_dir = Path(output_directory)
        out_dir.mkdir(exist_ok=True)
        logger.debug(f"Start: {pdb_filepath}")
        extra = self._read_extra_lines(pdb_filepath)
        structure = PDBParser(QUIET=True).get_structure("pdb", pdb_filepath)
        raw = self.extractor.extract(structure)
        prot_atoms = [a for a in structure.get_atoms() if a.get_parent().id[0] == " "]
        pockets = []
        for p in raw:
            splitted = self._split_overlapping(p, prot_atoms)
            pockets.extend(splitted)
        pockets = self._filter_overabundant(pockets)
        logger.debug(f"Total pockets to export: {len(pockets)}")
        for p in pockets:
            self.writer.save(structure, p, out_dir, extra, pdb_filepath.stem)
        logger.debug("Done")


def analyze_protein(
    pdb_filepath,
    output_directory="separated_complexes",
    interaction_distance=4.5,
    ligand_cluster_distance=1.5,
    rmsd_threshold=2.0,
    default_bond_distance=1.5,
    short_peptide_length=8,
    ligand_rmsd_threshold=1.0,
    overlap_distance=0.6,
):
    analyzer = ProteinAnalyzer(
        interaction_distance,
        ligand_cluster_distance,
        rmsd_threshold,
        default_bond_distance,
        short_peptide_length,
        ligand_rmsd_threshold,
        overlap_distance,
    )
    analyzer.analyze(pdb_filepath, output_directory)

In [2]:
!rm -r test

In [3]:
Path('test').mkdir(exist_ok=True)
analyze_protein('/home/nikolenko/work/Projects/LPCE/lpce/tests/test_data/1ahy.pdb', 'test/separated_complexes')


Start: /home/nikolenko/work/Projects/LPCE/lpce/tests/test_data/1ahy.pdb
Chain A has 398 residues. Num_std=396
Chain B has 398 residues. Num_std=396
Candidate residues: 4
New pocket with residues ['PLP'] -> bond=cov, chains={'B', 'A'}
New pocket with residues ['MAE'] -> bond=, chains={'B', 'A'}
New pocket with residues ['PLP'] -> bond=cov, chains={'B', 'A'}
New pocket with residues ['MAE'] -> bond=, chains={'B', 'A'}
Pockets found: 4
Total pockets to export: 4
Saved: test/separated_complexes/1ahy_PLP_chains_A_B_cov.pdb
Saved: test/separated_complexes/1ahy_MAE_chains_A_B.pdb
    ligand RMSD -> graph mode matched. RMSD=16.383
Saved: test/separated_complexes/1ahy_PLP_chains_A_B_cov_v2.pdb
    ligand RMSD -> graph mode matched. RMSD=18.979
Saved: test/separated_complexes/1ahy_MAE_chains_A_B_v2.pdb
Done


In [4]:
Path('test').mkdir(exist_ok=True)
analyze_protein('/home/nikolenko/work/Projects/LPCE/lpce/tests/test_data/1bxr.pdb', 'test/separated_complexes')


Start: /home/nikolenko/work/Projects/LPCE/lpce/tests/test_data/1bxr.pdb
Chain A has 1763 residues. Num_std=1073
Chain B has 550 residues. Num_std=379
Chain C has 1712 residues. Num_std=1073
Chain D has 615 residues. Num_std=379
Chain E has 1737 residues. Num_std=1073
Chain F has 572 residues. Num_std=379
Chain G has 1618 residues. Num_std=1073
Chain H has 545 residues. Num_std=379
Candidate residues: 62
New pocket with residues ['MN'] -> bond=coord, chains={'A'}
New pocket with residues ['K'] -> bond=coord, chains={'A'}
New pocket with residues ['K'] -> bond=coord, chains={'A'}
New pocket with residues ['MN', 'MN', 'ANP'] -> bond=coord, chains={'A'}
New pocket with residues ['K'] -> bond=coord, chains={'A'}
New pocket with residues ['CL'] -> bond=, chains={'A'}
New pocket with residues ['CL'] -> bond=, chains={'A'}
New pocket with residues ['CL'] -> bond=, chains={'A'}
New pocket with residues ['ANP'] -> bond=, chains={'A'}
New pocket with residues ['ORN'] -> bond=, chains={'A'}
New po

In [14]:
Path('test').mkdir(exist_ok=True)
analyze_protein('/home/nikolenko/work/Projects/LPCE/lpce/tests/test_data/3k8l.pdb', 'test/separated_complexes')


{'saved': 1, 'skipped': 0}

In [None]:
from pathlib import Path
from tqdm.auto import tqdm
Path('test').mkdir(exist_ok=True)
test_data_dir = Path('/home/nikolenko/work/Projects/LPCE/lpce/tests/test_data')
pdb_files = list(test_data_dir.glob('*.pdb'))
for pdb_file in tqdm(pdb_files, desc='Processing PDB files', total=len(pdb_files)):
    analyze_protein(
        str(pdb_file),
        f'test/separated_complexes_{pdb_file.stem}'
    )