##### This notebook generates train / valid / test (0.8 / 0.1 / 0.1) splits for Tox21, based on the random splitting method in MoleculeNet

Splits were created for seeds of 122, 123, 124 (same seeds as MoleculeNet). 

All splits are saved in data/tox21/split_data folder 

Raw Tox21 data was obtained from MoleculeNet.
- MoleculeNet Data: https://github.com/deepchem/deepchem/blob/7463d93d0f85a3ba58cd155209540d8e649d875e/deepchem/molnet/load_function/tox21_datasets.py specifies this location - "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz"
- MoleculeNet Splitting method: https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py

Tasks (endpoints) defined in the Tox21 are: 
- 'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD','NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'

In [None]:
# general and data handling
import numpy as np
import pandas as pd
import os
from collections import Counter
import time
import random
import joblib

# Required RDKit modules
import rdkit as rd
from rdkit import DataStructs
from rdkit.Chem import AllChem

# modeling
import sklearn as sk
from sklearn.model_selection import train_test_split

In [None]:
import torch
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  

In [None]:
import datetime, os

##### Settings

In [None]:
# set seed value
seed_value = 124 #122 123 124, as used in MoleculeNet
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
np.random.seed(seed_value)
random.seed(seed_value)
torch.backends.cudnn.enabled=False
torch.backends.cudnn.deterministic=True

In [None]:
# number of bits for morgan fingerprints
morgan_bits = 4096

In [None]:
# number of radius for morgan fingerprints
morgan_radius = 2

##### Load data

In [None]:
tox21_file  = '../../data/datasets/tox21/raw_data/tox21.csv'
tox21_tasks = ['NR-AR', 'NR-Aromatase', 'NR-PPAR-gamma', 'SR-HSE', 
               'NR-AR-LBD', 'NR-ER', 'SR-ARE', 'SR-MMP',
               'NR-AhR', 'NR-ER-LBD', 'SR-ATAD5', 'SR-p53']

tox21_data = pd.read_csv(tox21_file)
print('Reading {}... {} data loaded.'.format(tox21_file, len(tox21_data)))
tox21_data.head()

In [None]:
data = [tox21_data]

In [None]:
all_tasks = tox21_tasks

##### Process data

In [None]:
# Convert SMILES to CANONICAL SMILES
# In the process of canonicalizing SMILES, any bad SMILES definition 
#     is caught and removed from the dataset

for i in range(len(data)):
    smis = list(data[i].smiles)

    cans = []
    for smi in smis:
        mol = rd.Chem.MolFromSmiles(smi)
        # see whether can be parsed to mol
        if mol:
            can = rd.Chem.MolToSmiles(mol, True)
            cans.append(can)
        else:
            cans.append(np.nan)

    data[i]['SMILES'] = cans
    
    # drop data point that has invalid molecule
    data[i] = data[i][data[i]['SMILES'].notna()]

##### <font color = 'blue'> MoleculeNet Split </font>

In [None]:
### Method borrowed from MoleculeNet for random splits of 0.8 / 0.1 / 0.1, train / test / valid 
# Returns index of random train, test, valid datasets in array of [train,test,valid]

def split(dataset,
            seed=None,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
            log_every_n=None):
    """
        Splits internal compounds randomly into train/validation/test.
        """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    if not seed is None:
        np.random.seed(seed)
    num_datapoints = len(dataset)
    train_cutoff = int(frac_train * num_datapoints)
    valid_cutoff = int((frac_train + frac_valid) * num_datapoints)
    shuffled = np.random.permutation(range(num_datapoints))
    return (shuffled[:train_cutoff], shuffled[train_cutoff:valid_cutoff],
            shuffled[valid_cutoff:])

In [None]:
temp_train_data = []
temp_test_data = []
temp_valid_data = []
for i in range(len(data)):
    splitter_i = split(data[i])
    for j in range(len(splitter_i)):
            if j==0: 
                temp_train_data.append(data[i].iloc[splitter_i[j]])
            if j==1: 
                temp_test_data.append(data[i].iloc[splitter_i[j]])
            if j==2: 
                temp_valid_data.append(data[i].iloc[splitter_i[j]])
                
train_data = temp_train_data[0]
test_data  = temp_test_data[0]
valid_data  = temp_valid_data[0]


for i in range(1, len(data)):
    train_data = train_data.merge(temp_train_data[i], how='outer', on='smiles')
    test_data  = test_data.merge(temp_test_data[i], how='outer', on='smiles')
    valid_data  = valid_data.merge(temp_valid_data[i], how='outer', on='smiles')


data = [train_data, test_data, valid_data]

##### Save data

In [None]:
# To save data uncomment the path to dataset
data_path = #f"../../data/datasets/tox21/split_data/seed_{seed_value}/"

if not os.path.exists(data_path):
    os.makedirs(data_path)

In [None]:
torch.save(data[0], data_path+"train_data_tox21.pth")
torch.save(data[1], data_path+"test_data_tox21.pth")
torch.save(data[2], data_path+"valid_data_tox21.pth")

In [None]:
# Check: load saved clintox train/test/valid data
train_data=torch.load(data_path + 'train_data_tox21.pth')
test_data=torch.load(data_path + 'test_data_tox21.pth')
valid_data=torch.load(data_path + 'valid_data_tox21.pth')

data = [train_data, test_data, valid_data]