In [1]:
import os
from glob import glob
from scipy.sparse.csgraph import minimum_spanning_tree, connected_components
from scipy.sparse import csr_matrix
from scipy.spatial import cKDTree
import numpy as np
from xbpy.rdutil import read_molecules, write_as_batches, position
import json
from rdkit import Chem

In [2]:
out_folder = "batch_test_out"
in_file = "../../../Thesis/XB_BB_Docking_Score/Data/Structure_Files/1_N-heterocycles/bromo-1-2-pyridazine-PreComputation.sdf"
similarity_threshold = 0.7
batch_size = 100

In [3]:
to_batch = list(read_molecules(in_file))

In [4]:
positions = np.array([position(mol.GetAtoms()) for mol in to_batch])
mean_positions = np.mean(positions, axis=1)

In [5]:
# check which positions we even need to compare to cut down on rmsd computations
tree = cKDTree(mean_positions)
incidence_matrix = tree.sparse_distance_matrix(tree, max_distance=similarity_threshold)

In [6]:
cols, rows = incidence_matrix.nonzero()
pairwise_rmsds = np.array([np.sqrt(np.mean((positions[cols[i]] - positions[rows[i]])**2)) for i in range(len(cols))])

In [7]:
incidence_matrix = csr_matrix((pairwise_rmsds[pairwise_rmsds < similarity_threshold], np.array([cols, rows])[:,pairwise_rmsds < similarity_threshold]), shape=incidence_matrix.shape)

In [8]:
mst_matrix = minimum_spanning_tree(incidence_matrix)

In [9]:
# identify connected components to correctly batch the molecules
n_components, labels = connected_components(mst_matrix, directed=False)

In [10]:
n_components

312

In [16]:
label_counts = np.bincount(labels)

In [17]:
# greedily fill up batches
batches = []
cur_batch_size = 0
cur_batch = []
for i in range(n_components):
    labelled_indices = np.where(labels == i)[0]
    for j in np.arange(0, label_counts[i], batch_size):
        end = min(j+batch_size, label_counts[i])
        difference = end - j
        cur_batch_size += difference
        if cur_batch_size > batch_size:
            batches.append(np.concatenate(cur_batch))
            cur_batch = [labelled_indices[j:end]]
            cur_batch_size = difference
        else:
            cur_batch.append(labelled_indices[j:end])
            cur_batch_size += difference
batches.append(np.concatenate(cur_batch))

In [18]:
batch_matrices = []
for batch in batches:
    batch_matrices.append(incidence_matrix[batch, :][:] != 0)

In [19]:
# write batches to files and create computational dependencies
paths_to_mols = {}
os.makedirs(out_folder, exist_ok=True)
mol_idx = 0
for i, batch, batch_matrix in zip(range(len(batches)), batches, batch_matrices):
    os.makedirs(os.path.join(out_folder, f"batch_{i}"), exist_ok=True)
    for mol_index in batch:
        out_path = os.path.join(f"batch_{i}", str(mol_index) + ".xyz")
        paths_to_mols[mol_idx] = out_path
        Chem.MolToXYZFile(to_batch[mol_index], os.path.join(out_folder, out_path))
        mol_idx += 1


In [20]:
# write dependencies for each batch
for i, batch in enumerate(batches):
    with open(os.path.join(out_folder, f"batch_{i}", "dependencies.json"), "w") as f:
        print(i, {str(mol_idx): [(paths_to_mols[j], incidence_matrix[mol_idx, j]) for j in incidence_matrix[mol_idx].nonzero()[1]] for mol_idx in batch}, f)
        json.dump({str(mol_idx): [(paths_to_mols[j], incidence_matrix[mol_idx, j]) for j in incidence_matrix[mol_idx].nonzero()[1]] for mol_idx in batch}, f)

0 {'0': [('batch_0/40.xyz', 0.4645059030188895), ('batch_1/170.xyz', 0.5931575640228958), ('batch_1/217.xyz', 0.34568662985514215), ('batch_1/255.xyz', 0.6814528127908829), ('batch_1/300.xyz', 0.46817316839611883), ('batch_2/401.xyz', 0.6637682243118155), ('batch_8/1082.xyz', 0.6741656232527417), ('batch_23/1321.xyz', 0.3536872473519422)], '1': [('batch_0/39.xyz', 0.46276825095387547), ('batch_0/42.xyz', 0.682868225453003), ('batch_1/171.xyz', 0.593112313575727), ('batch_1/218.xyz', 0.4056302232702872), ('batch_1/255.xyz', 0.5699539677053143), ('batch_1/301.xyz', 0.5344207915343886), ('batch_2/438.xyz', 0.30034504500758136), ('batch_3/507.xyz', 0.6695205347972171), ('batch_3/509.xyz', 0.6294590870523097), ('batch_5/670.xyz', 0.6113458615417646), ('batch_8/1031.xyz', 0.5844763329578699), ('batch_8/1082.xyz', 0.4922213900453553), ('batch_9/1178.xyz', 0.4805340477641579)], '2': [('batch_0/40.xyz', 0.5088836692092593), ('batch_1/169.xyz', 0.6997884683635088), ('batch_1/172.xyz', 0.59310169