<a href="https://colab.research.google.com/github/GalJakob/Toxicity-prediction-WS/blob/main/code/chemBerta.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

***import datasets***

In [None]:
import io
import pandas as pd
from google.colab import files,drive

WITH_TRAINING = False #change to True if training is needed
WITH_HYPER_PARAMS = True # change to True if testing hyper parameters is needed
MODEL = "chemBerta" # constant
AUGMENTED_CASE = "none augmented"  # can be "none augmented" / "only train augmented"/ "both augmented"
dataset_name = "clintox" # can be cardio / tox21 / clintox
ds_test = None
ds_train = None

if AUGMENTED_CASE == "none augmented":
  ds_test = dataset_name + "_test"
  ds_train = dataset_name + "_train"
elif AUGMENTED_CASE == "only train augmented":
  ds_test = dataset_name + "_test"
  ds_train = dataset_name + "_train_aug"
else:
  ds_test = dataset_name + "_test"
  ds_train = dataset_name + "_train_aug"

path_train = f"https://raw.githubusercontent.com/GalJakob/Toxicity-prediction-WS/main/datasets/train%20datasets/{ds_train}.csv"
path_test = f"https://raw.githubusercontent.com/GalJakob/Toxicity-prediction-WS/main/datasets/test%20datasets/{ds_test}.csv"
drive.mount("/content/drive")

try: #getting data from github
  test_data = pd.read_csv(path_test)
  train_data = pd.read_csv(path_train)

except: #uploading data instead from github
  data = files.upload()
  train_data = io.BytesIO(data[ds_train])
  test_data = io.BytesIO(data[ds_test])


Mounted at /content/drive


***installation required***

In [None]:
!git clone https://github.com/seyonechithrananda/bert-loves-chemistry.git
!pip install transformers
!pip install simpletransformers
!pip install --pre deepchem
!pip install datasets scipy sklearn torch tqdm wandb
%cd /content/bert-loves-chemistry
!wget https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/vocab.txt

***split data  (80/10/10)***

In [None]:
from sklearn.model_selection import train_test_split
import pandas as pd

path_original_dataset = f"https://raw.githubusercontent.com/GalJakob/Toxicity-prediction-WS/main/datasets/original%20datasets/{dataset_name}.csv"
data = pd.read_csv(path_original_dataset)

  if AUGMENTED_CASE == "none augmented":
    X_train, X_test, y_train, y_test = train_test_split(data["smiles"], data["label"], test_size=0.2, random_state=42,shuffle = True)
    X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, test_size=0.5, random_state=42,shuffle = True)
    train_data = pd.DataFrame({"smiles": X_train, "label": y_train})
    test_data = pd.DataFrame({"smiles": X_test, "label": y_test})
    val_data = pd.DataFrame({"smiles": X_val, "label": y_val})

  else:#split only test_data
    X_test, X_val, y_test, y_val = train_test_split(test_data["smiles"], test_data["label"], test_size=0.5, random_state=42,shuffle = True)
    test_data = pd.DataFrame({"smiles": X_test, "label": y_test})
    val_data = pd.DataFrame({"smiles": X_val, "label": y_val})



***model builder***

In [None]:
from simpletransformers.classification import ClassificationModel, ClassificationArgs
import torch
import sklearn
model_args = ClassificationArgs()
model_args.train_batch_size = 16
model_args.evaluate_during_training = WITH_HYPER_PARAMS
model_args.evaluate_during_training_silent = False
model_args.evaluate_during_training_steps = -1
model_args.save_eval_checkpoints = False
model_args.save_model_every_epoch = False
model_args.learning_rate = 0.0000243
model_args.manual_seed = 4
model_args.no_cache = True
model_args.num_train_epochs = 35
model_args.overwrite_output_dir = True
model_args.reprocess_input_data = True
model_args.output_dir = "default_output"
model_args.best_model_dir = "default_output/best_model"
model_args.auto_weights = True


model = ClassificationModel('roberta',
                            'seyonec/PubChem10M_SMILES_BPE_396_250',
                            use_cuda = torch.cuda.is_available(),
                            args=model_args)

***hyperparameter tuning ***

In [None]:
### hyperparameter imports ###
from datasets import load_dataset, load_metric
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments)
import wandb

### chemberta imports ###
from rdkit import Chem
from simpletransformers.classification import ClassificationModel, ClassificationArgs

sweep_config = {
    "name": "chemBerta",
    "method": "bayes",
    "metric": {"name": "auprc", "goal": "maximize"},
    "parameters": {
        "num_train_epochs": {"min": 10, "max": 30},
        "learning_rate": {"min": 0.0000001, "max": 0.001},
         "batch_size": {"values": [32,16]},
    },
    "early_terminate": {"type": "hyperband", "min_iter": 6,},
}
sweep_id = wandb.sweep(sweep_config, project="chemBerta")

In [None]:
import sklearn
from sklearn.metrics import accuracy_score
def train_for_hyper_params():
  wandb.init()
  model = ClassificationModel('roberta',
                            'seyonec/PubChem10M_SMILES_BPE_396_250',
                            use_cuda = torch.cuda.is_available(),
                            args=model_args,
                            sweep_config=wandb.config)

  model.train_model(train_data, eval_df=val_data,
                    accuracy=lambda truth, predictions: accuracy_score(truth, [round(p) for p in predictions]) )
  wandb.join()

wandb.agent(sweep_id,train_for_hyper_params)

***training***

In [None]:
#training if wanted
if WITH_TRAINING:
  model.train_model(train_data)

***augmentation builder code, essential for majority vote code***


In [None]:
from rdkit import Chem
import numpy as np
import threading

class Iterator(object):
    """Abstract base class for data iterators.

    # Arguments
        n: Integer, total number of samples in the dataset to loop over.
        batch_size: Integer, size of a batch.
        shuffle: Boolean, whether to shuffle the data between epochs.
        seed: Random seeding for data shuffling.
    """

    def __init__(self, n, batch_size, shuffle, seed):
        self.n = n
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.batch_index = 0
        self.total_batches_seen = 0
        self.lock = threading.Lock()
        self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
        if n < batch_size:
            raise ValueError('Input data length is shorter than batch_size\nAdjust batch_size')

    def reset(self):
        self.batch_index = 0

    def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
        # Ensure self.batch_index is 0.
        self.reset()
        while 1:
            if seed is not None:
                np.random.seed(seed + self.total_batches_seen)
            if self.batch_index == 0:
                index_array = np.arange(n)
                if shuffle:
                    index_array = np.random.permutation(n)

            current_index = (self.batch_index * batch_size) % n
            if n > current_index + batch_size:
                current_batch_size = batch_size
                self.batch_index += 1
            else:
                current_batch_size = n - current_index
                self.batch_index = 0
            self.total_batches_seen += 1
            yield (index_array[current_index: current_index + current_batch_size],
                   current_index, current_batch_size)

    def __iter__(self):
        # Needed if we want to do something like:
        # for x, y in data_gen.flow(...):
        return self

    def __next__(self, *args, **kwargs):
        return self.next(*args, **kwargs)




class SmilesIterator(Iterator):
    """Iterator yielding data from a SMILES array.

    # Arguments
        x: Numpy array of SMILES input data.
        y: Numpy array of targets data.
        smiles_data_generator: Instance of `SmilesEnumerator`
            to use for random SMILES generation.
        batch_size: Integer, size of a batch.
        shuffle: Boolean, whether to shuffle the data between epochs.
        seed: Random seed for data shuffling.
        dtype: dtype to use for returned batch. Set to keras.backend.floatx if using Keras
    """

    def __init__(self, x, y, smiles_data_generator,
                 batch_size=32, shuffle=False, seed=None,
                 dtype=np.float32
                 ):
        if y is not None and len(x) != len(y):
            raise ValueError('X (images tensor) and y (labels) '
                             'should have the same length. '
                             'Found: X.shape = %s, y.shape = %s' %
                             (np.asarray(x).shape, np.asarray(y).shape))

        self.x = np.asarray(x)

        if y is not None:
            self.y = np.asarray(y)
        else:
            self.y = None
        self.smiles_data_generator = smiles_data_generator
        self.dtype = dtype
        super(SmilesIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)

    def next(self):
        """For python 2.x.

        # Returns
            The next batch.
        """
        # Keeps under lock only the mechanism which advances
        # the indexing of each batch.
        with self.lock:
            index_array, current_index, current_batch_size = next(self.index_generator)
        # The transformation of images is not under thread lock
        # so it can be done in parallel
        batch_x = np.zeros(tuple([current_batch_size] + [ self.smiles_data_generator.pad, self.smiles_data_generator._charlen]), dtype=self.dtype)
        for i, j in enumerate(index_array):
            smiles = self.x[j:j+1]
            x = self.smiles_data_generator.transform(smiles)
            batch_x[i] = x

        if self.y is None:
            return batch_x
        batch_y = self.y[index_array]
        return batch_x, batch_y


class SmilesEnumerator(object):
    """SMILES Enumerator, vectorizer and devectorizer

    #Arguments
        charset: string containing the characters for the vectorization
          can also be generated via the .fit() method
        pad: Length of the vectorization
        leftpad: Add spaces to the left of the SMILES
        isomericSmiles: Generate SMILES containing information about stereogenic centers
        enum: Enumerate the SMILES during transform
        canonical: use canonical SMILES during transform (overrides enum)
    """
    def __init__(self, charset = '@C)(=cOn1S2/H[N]\\', pad=120, leftpad=True, isomericSmiles=True, enum=True, canonical=False):
        self._charset = None
        self.charset = charset
        self.pad = pad
        self.leftpad = leftpad
        self.isomericSmiles = isomericSmiles
        self.enumerate = enum
        self.canonical = canonical

    @property
    def charset(self):
        return self._charset

    @charset.setter
    def charset(self, charset):
        self._charset = charset
        self._charlen = len(charset)
        self._char_to_int = dict((c,i) for i,c in enumerate(charset))
        self._int_to_char = dict((i,c) for i,c in enumerate(charset))

    def fit(self, smiles, extra_chars=[], extra_pad = 5):
        """Performs extraction of the charset and length of a SMILES datasets and sets self.pad and self.charset

        #Arguments
            smiles: Numpy array or Pandas series containing smiles as strings
            extra_chars: List of extra chars to add to the charset (e.g. "\\\\" when "/" is present)
            extra_pad: Extra padding to add before or after the SMILES vectorization
        """
        charset = set("".join(list(smiles)))
        self.charset = "".join(charset.union(set(extra_chars)))
        self.pad = max([len(smile) for smile in smiles]) + extra_pad

    def randomize_smiles(self, smiles):
        """Perform a randomization of a SMILES string
        must be RDKit sanitizable"""
        m = Chem.MolFromSmiles(smiles)
        ans = list(range(m.GetNumAtoms()))
        np.random.shuffle(ans)
        nm = Chem.RenumberAtoms(m,ans)
        return Chem.MolToSmiles(nm, canonical=self.canonical, isomericSmiles=self.isomericSmiles)

    def transform(self, smiles):
        """Perform an enumeration (randomization) and vectorization of a Numpy array of smiles strings
        #Arguments
            smiles: Numpy array or Pandas series containing smiles as strings
        """
        one_hot =  np.zeros((smiles.shape[0], self.pad, self._charlen),dtype=np.int8)

        if self.leftpad:
            for i,ss in enumerate(smiles):
                if self.enumerate: ss = self.randomize_smiles(ss)
                l = len(ss)
                diff = self.pad - l
                for j,c in enumerate(ss):
                    one_hot[i,j+diff,self._char_to_int[c]] = 1
            return one_hot
        else:
            for i,ss in enumerate(smiles):
                if self.enumerate: ss = self.randomize_smiles(ss)
                for j,c in enumerate(ss):
                    one_hot[i,j,self._char_to_int[c]] = 1
            return one_hot


    def reverse_transform(self, vect):
        """ Performs a conversion of a vectorized SMILES to a smiles strings
        charset must be the same as used for vectorization.
        #Arguments
            vect: Numpy array of vectorized SMILES.
        """
        smiles = []
        for v in vect:
            #mask v
            v=v[v.sum(axis=1)==1]
            #Find one hot encoded index with argmax, translate to char and join to string
            smile = "".join(self._int_to_char[i] for i in v.argmax(axis=1))
            smiles.append(smile)
        return np.array(smiles)

if __name__ == "__main__":
    smiles = np.array([ "CCC(=O)O[C@@]1(CC[NH+](C[C@H]1CC=C)C)c2ccccc2",
                        "CCC[S@@](=O)c1ccc2c(c1)[nH]/c(=N/C(=O)OC)/[nH]2"]*10
                        )
    #Test canonical SMILES vectorization
    sm_en = SmilesEnumerator(canonical=True, enum=False)
    sm_en.fit(smiles, extra_chars=["\\"])
    v = sm_en.transform(smiles)
    transformed = sm_en.reverse_transform(v)
    if len(set(transformed)) > 2: print("Too many different canonical SMILES generated")

    #Test enumeration
    sm_en.canonical = False
    sm_en.enumerate = True
    v2 = sm_en.transform(smiles)
    transformed = sm_en.reverse_transform(v2)
    if len(set(transformed)) < 3: print("Too few enumerated SMILES generated")

    #Reconstruction
    reconstructed = sm_en.reverse_transform(v[0:5])
    for i, smile in enumerate(reconstructed):
        if smile != smiles[i]:
            print("Error in reconstruction %s %s"%(smile, smiles[i]))
            break

    #test Pandas
    import pandas as pd
    df = pd.DataFrame(smiles)
    v = sm_en.transform(df[0])
    if v.shape != (20, 52, 18): print("Possible error in pandas use")

    #BUG, when batchsize > x.shape[0], then it only returns x.shape[0]!
    #Test batch generation
    sm_it = SmilesIterator(smiles, np.array([1,2]*10), sm_en, batch_size=10, shuffle=True)
    X, y = sm_it.next()
    if sum(y==1) - sum(y==2) > 1:
        print("Unbalanced generation of batches")
    if len(X) != 10: print("Error in batchsize generation")



***eval by majority(both augmented) functions***

In [None]:
from rdkit import Chem
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import numpy as np
from sklearn import metrics

def get_idx_of_matched_smile(list_of_dict_smiles,curr_smile):
  ''' returns the idx of the smile of different smile but same molecule,
    returns None if not found'''
  if len(list_of_dict_smiles) == 0:
    return None

  for idx in range(len(list_of_dict_smiles)):
    curr_list_smile = list(list_of_dict_smiles[idx].keys())[0]
    curr_smile_obj = Chem.MolFromSmiles(curr_smile)
    curr_list_smile_obj = Chem.MolFromSmiles(curr_list_smile)

    if curr_smile_obj.HasSubstructMatch(curr_list_smile_obj) and curr_list_smile_obj.HasSubstructMatch(curr_smile_obj): #smiles are same molecule
      return idx

  return None


def build_list_dict_of_smiles(predictions,smiles,true_labels):
  ''' returns the list: [{"randmon_smile":{"count_all":#,"count_correct":#,"true_label":0/1  }},{}]   '''
  list_of_dict_smiles=[]

  for idx in range(len(predictions)):
    curr_smile = smiles[idx]
    matched_idx = get_idx_of_matched_smile(list_of_dict_smiles,curr_smile)

    if matched_idx == None: #molecule not in list->add it
      one_for_correct = 0
      if predictions[idx] == true_labels[idx]:
        one_for_correct = 1
      list_of_dict_smiles.append({f"{curr_smile}":{"count_all":1,"count_correct":one_for_correct,"true_label": true_labels[idx] }})

    else: #molecule in list
      curr_obj_mol = list_of_dict_smiles[matched_idx]
      smile0 = list(curr_obj_mol.keys())[0]
      curr_obj_mol[smile0]["count_all"] = curr_obj_mol[smile0]["count_all"] +1
      if predictions[idx] == true_labels[idx]:
        curr_obj_mol[smile0]["count_correct"] = curr_obj_mol[smile0]["count_correct"] +1
  return list_of_dict_smiles

def ensure_dup_is_not_orginial(original_SMILE,dup_SMILE,sme):
  counter = 0
  while(True):
    if original_SMILE != dup_SMILE:
      return dup_SMILE
    dup_SMILE = sme.randomize_smiles(original_SMILE)
    counter+=1
    if counter > 1000:
    # raise Exception("Something wrong with SMILE duplicator")
      print("over 1000 attempts to duplicate")
      return dup_SMILE

def evaluate_by_majority(predictions_of_dup_SMILE,true_label,dup_count,threshold):
  '''gets predictions for specific SMILE duplicates and returns the prediction for original SMILE '''
  cnt_correct = 0
  for prediction in predictions_of_dup_SMILE:
    if prediction == true_label:
      cnt_correct+=1

  if ((cnt_correct/dup_count) >= threshold ) and true_label == 1:
      return 1
  if ((cnt_correct/dup_count) >= threshold ) and true_label == 0:
      return 0
  if ((cnt_correct/dup_count) < threshold ) and true_label == 1:
      return 0
  if ((cnt_correct/dup_count) < threshold ) and true_label == 0:
      return 1

def predictions_by_majority(model,test_data):
  '''duplicates some of the SMILES and then lets model preform prediction on each duplicated SMILE and combines
   them to a prediction by some threshold and finally to a list'''
  threshold = 0.6
  predictions = []
  predictions_of_dup_SMILE=[]
  dup_SMILEs_list=[]
  sme = SmilesEnumerator()
  true_labels = test_data["label"].values.tolist()
  dict_of_dup_weights= {'label_1':{'dups':[10],'probs':[1]},
                        'label_0':{'dups':[10],'probs':[1]},}

  for i in range(len(true_labels)):
    original_SMILE = test_data['smiles'][i]
    num_of_duplicates = None

    if test_data['label'][i]  == 1:
        num_of_duplicates = np.random.choice(dict_of_dup_weights['label_1']['dups'], p=dict_of_dup_weights['label_1']['probs'])
        for idx_of_dup in range(num_of_duplicates):
            dup_SMILE = sme.randomize_smiles(original_SMILE)
            if dup_SMILE == original_SMILE:
              dup_SMILE = ensure_dup_is_not_orginial(original_SMILE,dup_SMILE,sme)
            dup_SMILEs_list.append(dup_SMILE)

    else:
        num_of_duplicates = np.random.choice(dict_of_dup_weights['label_0']['dups'], p=dict_of_dup_weights['label_0']['probs'])
        for idx_of_dup in range(num_of_duplicates):
            dup_SMILE = sme.randomize_smiles(original_SMILE)
            if dup_SMILE == original_SMILE:
              dup_SMILE = ensure_dup_is_not_orginial(original_SMILE,dup_SMILE,sme)
            dup_SMILEs_list.append(original_SMILE)

    predictions_of_dup_SMILE = list(model.predict(dup_SMILEs_list)[0])
    prediction_for_original_SMILE = evaluate_by_majority(predictions_of_dup_SMILE,test_data['label'][i],num_of_duplicates,threshold)
    predictions.append(prediction_for_original_SMILE)
    dup_SMILEs_list=[]

  accuracy = metrics.accuracy_score(true_labels,predictions)
  roc_auc = metrics.roc_auc_score(true_labels,predictions)
  precision_list, recall_list, thresholds = metrics.precision_recall_curve(true_labels,predictions)
  pr_auc = metrics.auc(recall_list, precision_list)
  precision = metrics.precision_score(true_labels,predictions)
  recall = metrics.recall_score(true_labels,predictions)
  tn, fp, fn, tp = metrics.confusion_matrix(true_labels, predictions).ravel()

  result = {'acc':accuracy,'auroc':roc_auc,'auprc':pr_auc,'precision':precision,'recall':recall,
            'tp':tp,'fp':fp,'fn':fn,'tn':tn,'threshold':threshold,'dict_of_dup_weights':dict_of_dup_weights}
  return result


***evaluate and write results to files***

In [None]:
import torch
from rdkit import Chem
import pickle
import sklearn
from sklearn import metrics
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

result = None
model_outputs = None
wrong_predictions = None
loaded_model = None
predictions=[]
true_labels = list(test_data["label"])

if not (AUGMENTED_CASE == "both augmented"):
  result, model_outputs, wrong_predictions = model.eval_model(test_data, acc=sklearn.metrics.accuracy_score)
  predictions = list(model.predict(list(test_data["smiles"]))[0])

  result["precision"] = metrics.precision_score(true_labels,predictions)

  result["recall"]= metrics.recall_score(true_labels,predictions)

else: # evaluate by majority vote

  if not WITH_TRAINING:
    model_pkl_file = f"/content/drive/MyDrive/Toxicity prediction WS/models/{AUGMENTED_CASE}/{MODEL}/{dataset_name}_model.pkl"

    if device.type == "cuda":
      loaded_model = pickle.load(open(model_pkl_file, 'rb')) #IF CUDE AVAILABLE
    else:
       loaded_model = torch.load(model_pkl_file, map_location=device)
    result = predictions_by_majority(loaded_model,test_data)

  else:
    result = predictions_by_majority(model,test_data)


result['f1_score'] = 2 * (result['precision'] *result['recall']) / (result['recall'] + result['precision'])

print(result)

with open(f"/content/drive/MyDrive/Toxicity prediction WS/results/{AUGMENTED_CASE}/{MODEL}/{dataset_name}_results.txt", 'w') as resFile:
    resFile.write(f"results for {ds_test} :\n")
    resFile.write(f"accuracy: {result['acc']} \n")
    resFile.write(f"precision: {result['precision']} \n")
    resFile.write(f"recall: {result['recall']} \n")
    resFile.write(f"Area under the ROC curve: {result['auroc']}\n")
    resFile.write(f"Area under the PR curve: {result['auprc']}\n")
    resFile.write(f"confusion matrix:  \n")
    resFile.write(f"true positive :{result['tp']},false positive:{result['fp']} \n")
    resFile.write(f"false negative :{result['fn']},true negative:{result['tn']} \n")
    resFile.write(f"F1 Score: {result['f1_score']}\n")
    if AUGMENTED_CASE == "both augmented":
      resFile.write(f"threshold: {result['threshold']} \n")
      resFile.write(f"threshold: {result['dict_of_dup_weights']} \n")



***save the model as a pickle file***

In [None]:
import pickle
model_pkl_file = f"../drive/MyDrive/Toxicity prediction WS/models/{AUGMENTED_CASE}/{MODEL}/{dataset_name}_model.pkl"

with open(model_pkl_file, 'wb') as file:
   pickle.dump(model, file)



***load the model as a pickle file and predict with example***


In [None]:
import pickle
import sklearn
model_pkl_file = f"/content/drive/MyDrive/Toxicity prediction WS/models/{AUGMENTED_CASE}/{MODEL}/{dataset_name}_model.pkl"

loaded_model = pickle.load(open(model_pkl_file, 'rb'))
result, model_outputs, wrong_predictions = loaded_model.eval_model(test_data, acc=sklearn.metrics.accuracy_score)
print(result)