In [1]:
import argparse
import csv
import collections
import subprocess
import tqdm
import os

import chemprop
import numpy as np
import rdkit

CHEMBL_PATH = "/data/scratch/fisch/third_party/chemprop/data/chembl.csv"


def filter_invalid_smiles(smiles):
    if not smiles:
        return True
    mol = rdkit.Chem.MolFromSmiles(smiles)
    if mol.GetNumHeavyAtoms() == 0:
        return True
    return False


def load_dataset(path, N):
    with open(path, "r") as f:
        reader = csv.DictReader(f)
        columns = reader.fieldnames
        smiles_column = columns[0]
        target_columns = columns[1:]

        # Keep track of property --> list of molecules (by active/inactive).
        property_to_smiles = collections.defaultdict(lambda: collections.defaultdict(list))

        # Read in all the dataset smiles.
        num_lines = int(subprocess.check_output(["wc", "-l", path], encoding="utf8").split()[0])
        for row in tqdm.tqdm(reader, total=num_lines, desc="reading smiles"):
            smiles = row[smiles_column]
            if filter_invalid_smiles(smiles):
                continue
            scaffold = chemprop.data.scaffold.generate_scaffold(smiles)
            for target in target_columns:
                value = row[target]
                if not value:
                    continue
                value = int(value)
                property_to_smiles[target][value].append((scaffold, smiles))

        # Filter properties with sufficient examples.
        valid_properties = {}
        for target, values in property_to_smiles.items():
            if len(set([scaffold for scaffold, _ in values[0]])) < N:
                continue
            if len(set([scaffold for scaffold, _ in values[1]])) < N:
                continue
            valid_properties[target] = values

        print("Initially kept %d of %d properties." % (len(valid_properties), len(property_to_smiles)))
        return valid_properties


def make_splits(dataset, num_test, num_val, num_train, N):
    targets = list(dataset.keys())
    np.random.shuffle(targets)

    def _choose(values, exclude_scaffolds):
        # Shuffle values.
        np.random.shuffle(values)

        # Filter to disjoint scaffolds.
        scaffolds = set()
        filtered = []
        for scaffold, smiles in values:
            if scaffold not in scaffolds and scaffold not in exclude_scaffolds:
                filtered.append((scaffold, smiles))
                scaffolds.add(scaffold)

        # Sample N values.
        if len(filtered) < N:
            return None
        samples = np.random.choice(len(filtered), N, replace=False)
        return [filtered[i] for i in samples]

    def _update_scaffolds(split, exclude_scaffolds):
        for target in split:
            for active in [0, 1]:
                for scaffold, _ in dataset[target][active]:
                    exclude_scaffolds.add(scaffold)

    # Filter overlaps.
    used = set()

    # Gather test properties.
    test = []
    filtered = 0
    for i, target in enumerate(targets):
        if len(test) == num_test:
            break
        include = True
        for active in [0, 1]:
            samples = _choose(dataset[target][active], used)
            if not samples:
                include = False
                continue
            dataset[target][active] = samples
        if include:
            test.append(target)
        else:
            filtered += 1

    print("Filtered %d overlaps from test." % filtered)
    targets = targets[i:]
    _update_scaffolds(test, used)

    # Gather val properties, without molecule overlap in test.
    val = []
    filtered = 0
    for i, target in enumerate(targets):
        if len(val) == num_val:
            break
        include = True
        for active in [0, 1]:
            samples = _choose(dataset[target][active], used)
            if not samples:
                include = False
                continue
            dataset[target][active] = samples
        if include:
            val.append(target)
        else:
            filtered += 1

    print("Filtered %d overlaps from val." % filtered)
    targets = targets[i:]
    _update_scaffolds(val, used)

    # Gather train properties, without molecule overlap in val/test.
    train = []
    filtered = 0
    for target in targets:
        if len(train) == num_train:
            break
        include = True
        for active in [0, 1]:
            samples = _choose(dataset[target][active], used)
            if not samples:
                include = False
                continue
            dataset[target][active] = samples
        if include:
            train.append(target)
        else:
            filtered += 1

    print("Filtered %d overlaps from train." % filtered)

    splits = {"val": val, "test": test, "train": train}
    for split, keys in splits.items():
        print("%s: %d" % (split, len(keys)))

    return splits
    
dataset = load_dataset("/data/scratch/fisch/third_party/chemprop/data/chembl.csv", 250)

reading smiles: 100%|█████████▉| 456331/456332 [13:10<00:00, 576.93it/s]


Initially kept 231 of 1310 properties.


In [32]:
import copy
splits = make_splits(copy.deepcopy(dataset), 25, 25, 200, 200)

Filtered 0 overlaps from test.
Filtered 2 overlaps from val.
Filtered 75 overlaps from train.
val: 25
test: 25
train: 104


In [21]:
70 * 200 * 2

28000

In [25]:
35 / 140

0.25

In [23]:
20 / 130

0.15384615384615385