In [3]:
from tqdm import tqdm
import pickle
import pandas as pd
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from rdkit.Chem import AllChem
from rdkit import DataStructs, Chem

In [11]:
store = pd.HDFStore('/home/mila/v/vincent.quirion/gflownetdata/mols/data/docked_mols.h5', 'r')
df = store.select('df')

mols = []

for i in tqdm(range(len(df))):
    mols.append(df.iloc[i].name)
store.close()

100%|██████████| 316936/316936 [00:34<00:00, 9273.69it/s]


In [None]:
from rdkit import Chem
import networkx as nx
from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext

ctx = FragMolBuildingEnvContext()
fragmols = list(enumerate(ctx.frags_mol))
# Largest fragment first
fragmols = sorted(fragmols, key=lambda x: -x[1].GetNumAtoms())

def recursive_decompose(m, all_matches, a2f, frags, bonds, max_depth=9, numiters=None):
    if numiters is None:
        numiters = [0]
    numiters[0] += 1
    if numiters[0] > 1_000:
        raise ValueError('too many iterations')
    if max_depth == 0 or len(a2f) == m.GetNumAtoms():
        # try to make a mol, does it work?
        # Did we match all the atoms?
        if len(a2f) < m.GetNumAtoms():
            return None
        # graph is a tree, e = n - 1
        if len(bonds) != len(frags) - 1:
            return None
        g = nx.Graph()
        g.add_nodes_from(range(len(frags)))
        g.add_edges_from([(i[0], i[1]) for i in bonds])
        assert nx.is_connected(g), 'Somehow we got here but fragments dont connect?'
        for fi, f in enumerate(frags):
            g.nodes[fi]['v'] = f
        for a, b, stemidx_a, stemidx_b, _, _ in bonds:
            g.edges[(a, b)][f'{a}_attach'] = stemidx_a
            g.edges[(a, b)][f'{b}_attach'] = stemidx_b
        m2 = ctx.graph_to_mol(g)
        if m2.HasSubstructMatch(m) and m.HasSubstructMatch(m2):
            return g
        return None
    for fragidx, frag in fragmols:
        # Some fragments have symmetric versions, so we need all matches up to isomorphism!
        matches = all_matches[fragidx]
        for match in matches:
            if any(i in a2f for i in match):
                continue
            # Verify that atoms actually have the same charge
            if any(frag.GetAtomWithIdx(ai).GetFormalCharge() != m.GetAtomWithIdx(bi).GetFormalCharge()
                   for ai, bi in enumerate(match)):
                continue
            new_frag_idx = len(frags)
            new_frags = frags + [fragidx]
            new_a2f = {**a2f, **{i: (fi, new_frag_idx) for fi, i in enumerate(match)}}
            possible_bonds = []
            is_valid_match = True
            # Is every atom that has a bond outside of this fragment also a stem atom?
            for fi, i in enumerate(match):
                for j in m.GetAtomWithIdx(i).GetNeighbors():
                    j = j.GetIdx()
                    if j in match:
                        continue
                    # There should only be single bonds between fragments
                    if m.GetBondBetweenAtoms(i, j).GetBondType() != Chem.BondType.SINGLE:
                        is_valid_match = False
                        break
                    # At this point, we know (i, j) is a single bond that goes outside the fragment
                    # so we check if the fragment we chose has that atom as a stem atom
                    if fi not in ctx.frags_stems[fragidx]:
                        is_valid_match = False
                        break
                if not is_valid_match:
                    break
            if not is_valid_match:
                continue
            for this_frag_stemidx, i in enumerate([match[s] for s in ctx.frags_stems[fragidx]]):
                for j in m.GetAtomWithIdx(i).GetNeighbors():
                    j = j.GetIdx()
                    if j in match:
                        continue
                    if m.GetBondBetweenAtoms(i, j).GetBondType() != Chem.BondType.SINGLE:
                        continue
                    # Make sure the neighbor is part of an already identified fragment
                    if j in a2f and a2f[j] != new_frag_idx:
                        other_frag_atomidx, other_frag_idx = a2f[j]
                        try:
                            # Make sure that fragment has that atom as a stem atom
                            other_frag_stemidx = ctx.frags_stems[frags[other_frag_idx]].index(other_frag_atomidx)
                        except ValueError as e:
                            continue
                        # Make sure that that fragment's stem atom isn't already used
                        for b in bonds + possible_bonds:
                            if b[0] == other_frag_idx and b[2] == other_frag_stemidx:
                                break
                            if b[1] == other_frag_idx and b[3] == other_frag_stemidx:
                                break
                            if b[0] == new_frag_idx and b[2] == this_frag_stemidx:
                                break
                            if b[1] == new_frag_idx and b[3] == this_frag_stemidx:
                                break
                        else:
                            possible_bonds.append((other_frag_idx, new_frag_idx, other_frag_stemidx, this_frag_stemidx, i, j))
            new_bonds = bonds + possible_bonds
            dec = recursive_decompose(m, all_matches, new_a2f, new_frags, new_bonds, max_depth-1, numiters)
            if dec:
                return dec
def f(smi):
    m = Chem.MolFromSmiles(smi)
    all_matches = {}
    for fragidx, frag in fragmols:
        all_matches[fragidx] = m.GetSubstructMatches(frag, uniquify=False)
    try:
        g = recursive_decompose(m, all_matches, {}, [], [], 9)
    except:
        g = None

In [4]:
with open('data_files/training_set_graphs.pkl', 'rb') as file:
    graphs = pickle.load(file)
valid_idxs = [i for i, g in enumerate(graphs) if g is not None]

In [12]:
fp_list = []
for i in tqdm(valid_idxs):
    fp = AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(mols[i]), 2, nBits=1024)
    arr = np.zeros((1,), int)
    DataStructs.ConvertToNumpyArray(fp, arr)
    fp_list.append(arr)

df = pd.DataFrame(fp_list)

100%|██████████| 246983/246983 [02:35<00:00, 1591.94it/s]


KeyboardInterrupt: 

In [56]:
def kmeans_cluster(train_list, num_clusters, sample_size=None):
    arr = np.array(train_list, dtype=np.float16)

    km = MiniBatchKMeans(n_clusters=num_clusters, random_state=0, batch_size=600)
    km.fit(arr)
    return km

In [57]:
km = kmeans_cluster(fp_list, num_clusters=25)



In [58]:
np.bincount(km.labels_)

array([ 2081, 16450,  6111, 17022,  6717,  4397, 10927, 12753,   319,
       17279,  7227,  4008,  5367, 11471, 12715, 13730, 18289,  5264,
       21262,  5847,  9245, 13792,  5002,  8329, 11379])

In [59]:
with open('data_files/25_kmeans_model.pkl', 'wb') as f:
    pickle.dump(km, f)