### Process proteins

Process each of the DUD-E targets into the PyTorch `Data` objects saved to disk.

In [2]:
import os
import torch
import pickle
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree
from progressbar import progressbar
from biopandas.mol2 import PandasMol2
from torch_geometric.data import Data

<IPython.core.display.Javascript object>

In [3]:
# We will write each of the target Data files to the
# "raw" directory.
if not os.path.exists("../data/raw"):
    os.makedirs("../data/raw")

<IPython.core.display.Javascript object>

In [4]:
# Load all of the DUD-E target names.
all_targets = pd.read_csv("../data/dud-e_targets.csv").target_name.tolist()
all_targets = [target.lower() for target in all_targets]

<IPython.core.display.Javascript object>

#### Data quality checks

In [5]:
def read_in_file(target):
    """Read in the target PDB into a list."""
    with open(f"../data/pdb/{target}.pdb", "r") as f:
        return [entry.split() for entry in f.read().split("\n")][4:-3]


def check_conect_e(target):
    """Check if there is at least one bond recorded in the PDB."""
    if not [entry for entry in read_in_file(target) if entry[0] == "CONECT"]:
        print(f"{target} does not have bonds.")


def check_num_models(target):
    """Check the number of models in the PDB."""
    if len([entry for entry in read_in_file(target) if entry[0] == "MODEL"]) > 1:
        print(f"{target} has more than one model.")

<IPython.core.display.Javascript object>

In [6]:
for target in all_targets:
    check_conect_e(target)
    check_num_models(target)

aa2ar does not have bonds.
drd3 does not have bonds.


<IPython.core.display.Javascript object>

Since we are not using the original `CONECT` records to determine bonds, we will
not remove these two proteins from the `all_targets` list.

#### Atom and residue mappings

Here, we create the mapping between atom and residue names, and integers.
We need to encode these names before saving proteins in the `Data` format.

In [7]:
def read_in_file(target):
    """Get the atom and residue names in the target PDB."""
    with open(f"../data/pdb/{target}.pdb", "r") as f:
        pdb_text = [entry.split() for entry in f.read().split("\n")][5:-3]
        atoms = [entry[3:] for entry in pdb_text if entry[0] == "ATOM"]
    return [entry[0] for entry in atoms], [entry[-1] for entry in atoms]

<IPython.core.display.Javascript object>

In [8]:
# Create a list of all the atom and residue names
# in all targets.
all_residues = []
all_atoms = []
for target in all_targets:
    target_text = read_in_file(target)
    all_residues += target_text[0]
    all_atoms += target_text[1]

all_residues = sorted(np.unique(all_residues).tolist())
all_atoms = sorted(np.unique(all_atoms).tolist())

<IPython.core.display.Javascript object>

In [9]:
# Create mapping dictionaries out of these lists.
residue_mapping = dict(zip(all_residues, range(1, len(all_residues) + 1)))
atom_mapping = dict(zip(all_atoms, range(1, len(all_atoms) + 1)))

<IPython.core.display.Javascript object>

#### Process proteins

##### Bond-related methods

In [10]:
def get_bonds_cutoff(atom_coords, bond_cutoff, leafsize=32):
    """Get bond list using the proximity-based method."""
    # Create a KD tree out of the atoms - this will help
    # find an atom's nearest neighbors efficiently.
    KD_tree = cKDTree(atom_coords, leafsize=leafsize)
    processed_bonds = []
    # For each atom, find its neighbors within the cutoff,
    # and add the associated "bonds" to processed_bonds.
    for atom_index in range(atom_coords.shape[0]):
        index_atom_atom_coords = atom_coords[atom_index, :]
        neighbor_indices = KD_tree.query_ball_point(index_atom_atom_coords, bond_cutoff)
        # Remove the index of the current atom from the
        # neighbors list, to avoid inducing self-loops.
        neighbor_indices.remove(atom_index)
        # Add both "directions" of the bond to the list.
        processed_bonds += [
            [atom_index, neighbor_index] for neighbor_index in neighbor_indices
        ]
        processed_bonds += [
            [neighbor_index, atom_index] for neighbor_index in neighbor_indices
        ]
    return processed_bonds


def get_bonds_power(num_atoms, chemical_bonds, graph_power):
    """Get bond list using the graph power method."""
    # We will append the bonds induced by graph power to
    # the processed_bonds list, instead of appending them to
    # chemical_bonds directly - that would result in mistakes.
    processed_bonds = chemical_bonds.copy()
    # Create a dictionary which maps each atom index to a list
    # with the indices of all of its neighbors.
    chemical_bonds_dict = dict()
    for i in range(len(chemical_bonds)):
        try:
            chemical_bonds_dict[chemical_bonds[i][0]].append(chemical_bonds[i][1])
        except:
            chemical_bonds_dict[chemical_bonds[i][0]] = [chemical_bonds[i][1]]

    # For each atom, find its neighbors to the `graph_power`-th graph power.
    for atom_index in range(num_atoms):
        all_neighbors = []
        # First, we explore the neighbors of the atom's immediate
        # neighbors as the second graph-power neighbors. Next, we consider
        # their neighbors, and so on.
        neighbors_to_explore = chemical_bonds_dict[atom_index]
        for power in range(1, graph_power + 1):
            new_neighbors = set()
            for neighbor_atom_index in neighbors_to_explore:
                new_neighbors.update(chemical_bonds_dict[neighbor_atom_index])
            # Store the `power`-th graph power atom neighbors in the list.
            all_neighbors += neighbors_to_explore
            # Their neighbors are now the new neighbors to explore.
            neighbors_to_explore = list(new_neighbors)
            new_neighbors = set()
        all_neighbors = list(set(all_neighbors))
        try:
            all_neighbors.remove(atom_index)
        # Add the associated bonds to the bonds list.
        processed_bonds += [
            [atom_index, neighbor_index] for neighbor_index in all_neighbors
        ]
    return processed_bonds


def process_bonds(bonds):
    """Get bonds list from processed CONECT record list."""
    processed_bonds = []
    for i in range(len(bonds)):
        for j in range(1, len(bonds[i])):
            # Add both "directions" of the bond to the list, if
            # they do not already exist in it.
            if [bonds[i][0], bonds[i][j]] not in processed_bonds:
                processed_bonds.append([bonds[i][0], bonds[i][j]])
            if [bonds[i][j], bonds[i][0]] not in processed_bonds:
                processed_bonds.append([bonds[i][j], bonds[i][0]])
    return processed_bonds


def get_edge_distances(processed_bonds, coords):
    """Get L2 distance between each bond in processed bond list."""

    def get_distance(atom_a, atom_b):
        return np.linalg.norm(coords[atom_a, :] - coords[atom_b, :])

    edge_distances = []
    for i in range(len(processed_bonds)):
        edge_distances.append([get_distance(*processed_bonds[i])])
    return edge_distances


def get_edge_list(processed_bonds):
    """Convert processed bond list to edge list."""
    return np.hsplit(np.array(processed_bonds).transpose(), 1)[0]

<IPython.core.display.Javascript object>

##### Reduction-related methods

In [11]:
def get_ligand_coords(target):
    """Get the coordinates of the crystal ligand."""
    # Read in the DataFrame of the ligand, and return the
    # coordinates as a numpy array.
    ligand = PandasMol2().read_mol2(f"../data/unproc/{target}/crystal_ligand.mol2").df
    return ligand.iloc[:, 2:5].to_numpy()


def get_pocket_indices(target, target_coords, protein_cutoff):
    """Get the indices of target atoms that are within its pocket."""
    ligand_centroid = get_ligand_coords(target).mean(0)
    dist_from_centroid = lambda target_atom_coords: np.linalg.norm(
        target_atom_coords - ligand_centroid
    )
    target_atom_dists = np.apply_along_axis(dist_from_centroid, 1, target_coords)
    # Return the indices of the target atoms that are less than
    # the cutoff away from the centroid of the crystal ligand.
    return np.where(target_atom_dists <= protein_cutoff)[0], ligand_centroid


def filter_bonds(bonds, pocket_indices):
    """Get the bond list of bonds between atoms in the target's pocket."""
    # Keep only the bonds in the processed bonds list that are in the
    # target's pocket.
    filtered_bonds = [
        entry
        for entry in bonds
        if entry[0] in pocket_indices and entry[1] in pocket_indices
    ]
    # Remap the entry indices to be between 0-number of atoms remaining,
    # to correspond to their positions in the coordinates array.
    index_mapping = dict(zip(pocket_indices, range(len(pocket_indices))))
    return [
        [index_mapping[entry[0]], index_mapping[entry[1]]] for entry in filtered_bonds
    ]

<IPython.core.display.Javascript object>

In [12]:
def get_dist_from_ligand_centroid(ligand_centroid, atom_coords, atom_index):
    return np.linalg.norm(ligand_centroid - atom_coords[atom_index, :])

<IPython.core.display.Javascript object>

##### Putting it all together

In [13]:
def process_target(target, bond_mode, protein_cutoff, bond_cutoff, graph_power):
    """Get a Data object after processing the target's PDB."""
    # Open the target's PDB file and read in the lines in a list.
    with open(f"../data/pdb/{target}.pdb", "r") as f:
        pdb_text = [entry.split() for entry in f.read().split("\n")][5:-3]
    # Keep details on target atom elements and coordinates.
    atoms = np.array([entry[3:] for entry in pdb_text if entry[0] == "ATOM"])[
        :, [0, 2, 3, 4, 7]
    ]
    # Three atom attributes will be used, residue type, atom element type,
    # and distance from crystal ligand centroid. The third attribute will
    # be appended to this list after pocket indices are calculated.
    atom_attribs = atoms[:, [0, -1]]
    atom_attribs = [
        [residue_mapping[entry[0]], atom_mapping[entry[1]]]
        for entry in atom_attribs.tolist()
    ]
    atom_coords = np.array(atoms[:, [1, 2, 3]]).astype(float)
    # Get the bond list according to the bond_mode.
    if bond_mode in ["chemical", "graph_power"]:
        bonds = [entry[1:] for entry in pdb_text if entry[0] == "CONECT"]
        bonds = [[int(element) - 1 for element in entry] for entry in bonds]
        bonds = process_bonds(bonds)
        # Augment existing chemical bonds using graph power.
        if bond_mode == "graph_power":
            bonds = get_bonds_power(atom_coords.shape[0], bonds, graph_power)
    elif bond_mode == "bond_cutoff":
        bonds = get_bonds_cutoff(atom_coords, bond_cutoff)
    else:
        raise Exception("Invalid bond mode.")
    pocket_indices, ligand_centroid = get_pocket_indices(
        target, atom_coords, protein_cutoff
    )
    # Filter out all non-pocket atom information from
    # atom_attribs and atom_coords.
    atom_attribs = (np.array(atom_attribs)[pocket_indices, :]).tolist()
    atom_coords = atom_coords[pocket_indices, :]
    # Update atom_attribs with the third attribute.
    atom_attribs = [
        atom_attribs[i]
        + [get_dist_from_ligand_centroid(ligand_centroid, atom_coords, i)]
        for i in range(atom_coords.shape[0])
    ]
    bonds = filter_bonds(bonds, pocket_indices)
    edge_attributes = get_edge_distances(bonds, atom_coords)
    edge_list = get_edge_list(bonds)
    return Data(
        x=torch.tensor(atom_attribs),
        edge_index=torch.LongTensor(edge_list),
        edge_attr=torch.tensor(edge_attributes),
    )

<IPython.core.display.Javascript object>

In [14]:
def test_Data(target, fname):
    """Test the target Data object for correctness."""
    data = pd.read_pickle(fname)
    # The target's graph should neither be directed nor
    # contain self-loops.
    if data.is_directed():
        print(f"{target} graph is directed.")
    if data.contains_self_loops():
        print(f"{target} has self-loops.")

<IPython.core.display.Javascript object>

In [15]:
def process_all_targets(
    all_targets, bond_mode, protein_cutoff=10, bond_cutoff=None, graph_power=None
):
    """Save each target's Data object to disk, and check them for correctness."""
    for target in progressbar(all_targets):
        # Set the filename of the Data object in accordance to the bond determination
        # method.
        if bond_mode == "chemical":
            fname = f"../data/raw/{protein_cutoff}_{bond_mode}_{target}.pkl"
        elif bond_mode == "bond_cutoff":
            fname = (
                f"../data/raw/{protein_cutoff}_{bond_mode}_{bond_cutoff}_{target}.pkl"
            )
        elif bond_mode == "graph_power":
            fname = (
                f"../data/raw/{protein_cutoff}_{bond_mode}_{graph_power}_{target}.pkl"
            )
        try:
            with open(fname, "wb") as f:
                pickle.dump(
                    process_target(
                        target, bond_mode, protein_cutoff, bond_cutoff, graph_power
                    ),
                    f,
                )
        except Exception as e:
            print(f"{target} processing failed: {e}.")
        # Test the written Data object for correctness.
        test_Data(target, fname)

<IPython.core.display.Javascript object>

In [16]:
process_all_targets(all_targets, bond_mode="bond_cutoff", bond_cutoff=5)

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:882.)
  ptr = mask.nonzero().flatten()
100% (102 of 102) |######################| Elapsed Time: 0:01:39 Time:  0:01:39


<IPython.core.display.Javascript object>