In [2]:

import os
from os.path import join, exists, basename
import argparse
import numpy as np
import paddle.fluid as fluid
import paddle
paddle.seed(0)
import random
np.random.seed(42) 
random.seed(42)
fluid.default_startup_program().random_seed = 42
fluid.default_main_program().random_seed = 42
import paddle.nn as nn
import pgl
import pandas as pd
from pahelix.model_zoo.gem_model import GeoGNNModel
from pahelix.utils import load_json_config
from pahelix.datasets.inmemory_dataset import InMemoryDataset
from rdkit.Chem import AllChem

from src.model import DownstreamModel
from src.featurizer import DownstreamTransformFn, DownstreamCollateFn
from src.utils import get_dataset, create_splitter, get_downstream_task_names, get_dataset_stat, \
        calc_rocauc_score, calc_rmse, calc_mae, exempt_parameters
from rdkit.ML.Descriptors import MoleculeDescriptors
from rdkit.Chem import Descriptors
#from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors

In [4]:
data_path=f"./cached_data/esol/rdkit"
dataset = InMemoryDataset(npz_data_path=data_path)

In [5]:
data_path=f"./cached_data/esol/rdkit/TestDataset/"
test_dataset = InMemoryDataset(npz_data_path=data_path)

In [16]:
dataset.data_list[0].keys()

dict_keys(['atomic_num', 'chiral_tag', 'degree', 'explicit_valence', 'formal_charge', 'hybridization', 'implicit_valence', 'is_aromatic', 'total_numHs', 'mass', 'bond_dir', 'bond_type', 'is_in_ring', 'edges', 'morgan_fp', 'maccs_fp', 'daylight_fg_counts', 'atom_pos', 'bond_length', 'BondAngleGraph_edges', 'bond_angle', 'label', 'smiles'])

In [17]:
test_dataset.data_list[0].keys()

dict_keys(['atomic_num', 'chiral_tag', 'degree', 'explicit_valence', 'formal_charge', 'hybridization', 'implicit_valence', 'is_aromatic', 'total_numHs', 'mass', 'bond_dir', 'bond_type', 'is_in_ring', 'edges', 'morgan_fp', 'maccs_fp', 'daylight_fg_counts', 'atom_pos', 'bond_length', 'BondAngleGraph_edges', 'bond_angle', 'label', 'smiles'])

In [34]:
cached_data_path=f"./cached_data/bbbp/rdkit"
splitter = create_splitter("scaffold")
dataset_new = InMemoryDataset(npz_data_path=cached_data_path)
train_dataset, valid_dataset, test_dataset = splitter.split(dataset_new,0.8,0.1,0.1)
test_dataset[0]["label"]

array([1], dtype=int64)

In [None]:
dataset_new[0]

In [10]:
stri="COC(=O)[C@@H]1[C@H]2C[C@H]3c4[nH]c5ccccc5c4CCN3C[C@@H]2CC[C@@H]1O"

mol=AllChem.MolFromSmiles(stri)

In [14]:
AllChem.MolToSmiles(mol)

'COC(=O)[C@H]1[C@@H](O)CC[C@H]2CN3CCc4c([nH]c5ccccc45)[C@@H]3C[C@@H]21'

In [None]:
stri="Cc1c[nH+][o+]c(C([NH])CC(C)C(C)(C)N(C(C)(C)C)C(C)(N)N)c1[O-]"
dataset_new = InMemoryDataset(npz_data_path="./cached_data/bbbp/rdkit")
for item in dataset_new:
        print(item["smiles"])

In [50]:
data_path=f"./chemrl_downstream_datasets/bbbp/raw/bbbp.csv"
ref_df=pd.read_csv(data_path)
ref_df["smiles"]=ref_df["smiles"].apply(std_smiles)
ref_df=ref_df.set_index("smiles")

In [64]:
b=ref_df.loc["CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1"][["p_np"]]
a=ref_df.loc["C[C@H](N)Cc1ccccc1"][["p_np"]]
if len(ref_df.loc["C[C@H](N)Cc1ccccc1"][["p_np"]])>1:
    print("asd")

In [69]:
b["p_np"][0]

1

In [6]:
def dataset_feat_extractor(dataset):
    rdkit_desc_list=[]
    generator = rdNormalizedDescriptors.RDKit2DNormalized()
    for item in dataset:
        smi=item["smiles"]
        features = generator.process(smiles)[1:]
        print(features)
    features=np.stack(rdkit_desc_list, axis=0)
    return features
        
def feature_extractor(dataset):
    cached_data_path=f"./cached_data/{dataset}/rdkit"
    splitter = create_splitter("scaffold")
    dataset_new = InMemoryDataset(npz_data_path=cached_data_path)
    train_dataset, valid_dataset, test_dataset = splitter.split(dataset_new,0.8,0.1,0.1)
    train_features={"features":None}
    valid_features={"features":None}
    test_features={"features":None}
    train_features["features"]=dataset_feat_extractor(train_dataset)
    valid_features["features"]=dataset_feat_extractor(valid_dataset)
    test_features["features"]=dataset_feat_extractor(test_dataset)
    
    np.savez(f"test_data/"+dataset+"-features-train.npz",features=train_features["features"],allow_pickle=True)
    np.savez(f"test_data/"+dataset+"-features-valid.npz",features=valid_features["features"],allow_pickle=True)
    np.savez(f"test_data/"+dataset+"-features-test.npz",features=test_features["features"],allow_pickle=True)
    

In [7]:
datasets=["bbbp","bace"]
for dataset in datasets:
    feature_extractor(dataset)

(11.682267749519735, -0.40969098783565605, 11.682267749519735, 0.13470350189866753, 0.4748208495574166, 360.32500000000005, 333.10900000000004, 359.141884464, 130, 0, 0.3058126544757872, -0.4599478444235159, 0.4599478444235159, 0.3058126544757872, 1.0869565217391304, 1.6521739130434783, 2.130434782608696, 35.49675362605588, 10.074096807959336, 2.151891909617768, -2.226406972350156, 2.171389352437839, -2.2696328329032203, 6.182875633901804, -0.15440227772484644, 2.458834471614553, 462.905083873753, 17.294682450997076, 14.278787667510478, 15.79064555954739, 10.852685169664095, 8.031431235536765, 9.100476203186464, 6.471456179320208, 7.227385125338664, 3.2767284299878363, 3.6147901318792424, 2.276861130093487, 2.684937702906069, -0.9299999999999999, 79428.7448714238, 20.11531037607612, 10.04313306942461, 8.474206944686983, 149.45214194577363, 9.636772684650527, 5.601050810983688, 0.0, 0.0, 0.0, 5.969305287951849, 4.794537184071822, 0.0, 0.0, 23.20187978046503, 12.13273413692322, 51.309040

In [79]:
def std_smiles(item):
    mol=AllChem.MolFromSmiles(item)
    if mol==None:
        return "asd"
    smi=AllChem.MolToSmiles(mol)
    return smi
def save_test_data(dataset,targets,smiles_col):
    cached_data_path=f"./cached_data/{dataset}/rdkit"
    data_path=f"./chemrl_downstream_datasets/{dataset}/raw/{dataset}.csv"
    ref_df=pd.read_csv(data_path)
    if dataset=="bace":
        smiles_col="mol"
    print("before")
    print(len(ref_df[smiles_col]))
    print(len(ref_df[smiles_col].unique()))
    ref_df[smiles_col]=ref_df[smiles_col].apply(std_smiles)
    print("before")
    print(len(ref_df[smiles_col]))
    print(len(ref_df[smiles_col].unique()))
    ref_df=ref_df.set_index(smiles_col)
    print(dataset)
    splitter = create_splitter("scaffold")
    dataset_new = InMemoryDataset(npz_data_path=cached_data_path)
    train_dataset, valid_dataset, test_dataset = splitter.split(dataset_new,0.8,0.1,0.1)
    smiles_list_test={"smiles":[]}
    smiles_list_val={"smiles":[]}
    smiles_list_train={"smiles":[]}
    for target in targets:
        smiles_list_test[target]=[]
        smiles_list_val[target]=[]
        smiles_list_train[target]=[]
    for item in test_dataset:
        smiles_list_test["smiles"].append(item["smiles"])
        mol=AllChem.MolFromSmiles(item["smiles"])
        smi=AllChem.MolToSmiles(mol)
        
        for i in range(len(targets)):
            labels=ref_df.loc[smi][[targets[i]]]
            if len(labels)>1:
                label=labels[targets[i]][0]
            else:
                label=labels[0]
            smiles_list_test[targets[i]].append(label)
    for item in train_dataset:
        smiles_list_train["smiles"].append(item["smiles"])
        mol=AllChem.MolFromSmiles(item["smiles"])
        smi=AllChem.MolToSmiles(mol)
        #print(ref_df.loc[smi]["p_np"])
        for i in range(len(targets)):
            labels=ref_df.loc[smi][[targets[i]]]
            if len(labels)>1:
                label=labels[targets[i]][0]
            else:
                label=labels[0]
            smiles_list_train[targets[i]].append(label)
    for item in valid_dataset:
        smiles_list_val["smiles"].append(item["smiles"])
        mol=AllChem.MolFromSmiles(item["smiles"])
        smi=AllChem.MolToSmiles(mol)
        for i in range(len(targets)):
            labels=ref_df.loc[smi][[targets[i]]]
            if len(labels)>1:
                label=labels[targets[i]][0]
            else:
                label=labels[0]
            smiles_list_val[targets[i]].append(label)
    df_test=pd.DataFrame.from_dict(smiles_list_test)
    df_train=pd.DataFrame.from_dict(smiles_list_train)
    df_val=pd.DataFrame.from_dict(smiles_list_val)
    
    df_test.to_csv(f"test_data/"+dataset+"-gem-test.csv",index=False)
    df_train.to_csv(f"test_data/"+dataset+"-gem-train.csv",index=False)
    df_val.to_csv(f"test_data/"+dataset+"-gem-val.csv",index=False)


In [88]:
datasets=["esol","freesolv","lipo","qm7","bbbp","bace"]
#datasets=["bbbp"]
targets={"esol":['measured log solubility in mols per litre'],"freesolv":['expt'],"lipo":['exp'],"qm7":['u0_atom'],"bbbp":['p_np'],"bace":['Class']}
for dataset in datasets:
    smiles_col="smiles"
    
    save_test_data(dataset,targets[dataset],smiles_col)

before
1128
1128
before
1128
1117
esol
before
642
642
before
642
642
freesolv
before
4200
4200
before
4200
4200
lipo
before
6830
6830
before
6830
6830
qm7
before
2050
2050
before
2050
1976
bbbp
before
1513
1513
before
1513
1513
bace


In [12]:
def test_set_check(dataset,mode):
    cached_data_path=f"./cached_data/{dataset}/{mode}"
    splitter = create_splitter("scaffold")
    dataset = InMemoryDataset(npz_data_path=cached_data_path)
    train_dataset, valid_dataset, test_dataset = splitter.split(dataset,0.8,0.1,0.1)
    smiles_list=[]
    for item in test_dataset:
        smiles_list.append(item["smiles"])
    return smiles_list

In [16]:
lis1=["asd","a","b","c"]
list2=["a","b"]
set(lis1)-set(list2)

{'asd', 'c'}

In [23]:
datasets=["esol","freesolv","lipo","qm7","bbbp","bace"]
modes=["rdkit","graph","mmffless","geomol"]
for dataset in datasets:
    smiles_list=[]
    for mode in modes:
        smiles=test_set_check(dataset,mode)
        smiles_list.append(smiles)
    for i in range(len(modes)-1):
        for j in range(i+1,len(modes)):
            set1=set(smiles_list[i])
            set2=set(smiles_list[j])
            diff=0
            for item in set1:
                if item not in set2:
                    diff+=1
            for item in set2:
                if item not in set1:
                    diff+=1
            print(f"test_data/{dataset}:mode={modes[i]}-{modes[j]}:{diff}")


test_data/esol:mode=rdkit-graph:0
test_data/esol:mode=rdkit-mmffless:0
test_data/esol:mode=rdkit-geomol:0
test_data/esol:mode=graph-mmffless:0
test_data/esol:mode=graph-geomol:0
test_data/esol:mode=mmffless-geomol:0
test_data/freesolv:mode=rdkit-graph:0
test_data/freesolv:mode=rdkit-mmffless:0
test_data/freesolv:mode=rdkit-geomol:0
test_data/freesolv:mode=graph-mmffless:0
test_data/freesolv:mode=graph-geomol:0
test_data/freesolv:mode=mmffless-geomol:0
test_data/lipo:mode=rdkit-graph:0
test_data/lipo:mode=rdkit-mmffless:0
test_data/lipo:mode=rdkit-geomol:0
test_data/lipo:mode=graph-mmffless:0
test_data/lipo:mode=graph-geomol:0
test_data/lipo:mode=mmffless-geomol:0
test_data/qm7:mode=rdkit-graph:0
test_data/qm7:mode=rdkit-mmffless:0
test_data/qm7:mode=rdkit-geomol:0
test_data/qm7:mode=graph-mmffless:0
test_data/qm7:mode=graph-geomol:0
test_data/qm7:mode=mmffless-geomol:0
test_data/bbbp:mode=rdkit-graph:28
test_data/bbbp:mode=rdkit-mmffless:0
test_data/bbbp:mode=rdkit-geomol:28
test_data/

In [20]:
def model_datasets_check(path1,path2):
    df1=pd.read_csv(path1)
    df2=pd.read_csv(path2)
    smiles1=set(df1["smiles"])
    smiles2=set(df2["smiles"])
    for item in smiles1:
        if item not in smiles2:
            print(item)
    diff=list(smiles1-smiles2).extend((smiles2-smiles1))
    return diff

In [30]:
print(model_datasets_check("test_data/test.csv","test_data/bbbp-gem-test.csv"))

None


In [56]:
test_dataset="test_data/bace-gem-test.csv"
ref_dataset="chemrl_downstream_datasets/bace/raw/bace.csv"
ref_df=pd.read_csv(ref_dataset)
test_df=pd.read_csv(test_dataset)
ref_df=ref_df.set_index("mol")
for item in test_df.iterrows():
    smi=item[1]["smiles"]
    ref_mol=ref_df.loc[smi]
    if ref_mol["Class"]!=item[1]["Class"]:
        print(ref_mol["Class"],item[1]["Class"])


0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1
0 -1


In [46]:
ref_df

Unnamed: 0,mol,CID,Class,Model,pIC50,MW,AlogP,HBA,HBD,RB,...,PEOE6 (PEOE6),PEOE7 (PEOE7),PEOE8 (PEOE8),PEOE9 (PEOE9),PEOE10 (PEOE10),PEOE11 (PEOE11),PEOE12 (PEOE12),PEOE13 (PEOE13),PEOE14 (PEOE14),canvasUID
0,O1CC[C@@H](NC(=O)[C@@H](Cc2cc3cc(ccc3nc2N)-c2c...,BACE_1,1,Train,9.154901,431.56979,4.4014,3,2,5,...,53.205711,78.640335,226.855410,107.434910,37.133846,0.000000,7.980170,0.000000,0.000000,1
1,Fc1cc(cc(F)c1)C[C@H](NC(=O)[C@@H](N1CC[C@](NC(...,BACE_2,1,Train,8.853872,657.81073,2.6412,5,4,16,...,73.817162,47.171600,365.676940,174.076750,34.923889,7.980170,24.148668,0.000000,24.663788,2
2,S1(=O)(=O)N(c2cc(cc3c2n(cc3CC)CC1)C(=O)N[C@H](...,BACE_3,1,Train,8.698970,591.74091,2.5499,4,3,11,...,70.365707,47.941147,192.406520,255.752550,23.654478,0.230159,15.879790,0.000000,24.663788,3
3,S1(=O)(=O)C[C@@H](Cc2cc(O[C@H](COCC)C(F)(F)F)c...,BACE_4,1,Train,8.698970,591.67828,3.1680,4,3,12,...,56.657166,37.954151,194.353040,202.763350,36.498634,0.980913,8.188327,0.000000,26.385181,4
4,S1(=O)(=O)N(c2cc(cc3c2n(cc3CC)CC1)C(=O)N[C@H](...,BACE_5,1,Train,8.698970,629.71283,3.5086,3,3,11,...,78.945702,39.361153,179.712880,220.461300,23.654478,0.230159,15.879790,0.000000,26.100143,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1508,Clc1cc2nc(n(c2cc1)C(CC(=O)NCC1CCOCC1)CC)N,BACE_1543,0,Test,3.000000,364.86969,2.5942,3,2,6,...,37.212799,37.681076,180.226410,95.670128,30.107586,9.368159,7.980170,0.000000,0.000000,1543
1509,Clc1cc2nc(n(c2cc1)C(CC(=O)NCc1ncccc1)CC)N,BACE_1544,0,Test,3.000000,357.83731,2.8229,3,2,6,...,45.792797,47.349350,122.401500,99.877144,30.107586,9.368159,7.980170,0.000000,0.000000,1544
1510,Brc1cc(ccc1)C1CC1C=1N=C(N)N(C)C(=O)C=1,BACE_1545,0,Test,2.953115,320.18451,3.0895,2,1,2,...,47.790600,22.563574,96.290794,58.798935,20.071724,9.368159,0.000000,6.904104,0.000000,1545
1511,O=C1N(C)C(=NC(=C1)C1CC1c1cc(ccc1)-c1ccccc1)N,BACE_1546,0,Test,2.733298,317.38440,3.8595,2,1,3,...,77.219978,9.316234,95.907784,112.609720,20.071724,9.368159,0.000000,6.904104,0.000000,1546
