In [13]:
import os
import sys
import time
import rdkit
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from tqdm.notebook import tqdm

sys.path.insert(0, '../../')

from jtnn.mol_tree import *

In [14]:
scale_to_pm9 = 27.21

## Helper functions

In [15]:
lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

In [16]:
def progressbar(it, prefix="", size=60, file=sys.stdout):
    count = len(it)
    def show(j):
        x = int(size*j/count)
        file.write("%s[%s%s] %i/%i\r" % (prefix, "#"*x, "."*(size-x), j, count))
        file.flush()   
        
    show(0)
    
    for i, item in enumerate(it):
        yield item
        show(i+1)
        
    file.write("\n")
    file.flush()
    
def check_in_dict(smiles, cset):
    mol = MolTree(smiles)
    for c in mol.nodes:
        if not c.smiles in cset:
            return False
    return True

## WhitePaper

In [4]:
print("Full Zinc Dataset info")
print("Train molecules", end=": ")
with open("./zinc/train.txt", "r") as f:
    print(len(f.readlines()))
    
print("Test molecules", end=": ")
with open("./zinc/test.txt", "r") as f:
    print(len(f.readlines()))

Full Zinc Dataset info
Train molecules: 220011
Test molecules: 5000


In [5]:
print("Full QDB9 Dataset info")
print("Train molecules", end=": ")
with open("./qdb9/backup/train.txt", "r") as f:
    print(len(f.readlines()))
    
print("Test molecules", end=": ")
with open("./qdb9/backup/test.txt", "r") as f:
    print(len(f.readlines()))

Full QDB9 Dataset info
Train molecules: 67591
Test molecules: 45116


In [6]:
print("Train molecules", end=": ")
with open("./zinc/vocab.txt", "r") as f:
    vocab = [x.replace("\n", "") for x in f.readlines()]
    print(len(vocab))
    
cset = set(vocab)

Train molecules: 780


In [7]:
with open("./qdb9/new_qdb9/all_intersect_zinc.txt", "r") as f:
    with open("./qdb9/new_qdb9/all_intersect_zinc_second.txt", "w") as w:
        old_lines = f.readlines()
        
        try:
            while True:
                old_lines.remove("\n")
        except ValueError:
            pass 
        
        w.write("".join(old_lines))

In [None]:
print("Creating all file without QDB9 unique elements")
with open("./qdb9/new_qdb9/all.txt", "r") as f:
    with open("./qdb9/new_qdb9/all_intersect_zinc.txt", "w") as w:
        for i, line in tqdm(enumerate(f.readlines())):
            try:
                if line == "\n":
                    pass
                else:
                    is_valid = True
                    smiles = line.split()[0]
                    mol = MolTree(smiles)
                    for c in mol.nodes:
                        if not c.smiles in cset:
                            is_valid = False
                            break
                        else:
                            pass
            except IndexError as e:
                print("index {}, Error: {}".format(i, e))
                pass

            if is_valid:
                w.write(line)

In [8]:
print("Train molecules intersection qdb9 with zinc", end=": ")
with open("./qdb9/new_qdb9/all_intersect_zinc.txt", "r") as f:
    vocab = f.readlines()
    print(len(vocab))

Train molecules intersection qdb9 with zinc: 96575


In [36]:
with open("./qdb9/new_qdb9/all_intersect_zinc.txt", "r") as f:
    db = f.readlines()
    
    try:
        while True:
            db.remove("\n")
    except ValueError:
        pass    
    
    db_shuffled = np.copy(db)
    np.random.shuffle(db_shuffled)
    with open("./qdb9/new_qdb9/train_intersect_zinc.txt", "w") as wtrain:
        with open("./qdb9/new_qdb9/train_intersect_zinc_shuffled.txt", "w") as wtrains:
            with open("./qdb9/new_qdb9/test_intersect_zinc.txt", "w") as wtest:
                with open("./qdb9/new_qdb9/test_intersect_zinc_shuffled.txt", "w") as wtests:
                    wtrain.write("".join(db[:-5000]))
                    wtest.write("".join(db[-5000:]))
                    wtrains.write("".join(db_shuffled[:-5000]))
                    wtests.write("".join(db_shuffled[-5000:]))

In [37]:
print("Train molecules intersection qdb9 with zinc", end=": ")
with open("./qdb9/new_qdb9/train_intersect_zinc.txt", "r") as f:
    vocab = f.readlines()
    print(len(vocab))

Train molecules intersection qdb9 with zinc: 91575


In [38]:
print("Train molecules intersection qdb9 with zinc", end=": ")
with open("./qdb9/new_qdb9/test_intersect_zinc.txt", "r") as f:
    vocab = f.readlines()
    print(len(vocab))

Train molecules intersection qdb9 with zinc: 5000


In [None]:
import sys
lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

qdb9_train = '../qvae/data/qdb9/train.txt'
qdb9_test = '../qvae/data/qdb9/test.txt'
cset = set()
for i, line in tqdm(enumerate(sys.stdin)):
    try:
        smiles = line.split()[0]
        mol = MolTree(smiles)
        for c in mol.nodes:
            cset.add(c.smiles)
    except IndexError as e:
        print("index {} {}, Error: {}".format(i, line, e))
        pass

qdb9_train_zinc = '../qvae/data/qdb9/train_zinc_only.txt'

qdb9_test_zinc = '../qvae/data/qdb9/test_zinc_only.txt'

qdb9_full_zinc = '../qvae/data/qdb9/all_zinc_only.txt'

with open(qdb9_train_zinc, 'w') as train:
    with open(qdb9_test_zinc, 'w') as test:
        with open(qdb9_full_zinc, 'w') as full:
            with open(qdb9_train, "r") as qdb9_train:

                for i,line in tqdm(enumerate(qdb9_train.readlines())):
                    try:
                        is_valid = True
                        smiles = line.split()[0]
                        mol = MolTree(smiles)
                        for c in mol.nodes:
                            if not c.smiles in  cset:
                                is_valid = False
                                break
                            else:
                                pass
                    except IndexError as e:
                        print("index {}, Error: {}".format(i, e))
                        pass

                    if is_valid:
                        train.write(line)
                        full.write(line)

            with open(qdb9_test) as qdb9_test:
                for i,line in tqdm(enumerate(qdb9_test.readlines())):
                    try:
                        is_valid = True
                        smiles = line.split()[0]
                        mol = MolTree(smiles)
                        for c in mol.nodes:
                            if not c.smiles in  cset:
                                is_valid = False
                                break
                            else:
                                pass

                        if is_valid:
                            test.write(line)
                            full.write(line)
                    except IndexError as e:
                        print("index {}, Error: {}".format(i, e))
                        pass


## Humo Lumo

In [79]:
df = pd.read_csv("./qdb9/gdb9_prop_smiles.csv", index_col=0)
df = df[["mol_id", 'homo', 'lumo', "gap", "smiles"]]

In [80]:
val_idxes = df.smiles.apply(lambda x: rdkit.Chem.MolFromSmiles(x) != None)
df = df[val_idxes]

In [81]:
val_idxes = df.smiles.apply(lambda x: check_in_dict(x, cset))
df = df[val_idxes]

In [118]:
scaler = StandardScaler()
df[["homo", "lumo", "gap"]] = scaler.fit_transform(df[["homo", "lumo", "gap"]])sample

In [88]:
df = df.sample(frac=1)

In [89]:
train = df[:-5000]
test = df[-5000:]

In [91]:
path = "reg_data/"
train.smiles.to_csv(path + 'train_mols.txt', header=False, index=False)
train.homo.to_csv(path + 'train_homo.txt', header=False, index=False)
train.lumo.to_csv(path + 'train_lumo.txt', header=False, index=False)
train.gap.to_csv(path + 'train_gap.txt', header=False, index=False)

test.smiles.to_csv(path + 'test_mols.txt', header=False, index=False)
test.homo.to_csv(path + 'test_homo.txt', header=False, index=False)
test.lumo.to_csv(path + 'test_lumo.txt', header=False, index=False)
test.gap.to_csv(path + 'test_gap.txt', header=False, index=False)

pd.DataFrame(list(zip(scaler.mean_, scaler.var_)), columns=['mean', 'var'], index=df.columns[1:4]).to_csv(path+"info.txt")

## Real Intersection

In [47]:
with open("zinc/all.txt","r") as file:
    zinc_set = set([x.replace("\n", "").strip() for x in file.readlines()])
    
with open("qdb9/all.txt", "r") as file:
    qdb9_set = set([x.replace("\n", "").strip() for x in file.readlines()])

In [48]:
qdb9_set_normalized = set([rdkit.Chem.MolToSmiles(rdkit.Chem.MolFromSmiles(smiles)) for smiles in qdb9_set])
zinc_set_normalized = set([rdkit.Chem.MolToSmiles(rdkit.Chem.MolFromSmiles(smiles)) for smiles in zinc_set])

In [68]:
counter=0

for mol in zinc_set_normalized:
    if mol in zinc_set:
        counter+=1
        
print("{} molecules affected in ZINC".format(len(zinc_set) - counter))

counter=0

for mol in qdb9_set_normalized:
    if mol in qdb9_set:
        counter+=1
        
print("{} molecules affected in QDB9".format(len(qdb9_set) - counter))

20 molecules affected in ZINC
2033 molecules affected in QDB9


In [59]:
counter = 0
for mol in zinc_set:
    if mol in qdb9_set:
        counter+=1
        
print("Number of elements common to both datasets\n\tCounter:{}\n\tZinc: {}%\n\tQDB9: {}%".format(counter, 
                                                                                        int(counter*100/len(zinc_set)), 
                                                                                        int(counter*100/len(qdb9_set))))

Number of elements common to both datasets
	Counter:0
	Zinc: 0%
	QDB9: 0%


In [58]:
counter = 0
for mol in zinc_set_normalized:
    if mol in qdb9_set_normalized:
        counter+=1
        
print("Number of elements common to both normalized through rdkit datasets\n\tCounter:{}\n\tZinc: {}%\n\tQDB9: {}%".format(counter, 
                                                                                                                    int(counter*100/len(zinc_set_normalized)), 
                                                                                                                    int(counter*100/len(qdb9_set_normalized))))

Number of elements common to both normalized through rdkit datasets
	Counter:0
	Zinc: 0%
	QDB9: 0%


In [77]:
with open("reg_data/test_lumo.txt", "r") as f:
    lines = f.readlines()
    total = [float(x) for x in lines]
    print("Mean: ", np.mean(total)," Var: ", np.var(total))

Mean:  0.012599628660218397  Var:  0.9720976583714979


In [78]:
with open("reg_data/test_homo.txt", "r") as f:
    lines = f.readlines()
    total = [float(x) for x in lines]
    print("Mean: ", np.mean(total)," Var: ", np.var(total))

Mean:  -0.021853699140908463  Var:  0.9937224476477277


## CGRTools normalization

In [24]:
from CGRtools import SMILESRead, smiles
from CGRtools.exceptions import InvalidAromaticRing

In [None]:
with open("./qdb9/all.txt", 'w') as qdb9:
    for i,line in tqdm(enumerate(qdb9_train.readlines())):
        try:
            is_valid = True
            smiles = line.split()[0]
            mol = MolTree(smiles)
            for c in mol.nodes:
                if not c.smiles in  cset:
                    is_valid = False
                    break
                else:
                    pass
        except IndexError as e:
            print("index {}, Error: {}".format(i, e))
            pass

        if is_valid:
            train.write(line)
            full.write(line)

In [None]:
with SMILESRead("qdb9/all.txt") as file:
    CGR_qdb9_set = set()
    for mol in file:
        mol.canonicalize()
        CGR_qdb9_set.add(str(mol))

To check the QM9 to smiles from database i have used so far

    from CGRtools import XYZRead
    with XYZRead("dsgdb9nsd_000001.xyz") as f:
        data = f.read()[0]
        data.implicify_hydrogens()
        print(data)

Original: C#CC(C)C1(C)OC1C
Decoded: C#C[C@H](C)[C@@]1(C)O[C@@H]1C

Original: CCOCC(C=O)CO
Decoded: CCOC[C@@H](C=O)CO

Original: CC1(C)OC2(C)COC12
Decoded: CC1(C)O[C@@]2(C)CO[C@H]12

Original: OCC(O)C1CC(O)=N1
Decoded: OC[C@H](O)[C@H]1CC(O)=N1

Original: OCC1(C2CO2)CC1O
Decoded: OC[C@]1([C@H]2CO2)C[C@H]1O

Original: CC1CCC2(C)CC2O1
Decoded: C[C@@H]1CC[C@]2(C)C[C@@H]2O1

In [1]:
from CGRtools import SMILESRead, smiles
from CGRtools.exceptions import InvalidAromaticRing

In [6]:
mol = smiles("C#C[C@H](C)[C@@]1(C)O[C@@H]1C")
mol.clean_stereo()
mol.canonicalize()
print(mol)

CC1OC1(C(C#C)C)C


In [7]:
mol2 = smiles("C#CC(C)C1(C)OC1C")
mol2.clean_stereo()
mol2.canonicalize()
print(mol2)

CC1OC1(C(C#C)C)C


In [8]:
print(mol == mol2)

True


In [149]:
from CGRtools import SDFRead
with SDFRead("./qm9/gdb9.sdf", indexable=True) as f:
    for mol in f[:24]:
        mol.implicify_hydrogens()
        mol.canonicalize()
        print(mol)

C
N
O
C#C
C#N
O=C
CC
OC
CC#C
N#CC
O=CC
O=CN
CCC
C(O)C
COC
C1CC1
C1CO1
CC(C)=O
O=C(N)C
NC(N)=O
CC(C)C
CC(C)O
CC#CC
C([NH3+])#CC


In [30]:
with open("zinc/all.txt","r") as file:
    zinc_set = set([x.replace("\n", "").strip() for x in file.readlines()])

In [31]:
with open("qdb9/all.txt", "r") as file:
    qdb9_set = set([x.replace("\n", "").strip() for x in file.readlines()])

In [2]:
with SMILESRead("qdb9/all.txt") as file:
    CGR_qdb9_set = set()
    for mol in file:
        mol.canonicalize()
        CGR_qdb9_set.add(str(mol))

In [38]:
with SMILESRead("zinc/all.txt", ignore=True) as file:
    CGR_zinc_set = set()
    try:
        for mol in file:
            mol.standardize()
            CGR_zinc_set.add(str(mol))
    except:
        pass

In [None]:
smiles = set()
with SMILESRead("all.txt", ignore=True) as f:
    for n,i in enumerate(f):
        try:
            i.canonicalize()
        except InvalidAromaticRing:
            continue
        smiles.add(str(i))
        
    else:
        last = n

len(smiles)

In [39]:
len(CGR_zinc_set)

249440

In [27]:
len(zinc_set)

249456

In [23]:
print(len(qdb9_set))
print(len(CGR_zinc_set.intersection(CGR_zinc_set)))
print(len(zinc_set.intersection(CGR_zinc_set)), len(qdb9_set.intersection(CGR_qdb9_set)))

111529
9995
51 3070


# Filtering out non-frequent molecules

In [6]:
lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

In [10]:
len(all_data)

133885

In [7]:
print("QDB9 prunned vocabulary", end=": ")
with open("./qdb9/vocab_prunned_10.txt", "r") as f:
    vocab = [x.replace("\n", "") for x in f.readlines()]
    print(len(vocab))
    
cset = set(vocab)

QDB9 prunned vocabulary: 331


In [8]:
all_data = pd.read_csv('./qdb9/gdb9_prop_smiles.csv', index_col=0)
all_data = all_data.loc[:, ["mol_id", "lumo", "homo", "smiles"]]

In [9]:
def check_rdkit_validity(mol):
    try:
        rdkit.Chem.MolToSmiles(rdkit.Chem.MolFromSmiles(mol))
        return True
    except:
        return False

def check_vocab_validity(smiles, vocab):
    mol = MolTree(smiles)
    
    for c in mol.nodes:
        if c.smiles not in vocab:
            return False
        
    return True
        
def normalize(smiles_str):
    mol = smiles(smiles_str)
    mol.canonicalize()
    return mol

In [135]:
smiles_valid_idxes = all_data.smiles.apply(lambda x: check_rdkit_validity(x))

In [136]:
all_data = all_data.loc[smiles_valid_idxes, :]

In [137]:
prunned_vocab_idxes = all_data.smiles.apply(lambda x: check_vocab_validity(x, cset))

In [138]:
all_data = all_data.loc[prunned_vocab_idxes, :]

In [139]:
all_data.lumo = all_data.lumo*scale_to_pm9
all_data.homo = all_data.homo*scale_to_pm9

In [140]:
qdb9_smiles_canonized = all_data.smiles.apply(lambda x: normalize(x))

In [11]:
qdb9

NameError: name 'qdb9' is not defined

In [141]:
qdb9 = set(list(qdb9_smiles_canonized))

zinc = set()
with open("./zinc/zinc_all_canon.txt", "r") as f:
    for mol in [x.replace("\n", "") for x in f.readlines()]:
        zinc.add(mol)
        
print("Intersection between canonized QDB9 and ZINC: {}".format(len(qdb9.intersection(zinc))))
print("Intersection between non-canonized QDB9 and ZINC: {}".format(len(set(list(all_data.smiles)).intersection(zinc))))

Intersection between canonized QDB9 and ZINC: 0
Intersection between non-canonized QDB9 and ZINC: 0


## Lets extract zinc that is in the prunned qdb9 vocab

In [None]:
zinc_db = pd.Series(list(zinc))
zinc_prunned_indexes = zinc_db.apply(lambda x: check_vocab_validity(x, cset))
zinc_db.rename("smiles", inplace=True)

### Saving the resulting files for training

In [152]:
all_data.loc["smiles"] = qdb9_smiles_canonized
zinc_db_prunned = zinc_db[zinc_prunned_indexes]

In [153]:
all_data = all_data.dropna()
zinc_db_prunned = zinc_db_prunned.dropna()

In [None]:
#all_data.to_csv("qdb9/prunned/qdb9/qdb9_prunned.csv", header=None, index=None)
#pd.Series(vocab).to_csv("qdb9/prunned/vocab.txt", header=None, index=None)

In [1]:
import pandas as pd

In [14]:
all_data = pd.read_csv("qdb9/prunned/qdb9/qdb9_prunned.csv", header=None)
all_data.columns = ["mol_id", "lumo", "homo", "smiles"]

In [20]:
all_data.homo.mean()

-6.616149834752479

Splitting train-test-validation

In [36]:
from sklearn.model_selection import train_test_split
qdb9_train_X, qdb9_test_X = train_test_split(all_data, test_size=0.05)
qdb9_test_X, qdb9_val_X = train_test_split(qdb9_test_X, test_size=0.1)

In [155]:
qdb9_train_X.smiles.to_csv("qdb9/prunned/qdb9/train_smiles.txt", header=None, index=None)
qdb9_train_X.homo.to_csv("qdb9/prunned/qdb9/train_homo.txt", header=None, index=None)
qdb9_train_X.homo.to_csv("qdb9/prunned/qdb9/train_lumo.txt", header=None, index=None)

qdb9_test_X.smiles.to_csv("qdb9/prunned/qdb9/test_smiles.txt", header=None, index=None)
qdb9_test_X.homo.to_csv("qdb9/prunned/qdb9/test_homo.txt", header=None, index=None)
qdb9_test_X.homo.to_csv("qdb9/prunned/qdb9/test_lumo.txt", header=None, index=None)

qdb9_val_X.smiles.to_csv("qdb9/prunned/qdb9/val_smiles.txt", header=None, index=None)
qdb9_val_X.homo.to_csv("qdb9/prunned/qdb9/val_homo.txt", header=None, index=None)
qdb9_val_X.homo.to_csv("qdb9/prunned/qdb9/val_lumo.txt", header=None, index=None)

In [156]:
zinc_qdb9 = set([*list(all_data.smiles), *list(zinc_db_prunned)])
zinc_qdb9 = pd.Series(list(zinc_qdb9))

In [157]:
zinc_qdb9_train_X, zinc_qdb9_test_X = train_test_split(zinc_qdb9, test_size=0.05)
zinc_qdb9_test_X, zinc_qdb9_val_X = train_test_split(zinc_qdb9_test_X, test_size=0.1)

In [160]:
zinc_qdb9_train_X.to_csv("qdb9/prunned/qdb9+zinc/train_smiles.txt", header=None, index=None)
zinc_qdb9_test_X.to_csv("qdb9/prunned/qdb9+zinc/test_smiles.txt", header=None, index=None)
zinc_qdb9_val_X.to_csv("qdb9/prunned/qdb9+zinc/val_smiles.txt", header=None, index=None)
zinc_db_prunned.to_csv("qdb9/prunned/qdb9+zinc/zinc_prunned.txt", header=None, index=None)