##### This notebook generates train / valid / test (0.8 / 0.1 / 0.1) splits for RTECS, based on the random splitting method in MoleculeNet

Splits were created for seeds of 122, 123, 124 (same seeds as MoleculeNet), using:

- MoleculeNet Splitting method: https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py

RTECS dataset is a commericial dataset provided by Biovia. We used the acute oral toxicity data in mice to create binary classes of "toxic"/"nontoxic" chemicals using the LD50 (lethal dose for 50% of population) data. The cutoff used was 5000 mg/kg, as defined by EPA.

However, since this is a commericial dataset, we cannot provide it. Instead below is the method used to create the binary classes and splits. 



In [None]:
# general and data handling
import numpy as np
import pandas as pd
import os
from collections import Counter
import time
import random
import joblib

# Required RDKit modules
import rdkit as rd
from rdkit import DataStructs
from rdkit.Chem import AllChem

# modeling
import sklearn as sk
from sklearn.model_selection import train_test_split

In [None]:
import torch
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  

In [None]:
import datetime, os

##### Settings

In [None]:
# set seed value 
seed_value = 122 #122 123 124, as used in MoleculeNet
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
np.random.seed(seed_value)
random.seed(seed_value)
torch.backends.cudnn.enabled=False
torch.backends.cudnn.deterministic=True

In [None]:
# number of bits for morgan fingerprints
morgan_bits = 4096

In [None]:
# number of radius for morgan fingerprints
morgan_radius = 2

##### Load data

In [None]:
a_oral_file = # cannot provide

a_oral_data = pd.read_csv(a_oral_file)

###### Define binary labels based on LD50 values

In [None]:
def binary_non_toxic_catg(numeric_mgkg):
    ''' Defines binary classes for acute oral toxicity data, based on the LD50 (lethal dose for 50% of population) values.
        Cutoff of 5000 mg/kg is used to define "toxic" and "nontoxic" molecules. 
    '''
    if(numeric_mgkg <= 5000):
        # Toxic
        return "toxic_a_oral"
    elif(numeric_mgkg > 5000):
        # Nontoxic
        return "non-toxic_a_oral"
    
a_oral_data['toxic_catg_5000'] = a_oral_data["numeric_mgkg"].apply(lambda x: binary_non_toxic_catg(x))

# Extract only EPA_catg, smiles and seqnum from a_oral_data 
a_oral_data = a_oral_data[['toxic_catg_5000','pubchem_CASRN_SMILES', 'CASRN_canonical_SMILES', 'seqnum']]

# Convert the EPA_catg to one-hot encoded columns for labels to classify into
a_oral_data_toxic_labels = pd.get_dummies(a_oral_data.toxic_catg_5000)
a_oral_data = pd.concat([a_oral_data, a_oral_data_toxic_labels], axis=1)
a_oral_data = a_oral_data.drop(['toxic_catg_5000'], axis=1)

# SMILES for the molecules in the dataset had been curated by matching CASRN to Pubchem
a_oral_data = a_oral_data.rename(columns = {'pubchem_CASRN_SMILES':'smiles', 
                                            'CASRN_canonical_SMILES':'canonical_smiles'})
a_oral_data = a_oral_data.drop(['canonical_smiles'], axis=1)
a_oral_data_seqnum = a_oral_data
a_oral_data = a_oral_data.drop(['seqnum'], axis=1)
a_oral_data = a_oral_data.drop_duplicates()

# Selecting only EPA_catg as labels, i.e., 1 - Toxic, 0 - NonToxic by 5000 mg/kg LD50 cutoff 
a_oral_tasks = ['toxic_a_oral'] 

print("Acute oral tasks: %s" % str(a_oral_tasks))
print("%d tasks in total" % len(a_oral_tasks))

##### Setting all tasks 

In [None]:
data = [a_oral_data] 

In [None]:
all_tasks = a_oral_tasks 

##### Process data

In [None]:
# Convert SMILES to CANONICAL SMILES
# In the process of canonicalizing SMILES, any bad SMILES definition 
#     is caught and removed from the dataset

for i in range(len(data)):
    smis = list(data[i].smiles)

    cans = []
    for smi in smis:
        mol = rd.Chem.MolFromSmiles(smi)
        # see whether can be parsed to mol
        if mol:
            can = rd.Chem.MolToSmiles(mol, True)
            cans.append(can)
        else:
            cans.append(np.nan)

    data[i]['SMILES'] = cans
    
    # drop data point that has invalid molecule
    data[i] = data[i][data[i]['SMILES'].notna()]

##### MoleculeNet Split

In [None]:
### Method borrowed from MoleculeNet for random splits of 0.8 / 0.1 / 0.1, train / test / valid 
# Returns index of random train, test, valid datasets in array of [train,test,valid]

def split(dataset,
            seed=None,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
            log_every_n=None):
    """
        Splits internal compounds randomly into train/validation/test.
        """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    if not seed is None:
        np.random.seed(seed)
    num_datapoints = len(dataset)
    train_cutoff = int(frac_train * num_datapoints)
    valid_cutoff = int((frac_train + frac_valid) * num_datapoints)
    shuffled = np.random.permutation(range(num_datapoints))
    return (shuffled[:train_cutoff], shuffled[train_cutoff:valid_cutoff],
            shuffled[valid_cutoff:])

In [None]:
temp_train_data = []
temp_test_data = []
temp_valid_data = []
for i in range(len(data)):
    splitter_i = split(data[i])
    for j in range(len(splitter_i)):
            if j==0: 
                temp_train_data.append(data[i].iloc[splitter_i[j]])
            if j==1: 
                temp_test_data.append(data[i].iloc[splitter_i[j]])
            if j==2: 
                temp_valid_data.append(data[i].iloc[splitter_i[j]])
                
train_data = temp_train_data[0]
test_data  = temp_test_data[0]
valid_data  = temp_valid_data[0]


for i in range(1, len(data)):
    train_data = train_data.merge(temp_train_data[i], how='outer', on='smiles')
    test_data  = test_data.merge(temp_test_data[i], how='outer', on='smiles')
    valid_data  = valid_data.merge(temp_valid_data[i], how='outer', on='smiles')


data = [train_data, test_data, valid_data]

##### Save data

In [None]:
data_path  = #f"pathway-to-data/rtecs/split_data/seed_{seed_value}/"

if not os.path.exists(data_path):
    os.makedirs(data_path)

In [None]:
torch.save(data[0], data_path+"train_data_rtecs.pth")
torch.save(data[1], data_path+"test_data_rtecs.pth")
torch.save(data[2], data_path+"valid_data_rtecs.pth")

In [None]:
print("Total number of examples, train: " + str(data[0].shape[0]))
print("Total number of examples, test: " + str(data[1].shape[0]))
print("Total number of examples, valid: " + str(data[2].shape[0]))
print("Total number of examples, train+test+valid: " + str(data[0].shape[0] + data[1].shape[0] + data[2].shape[0]))