In [2]:
import argparse
import collections
import csv
import json
import numpy as np
import os
import subprocess
import tqdm
from rdkit import Chem

CHEMBL_PATH = "/data/rsg/chemistry/swansonk/antibiotic_moa/data/pchembl_100.csv"

Molecule = collections.namedtuple(
    "Molecule", ["smiles", "targets"])

def filter_invalid_smiles(smiles):
    if not smiles:
        return True
    mol = Chem.MolFromSmiles(smiles)
    if mol.GetNumHeavyAtoms() == 0:
        return True
    return False


def load_dataset(path):
    """Return list of molecules --> attributes."""
    with open(path, "r") as f:
        reader = csv.DictReader(f)
        columns = reader.fieldnames
        smiles_column = columns[0]
        target_columns = columns[1:]

        # Read in all the dataset smiles.
        dataset = []
        num_lines = int(subprocess.check_output(["wc", "-l", path], encoding="utf8").split()[0])
        for row in tqdm.tqdm(reader, total=num_lines, desc="reading smiles"):
            smiles = row[smiles_column]
            if filter_invalid_smiles(smiles):
                continue
            datapoint = Molecule(smiles, {t: float(row[t]) for t in target_columns if row[t]})
            dataset.append(datapoint)

        return dataset
    
    
chembl_data = load_dataset(CHEMBL_PATH)

reading smiles: 100%|█████████▉| 458047/458048 [04:42<00:00, 1622.34it/s]


In [5]:
# Compute number of unique targets.
targets_to_vals = collections.defaultdict(list)
unique_smiles = set()
for mol in chembl_data:
    unique_smiles.add(mol.smiles)
    for t, value in mol.targets.items():
        targets_to_vals[t].append(value)

In [6]:
len(unique_smiles)

458047

In [7]:
len(chembl_data)

458047

In [8]:
len(targets_to_vals)

1499

In [87]:
filtered_targets = {}
for t, vals in targets_to_vals.items():
    if len(vals) < 300:
        continue
    if np.std(vals) < 0.5:
        continue
    filtered_targets[t] = vals

In [88]:
len(filtered_targets)

296

In [82]:
stds = [np.std(v) for v in filtered_targets.values() if np.std(v) > 0.5]

In [83]:
min(stds)

0.5025533693260913

In [84]:
max(stds)

1.890189337673948

In [85]:
found = None
for t, v in filtered_targets.items():
    if np.std(v) == min(stds):
        found = v
        break

In [86]:
len(stds)

367

In [81]:
found

[5.25,
 5.21,
 5.42,
 5.35,
 5.21,
 4.54,
 5.17,
 4.79,
 5.35,
 5.31,
 4.78,
 5.3,
 4.79,
 4.51,
 5.24,
 4.65,
 5.29,
 5.22,
 5.45,
 5.62,
 5.5,
 5.66,
 5.62,
 5.68,
 4.36,
 5.21,
 5.27,
 5.32,
 5.47,
 4.74,
 5.54,
 5.07,
 5.0,
 5.3,
 4.8,
 4.79,
 5.64,
 4.71,
 5.43,
 5.6,
 5.58,
 5.25,
 5.45,
 5.37,
 5.08,
 5.57,
 5.29,
 5.83,
 5.34,
 5.65,
 5.65,
 5.24,
 5.17,
 5.17,
 4.39,
 5.33,
 5.46,
 5.53,
 5.37,
 5.2,
 5.19,
 5.47,
 5.44,
 5.25,
 5.55,
 4.38,
 5.27,
 5.28,
 5.68,
 5.25,
 5.88,
 5.33,
 5.58,
 5.6,
 4.8,
 5.5,
 5.72,
 4.81,
 5.2,
 5.59,
 5.55,
 5.66,
 5.04,
 5.3,
 5.41,
 5.34,
 5.62,
 4.6,
 5.58,
 5.54,
 4.46,
 5.39,
 5.68,
 5.49,
 5.09,
 5.5,
 5.54,
 5.09,
 5.17,
 5.07,
 5.42,
 5.61,
 5.3,
 5.44,
 5.14,
 5.34,
 5.38,
 5.71,
 5.28,
 5.37,
 5.1,
 5.06,
 5.2,
 5.45,
 5.25,
 5.25,
 5.46,
 5.24,
 5.42,
 5.18,
 5.38,
 5.14,
 5.13,
 5.28,
 5.25,
 5.04,
 4.88,
 5.12,
 4.83,
 5.9,
 4.55,
 5.23,
 5.4,
 5.08,
 5.66,
 5.38,
 5.39,
 5.26,
 4.96,
 5.63,
 5.5,
 5.32,
 4.32,
 5.65,
 5.5,
 5.19,