In [1]:
import numpy as np
import prody
from pathlib import Path

import rdkit

import dgl
import torch

  from .autonotebook import tqdm as notebook_tqdm


# define hard-coded filepaths

In [2]:
train_index_path = '/home/ian/projects/mol_diffusion/ligdiff/data/PDBbind_processed/train_index.txt'
data_dir = '/home/ian/projects/mol_diffusion/ligdiff/data/PDBbind/refined-set'

data_dir = Path(data_dir)

# get a list of all pdb_ids in the training set

In [3]:
with open(train_index_path, 'r') as f:
    pdb_ids = [line.strip() for line in f]

print(len(pdb_ids))

5316


# get atom coordinates, charges, and types from a pdb id

## get prody object from a pdb id

In [4]:
def parse_protein(pdb_id: str, data_dir: Path = data_dir):
    pdb_path = data_dir / pdb_id / f'{pdb_id}_protein.pdb'
    pdb_path = str(pdb_path)
    protein_atoms = prody.parsePDB(pdb_path)
    return protein_atoms

protein_atoms = parse_protein(pdb_ids[0])

@> 4858 atoms and 1 coordinate set(s) were parsed in 0.03s.


## get ligand atoms, atom types, and charges from a pdb id 

In [5]:
def parse_ligand(pdb_id: str, data_dir: Path = data_dir):
    # construct path to ligand file
    ligand_path = data_dir / pdb_id / f'{pdb_id}_ligand.sdf'

    # read ligand into a rdkit mol
    suppl = rdkit.Chem.SDMolSupplier(str(ligand_path), sanitize=False, removeHs=False)
    ligands = list(suppl)
    if len(ligands) > 1:
        raise NotImplementedError('Multiple ligands found. Code is not written to handle multiple ligands.')
    ligand = ligands[0]

    # get atom positions
    ligand_conformer = ligand.GetConformer()
    atom_positions = ligand_conformer.GetPositions()

    # get atom types and charges
    atom_types = []
    atom_charges = []
    for atom in ligand.GetAtoms():
        atom_types.append(atom.GetAtomicNum())
        atom_charges.append(atom.GetFormalCharge()) # equibind code calls ComputeGasteigerCharges(mol), not sure why/if necessary

    return ligand, atom_positions, atom_types, atom_charges


ligand, lig_atom_positions, lig_atom_types, lig_atom_charges = parse_ligand(pdb_ids[0])

## get the positons, types, and charges of all atoms within some cutoff radius of the ligand center of mass

In [6]:
def get_pocket_atoms(protein_atoms, ligand_atom_positions, pocket_cutoff=8):
    # note that pocket_cutoff is in units of angstroms

    # TODO: maybe it would be better to find all atoms that are within some distance of any ligand atom
    # this is more computataionally expensive but might give more accurate models

    ligand_com = lig_atom_positions.mean(axis=0, keepdims=False)

    pocket_atoms = protein_atoms.select(f'within {pocket_cutoff} of center', center=ligand_com)
    pocket_atom_positions = pocket_atoms.getCoords()
    pocket_atom_types = pocket_atoms.getElements()
    pocket_atom_charges = pocket_atoms.getCharges()
    return pocket_atom_positions, pocket_atom_types, pocket_atom_charges


protein_atoms = parse_protein(pdb_ids[0])
ligand, lig_atom_positions, lig_atom_types, lig_atom_charges = parse_ligand(pdb_ids[0])
pocket_atom_positions, pocket_atom_types, pocket_atom_charges = get_pocket_atoms(protein_atoms, lig_atom_positions)

@> 4858 atoms and 1 coordinate set(s) were parsed in 0.03s.


# build binding pocket graph

1. one-hot encode atom types
2. determine edges for KNN graph

In [7]:
def build_pocket_graph(atom_positions, atom_features, k=3, edge_algorithm='bruteforce'):
    # atom_positions is an array of shape (N, 3) where N is the number of atoms
    # atom_features is an array of shape (N, d) where N is the number of atoms and d is the size of the feature vector for each atom


    # construct KNN graph
    g = dgl.knn_graph(atom_positions, k=k, algorithm=edge_algorithm, dist='euclidean')
    return g

protein_atoms = parse_protein(pdb_ids[0])
ligand, lig_atom_positions, lig_atom_types, lig_atom_charges = parse_ligand(pdb_ids[0])
pocket_atom_positions, pocket_atom_types, pocket_atom_charges = get_pocket_atoms(protein_atoms, lig_atom_positions)

g = build_pocket_graph(pocket_atom_positions, None)

@> 4858 atoms and 1 coordinate set(s) were parsed in 0.02s.


AttributeError: 'numpy.ndarray' object has no attribute 'dim'