In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import numpy as np
import pandas as pd
import pickle as pkl
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from sklearn.svm import LinearSVR
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

from astartes import train_val_test_split

from scipy.stats.mstats import gmean

Define helper functions

In [3]:
params = Chem.SmilesParserParams()
params.removeHs = False

Function taken from Chemprop: https://github.com/chemprop/chemprop/blob/master/chemprop/features/features_generators.py

In [4]:
MORGAN_RADIUS = 2
MORGAN_NUM_BITS = 2048
def morgan_counts_features_generator(mol,
                                     radius=MORGAN_RADIUS,
                                     num_bits=MORGAN_NUM_BITS):
    """
    Generates a counts-based Morgan fingerprint for a molecule.
    :param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
    :param radius: Morgan fingerprint radius.
    :param num_bits: Number of bits in Morgan fingerprint.
    :return: A 1D numpy array containing the counts-based Morgan fingerprint.
    """
    mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
    features_vec = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=num_bits)
    features = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(features_vec, features)

    return features

# Read in the data
- This csv file comes from `data.tar.gz` from [Chemprop](https://github.com/chemprop/chemprop). It stores the QM9 data which was published in the following publication: Ramakrishnan, R., Dral, P.O., Rupp, M. and Von Lilienfeld, O.A. "Quantum Chemistry Structures and Properties of 134 kilo Molecules. In: Sci. Data 1.1 (2014), pp.1-7. [link](https://www.nature.com/articles/sdata201422) 

In [5]:
CSV_PATH = 'qm9.csv'
df = pd.read_csv(CSV_PATH)
df

Unnamed: 0,smiles,mu,alpha,homo,lumo,gap,r2,zpve,cv,u0,u298,h298,g298
0,C,0.0000,13.21,-0.3877,0.1171,0.5048,35.3641,0.044749,6.469,-40.478930,-40.476062,-40.475117,-40.498597
1,N,1.6256,9.46,-0.2570,0.0829,0.3399,26.1563,0.034358,6.316,-56.525887,-56.523026,-56.522082,-56.544961
2,O,1.8511,6.31,-0.2928,0.0687,0.3615,19.0002,0.021375,6.002,-76.404702,-76.401867,-76.400922,-76.422349
3,C#C,0.0000,16.28,-0.2845,0.0506,0.3351,59.5248,0.026841,8.574,-77.308427,-77.305527,-77.304583,-77.327429
4,C#N,2.8937,12.99,-0.3604,0.0191,0.3796,48.7476,0.016601,6.278,-93.411888,-93.409370,-93.408425,-93.431246
...,...,...,...,...,...,...,...,...,...,...,...,...,...
133880,C1C2C3C4C5OC14C5N23,1.6637,69.37,-0.2254,0.0588,0.2842,760.7472,0.127406,23.658,-400.633868,-400.628599,-400.627654,-400.663098
133881,C1N2C3C2C2C4OC12C34,1.2976,69.52,-0.2393,0.0608,0.3002,762.6354,0.127495,23.697,-400.629713,-400.624444,-400.623500,-400.658942
133882,C1N2C3C4C5C2C13CN45,1.2480,73.60,-0.2233,0.0720,0.2953,780.3553,0.140458,23.972,-380.753918,-380.748619,-380.747675,-380.783148
133883,C1N2C3C4C5CC13C2C45,1.9576,77.40,-0.2122,0.0881,0.3003,803.1904,0.152222,24.796,-364.720374,-364.714974,-364.714030,-364.749650


In [6]:
df.describe()

Unnamed: 0,mu,alpha,homo,lumo,gap,r2,zpve,cv,u0,u298,h298,g298
count,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0
mean,2.706037,75.191296,-0.239977,0.011124,0.2511,1189.52745,0.148524,31.600676,-411.543985,-411.535513,-411.534569,-411.577397
std,1.530394,8.187793,0.022131,0.046936,0.047519,279.757172,0.033274,4.062471,40.06023,40.060012,40.060012,40.060741
min,0.0,6.31,-0.4286,-0.175,0.0246,19.0002,0.015951,6.002,-714.568061,-714.560153,-714.559209,-714.602138
25%,1.5887,70.38,-0.2525,-0.0238,0.2163,1018.3226,0.125289,28.942,-437.913936,-437.905942,-437.904997,-437.947682
50%,2.5,75.5,-0.241,0.012,0.2494,1147.5858,0.148329,31.555,-417.864758,-417.857351,-417.856407,-417.895731
75%,3.6361,80.52,-0.2287,0.0492,0.2882,1308.8166,0.17115,34.276,-387.049166,-387.039746,-387.038802,-387.083279
max,29.5564,196.62,-0.1017,0.1935,0.6221,3374.7532,0.273944,46.969,-40.47893,-40.476062,-40.475117,-40.498597


# Random splits

In [7]:
RANDOM_SPLITS = []
sampler='random'
for seed in range(5):
    # create 85:5:10 data split
    _, _, _, train_indices, val_indices, test_indices = train_val_test_split(np.arange(len(df)),
                                                                    train_size=0.85,
                                                                    val_size=0.05,
                                                                    test_size=0.1,
                                                                    sampler=sampler,
                                                                    random_state=seed,
                                                                    return_indices=True,
                                                                   )
    print(len(train_indices), len(val_indices), len(test_indices), f'first val index {val_indices[0]}',  f'first test index {test_indices[0]}')
    RANDOM_SPLITS.append([train_indices, val_indices, test_indices])

113802 6694 13389 first val index 74350 first test index 83476
113802 6694 13389 first val index 61913 first test index 116029
113802 6694 13389 first val index 1528 first test index 8964
113802 6694 13389 first val index 81899 first test index 53153
113802 6694 13389 first val index 94694 first test index 45398


In [8]:
with open('QM9_splits/QM9_splits_random.pkl', 'wb') as f:
    pkl.dump(RANDOM_SPLITS, f)

# Scaffold splits

In [9]:
SCAFFOLD_SPLITS = []
sampler='scaffold'
for seed in range(5):
    # create 85:5:10 data split
    _, _, _, train_labels, val_labels, test_labels, train_indices, val_indices, test_indices = train_val_test_split(df.smiles.values,
                                                                    train_size=0.85,
                                                                    val_size=0.05,
                                                                    test_size=0.1,
                                                                    sampler=sampler,
                                                                    random_state=seed,
                                                                    return_indices=True,
                                                                   )
    print(len(train_indices), len(val_indices), len(test_indices), f'first val index {val_indices[0]}',  f'first test index {test_indices[0]}')
    SCAFFOLD_SPLITS.append([train_indices, val_indices, test_indices])

/Users/kevin/Dropbox (MIT)/code/astartes/astartes/samplers/extrapolation/scaffold.py:47: NoMatchingScaffold: No matching scaffold was found for the 13998 molecules corresponding to indices {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 98313, 98314, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 98337, 98338, 98339, 98340, 98341, 98342, 48, 98344, 98345, 98346, 98347, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 98380, 98381, 98382, 98383, 98384, 98385, 98386, 98387, 98388, 98389, 98390, 98391, 98392, 98393, 98394, 98395, 98396, 98397, 98398, 98399, 98400, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 101604, 177, 178, 179, 121918, 183, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 2

113803 6694 13388 first val index 18766 first test index 127816
113803 6694 13388 first val index 43439 first test index 40559
113803 6694 13388 first val index 44692 first test index 38869
113803 6694 13388 first val index 46938 first test index 37655
113803 6694 13388 first val index 662 first test index 38999


In [10]:
with open('QM9_splits/QM9_splits_scaffold.pkl', 'wb') as f:
    pkl.dump(SCAFFOLD_SPLITS, f)

# KMeans

### Featurize the data using morgan fingerprint with standard settings

In [11]:
morgan_fps = np.zeros((len(df), 2048))
for i, row in df.iterrows():
    rmol = Chem.MolFromSmiles(row.smiles, params)
    morgan = morgan_counts_features_generator(rmol)
    morgan_fps[i, :] = morgan

In [12]:
KMEANS_SPLITS = []
sampler='kmeans'
for seed in range(5):
    # create 85:5:10 data split
    _, _, _, _, _, _, train_indices, val_indices, test_indices = train_val_test_split(morgan_fps,
                                                                    train_size=0.85,
                                                                    val_size=0.05,
                                                                    test_size=0.1,
                                                                    sampler=sampler,
                                                                    hopts={"n_clusters": 100},
                                                                    random_state=seed,
                                                                    return_indices=True,
                                                                   )
    print(len(train_indices), len(val_indices), len(test_indices), f'first val index {val_indices[0]}',  f'first test index {test_indices[0]}')
    KMEANS_SPLITS.append([train_indices, val_indices, test_indices])

113912 6668 13305 first val index 50 first test index 411
113980 6540 13365 first val index 96 first test index 2019
113974 6554 13357 first val index 101 first test index 139
114011 6644 13230 first val index 104 first test index 42
114325 6493 13067 first val index 10 first test index 2216


In [13]:
with open('QM9_splits/QM9_splits_kmeans.pkl', 'wb') as f:
    pkl.dump(KMEANS_SPLITS, f)