In [5]:
from xbpy import rdutil
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.ForceField import rdForceField
import pymolviz as pmv
import numpy as np

In [6]:
receptor = next(rdutil.read_molecules("AF3_1_prepped.pdb"))
ligand = next(rdutil.read_molecules("rank1_confidence-1.57.sdf"))

In [7]:
ligand = rdutil.proximity_bond(ligand)

In [8]:
rotated= rdutil.resolve_small_clashes(ligand,receptor, distance_threshold=1.4)

sampled angle 0.0
sampled angle 0.0
sampled angle 0.0
sampled angle 0.17951958020513104
sampled angle 0.35903916041026207
sampled angle 0.5385587406153931
sampled angle 0.7180783208205241
sampled angle 0.8975979010256552
sampled angle 1.0771174812307862
sampled angle 1.2566370614359172
sampled angle 1.4361566416410483
sampled angle 1.6156762218461793
sampled angle 1.7951958020513104
sampled angle 1.9747153822564414
sampled angle 2.1542349624615724
sampled angle 2.3337545426667035
rotation found
no rotation found
sampled angle 0.0
sampled angle 0.17951958020513104
rotation found
sampled angle 0.0
sampled angle 0.17951958020513104
sampled angle 0.35903916041026207
sampled angle 0.5385587406153931
sampled angle 0.7180783208205241
sampled angle 0.8975979010256552
sampled angle 1.0771174812307862
sampled angle 1.2566370614359172
sampled angle 1.4361566416410483
sampled angle 1.6156762218461793
sampled angle 1.7951958020513104
sampled angle 1.9747153822564414
sampled angle 2.1542349624615724

In [5]:
rdutil.write_molecules(rotated, "rotated.sdf")

While rdkit technically should support this, practically it is a piece of cr*p and consistently fails for no apparent reason. Thus we write a simple script to resolve clashes by rotating rotatable bonds until clashes are resolved.

In [6]:
# first we construct the kinematic chain
rigid_components, adjacency_matrix, bonds_to_atoms = rdutil.get_kinematic_graph(ligand)

In [7]:
kinematic_weights = rdutil.kinematic_chain_weights(ligand, rigid_components, return_node_weights=False)

In [8]:
# we arbitrarily chose some bond
bond = list(bonds_to_atoms.values())[10]

In [9]:
bond_atoms = [ligand.GetAtomWithIdx(bond[0]), ligand.GetAtomWithIdx(bond[1])]
positions = rdutil.position(bond_atoms)
pmv.Points(positions, color=["green", "red"], name = "bond").write("selected_bond.py")

In [10]:
selected_rest_indices = rdutil.get_bond_connected_atoms(ligand, bond[1], bond[0])

In [11]:
selected_atoms = [ligand.GetAtomWithIdx(int(i)) for i in selected_rest_indices]
positions = rdutil.position(selected_atoms)
pmv.Points(positions, color="blue", name = "rest_atoms").write("rest_atoms.py")

In [12]:
atom_to_component = {}
for i, rigid_component in enumerate(rigid_components):
    for atom in rigid_component:
        atom_to_component[atom] = i

In [13]:
# cosntruct kd_tree for receptor

from scipy.spatial import cKDTree
receptor_coords = rdutil.position(receptor)
receptor_kd_tree = cKDTree(receptor_coords)

In [14]:
distance_threshold = 1.3

In [15]:
# identify clashing components
clashing_components = set()

ligand_positions = rdutil.position(ligand)
clashing_indices = np.argwhere(np.array(receptor_kd_tree.query_ball_point(ligand_positions, distance_threshold, return_length=True)) > 0).flatten()
print(clashing_indices)

[27 51 63 91 92 93]


In [16]:
# mark clashing
pmv.Points(rdutil.positions([ligand.GetAtomWithIdx(int(i)) for i in clashing_indices]), color="red", name = "clashing").write("clashing.py")

In [17]:
# identify clashing components
clashing_components = set()
for clashing_index in clashing_indices:
    clashing_components.add(atom_to_component[clashing_index])

# mark clashing components
clashing_atoms = [ligand.GetAtomWithIdx(int(i)) for clashing_component in clashing_components for i in rigid_components[clashing_component]]
clashing_positions = rdutil.position(clashing_atoms)

In [18]:
# make adjacency matrix only point into heaviest component
for weight in kinematic_weights:
    if weight[0] in clashing_components:
        weight[1] = 0
    else:
        weight[1] = 1

In [19]:
heavier_adjacency_matrix = np.zeros_like(adjacency_matrix)
for i in range(kinematic_weights.shape[0]):
    for j in range(kinematic_weights.shape[1]):
        if i < j:
            continue
        if adjacency_matrix[i, j] == 1:
            if kinematic_weights[i, j] >= kinematic_weights[j, i]:
                heavier_adjacency_matrix[j, i] = 1
            else:
                heavier_adjacency_matrix[i, j] = 1

In [20]:
heavier_adjacency_matrix

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0],
       [0, 1, 0, 0, 

In [21]:
positions = []
for bond in np.argwhere(heavier_adjacency_matrix != 0):
    positions.extend(rdutil.position([ligand.GetAtomWithIdx(int(idx)) for idx in bonds_to_atoms[tuple(bond)]]))

pmv.Arrows(np.array(positions).reshape(-1, 2, 3), name = "heavier_adjacency", linewidth=0.33).write("heavier_adjacency.py")

ColorMap.py:180 Infered color [0.96779756 0.44127456 0.53581032] and alpha 1 from value (0.9677975592919913, 0.44127456009157356, 0.5358103155058701)


In [22]:
(np.ones((4,4)) * 2) ** 3

array([[8., 8., 8., 8.],
       [8., 8., 8., 8.],
       [8., 8., 8., 8.],
       [8., 8., 8., 8.]])

In [23]:
np.setdiff1d(clashing_indices, np.array([1,2,3]))

array([27, 51, 63, 91, 92, 93])

In [24]:
clashing_indices

array([27, 51, 63, 91, 92, 93])

In [25]:
[rigid_components[clashing_component] for clashing_component in clashing_components]

[(8, 35, 36, 39, 88, 89, 90, 91, 92, 93),
 (34, 63),
 (18, 27, 62, 65, 72, 85),
 (51,)]

In [26]:
max_depth = 4
found_points = []
ligand_positions = rdutil.position(ligand)
ligand_adjacency_matrix = Chem.GetAdjacencyMatrix(ligand)
rotated_mol = Chem.Mol(ligand)
# for each clashing component we identify the bond at max_depth or first branching bond
for clashing_component in clashing_components:
    next_matrix = np.array(heavier_adjacency_matrix)
    cur_depth = 0
    cur_other_component = next_matrix[clashing_component]
    path = []
    path.append(cur_other_component)
    while(cur_depth < max_depth) and (np.sum(adjacency_matrix[cur_other_component.astype(bool)]) < 3):
        next_matrix = next_matrix @ heavier_adjacency_matrix
        cur_depth += 1
        cur_other_component = next_matrix[clashing_component]
        path.append(cur_other_component)
    print("path: ", [np.argwhere(i) for i in path])
    if sum(cur_other_component) == 1:
        # we dont care for all other clashing atoms at the moment:
        ignored_mask = np.zeros(ligand.GetNumAtoms(), dtype=bool)
        ignored_mask[np.setdiff1d(clashing_indices, rigid_components[clashing_component])] = True
        # now we start sampling until we hit our component again
        def rotate_mol(mol, other_component_mask_index, angle):
            other_component_mask = path[other_component_mask_index]
            if other_component_mask_index == 0:
                # check for clashes
                mol_positions = rdutil.position(mol)
                if np.any((np.array(receptor_kd_tree.query_ball_point(mol_positions, distance_threshold, return_length=True)) > 0)[~ignored_mask]):
                    return None
                else:
                    # check self clashes
                    distances = np.linalg.norm(mol_positions[:, None] - mol_positions[None, :], axis=-1)
                    np.fill_diagonal(distances, np.inf)
                    # mask adjacent atoms
                    distances[ligand_adjacency_matrix.astype(bool)] = np.inf
                    if np.any(distances < distance_threshold):
                        return None
                    return mol
            else:
                # try rotations until we hit our component:
                # next component is determined by moving heavier_adjacency_matrix backwards
                next_component_idx = other_component_mask_index - 1
                next_component = path[next_component_idx]
                bond_indices = bonds_to_atoms[(np.argwhere(other_component_mask)[0][0], np.argwhere(next_component)[0][0])]
                for angle in np.linspace(0, 2*np.pi, 36):
                    #print(mol)
                    rot_mol = rdutil.rotate_around_bond(mol, *bond_indices, angle)
                    rot_mol = rotate_mol(rot_mol, next_component_idx, angle)
                    if rot_mol is not None:
                        return rot_mol
                return None
        next_mol = rotate_mol(rotated_mol, len(path) -1, 0)
        if next_mol is None:
            print("no rotation found")
        else:
            rotated_mol = next_mol


path:  [array([[19]]), array([[6]])]
no rotation found
path:  [array([[7]]), array([[19]]), array([[6]])]


KeyboardInterrupt: 

In [None]:
[(i,c) for i,c in enumerate(rigid_components)]

[(0, (0, 13, 30, 84)),
 (1, (1, 3, 15, 17, 19, 48, 52, 60, 67, 70, 73)),
 (2, (2, 59, 74, 79)),
 (3, (4, 12, 33, 106)),
 (4, (5, 77)),
 (5, (6, 28, 83)),
 (6, (7, 86)),
 (7, (8, 35, 36, 39, 88, 89, 90, 91, 92, 93)),
 (8, (9, 24, 40, 96)),
 (9, (10, 25, 47, 71)),
 (10, (11, 50)),
 (11, (14, 41, 45, 49, 53, 54, 55, 57, 58, 61, 64)),
 (12, (16, 22, 43, 46, 100, 101, 102, 103, 104, 105)),
 (13, (18, 27, 62, 65, 72, 85)),
 (14, (20, 42, 82)),
 (15, (21, 98)),
 (16, (23, 76)),
 (17, (26, 44, 99)),
 (18, (29, 81)),
 (19, (31, 32, 87)),
 (20, (34, 63)),
 (21, (37, 56, 75, 80)),
 (22, (38, 95)),
 (23, (51,)),
 (24, (66, 69, 94)),
 (25, (68,)),
 (26, (78, 97))]

In [None]:
rdutil.write_molecules(rotated_mol, "rotated.sdf")