In [None]:
%load_ext autoreload
%autoreload 2

# Multitask Classification with SELFIES

In [None]:
from dotenv import load_dotenv
import os
from Code.Utils.util_methods import NNUtils
import pandas as pd
import deepchem as dc
import logging
import numpy as np
from rdkit import Chem, DataStructs
import selfies as sf
from tqdm import tqdm
import deepchem as dc
import json


print(f"Current working directory: {os.getcwd()}")

base = NNUtils.find_project_root(os.getcwd())
print(f"Project root found: {base}")

load_dotenv(f'{base}/.env')

In [None]:
X_train = pd.read_pickle(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/{os.getenv("X_TRAIN")}')#.loc[:10000]
X_test = pd.read_pickle(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/{os.getenv("X_TEST")}')#.loc[:10000]
y_train = pd.read_pickle(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/{os.getenv("Y_TRAIN")}')#.loc[:10000]
y_test = pd.read_pickle(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/{os.getenv("Y_TEST")}')#.loc[:10000]
#X = pd.read_pickle(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/X.pkl')
#y = pd.read_pickle(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/y.pkl')
selfies_X_test = pd.read_pickle(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/selfies_X_test.pkl')#.loc[:10000]

In [None]:
X_train

In [None]:
y_train

Create a subset

In [None]:
print(f'The size of the training set is: {X_train.shape[0]}')

In [None]:
# transform into deepchem datasets
train_dataset = dc.data.NumpyDataset(X_train, y_train)
test_dataset = dc.data.NumpyDataset(X_test, y_test)
#ds = dc.data.DiskDataset.from_numpy(X.values, y.values)

In [None]:
print(f'train: {train_dataset}')
print(f'test: {test_dataset}')

In [None]:
with open(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/selfies_group_dict.json', 'r') as json_file:
    selfies_group_dict = json.load(json_file)
selfies_group_dict

Train a multitask model

In [None]:
class EpochCallback:
    def __init__(self, steps_per_epoch):
        """
        Initialize the callback with the number of steps per epoch.
        """
        self.steps_per_epoch = steps_per_epoch
        self.current_step = 0  # Tracks the global step count
        self.current_epoch = 0  # Tracks the epoch count

    def __call__(self, model, step):
        """
        This method is called at the end of each training step.
        """
        self.current_step += 1
        # Check if the current step marks the end of an epoch
        if self.current_step % self.steps_per_epoch == 0:
            self.current_epoch += 1
            print(f"Epoch {self.current_epoch} completed at step {self.current_step}.")

In [None]:
from tqdm import tqdm

class ProgressBarCallback:
    def __init__(self, steps_per_epoch, total_epochs):
        """
        Initialize the progress bar callback.
        """
        self.steps_per_epoch = steps_per_epoch
        self.total_epochs = total_epochs
        self.current_step = 0
        self.current_epoch = 0
        self.pbar = None  # Placeholder for the progress bar

    def __call__(self, model, step):
        """
        Update the progress bar at each step.
        """
        if self.pbar is None:  # Initialize the progress bar at the start
            self.pbar = tqdm(total=self.steps_per_epoch * self.total_epochs, desc="Training Progress", unit="step")

        # Update the progress bar
        self.pbar.update(1)
        self.current_step += 1

        # Check if an epoch is completed
        if self.current_step % self.steps_per_epoch == 0:
            self.current_epoch += 1
            print(f"Epoch {self.current_epoch}/{self.total_epochs} completed.")

    def close(self):
        """
        Close the progress bar when training is done.
        """
        if self.pbar is not None:
            self.pbar.close()


In [None]:
X_train.shape[0]

In [None]:
dc.utils.logger.setLevel(logging.INFO)

batch_size = 50 #50 by defalut
epochs = 50

model = dc.models.MultitaskClassifier(
    n_tasks=y_train.shape[1],
    n_features=int(os.getenv('MAX_MASS')),
    layer_sizes=[1000, 1000, 1000], #3000, 1000, #3000, 2000
    dropouts=0.1
    #activation_fns=['relu', 'relu', 'sigmoid'],
    #learning_rate=0.001,
    #batch_size = batch_size
)
steps_per_epoch = X_train.shape[0] // batch_size
epoch_callback = EpochCallback(steps_per_epoch)
progress_bar_callback = ProgressBarCallback(steps_per_epoch, epochs)

model.fit(train_dataset, nb_epoch=epochs, callbacks=[progress_bar_callback])

In [None]:
# # Initialize your DeepChem model
# model = dc.models.MultitaskClassifier(
#     n_tasks=y_train.shape[1],
#     n_features=int(os.getenv('MAX_MASS')),
#     batch_size=128  # Example batch size
# )

# # Wrap the model with DCLightningModule
# lit_model = DCLightningModule(model)

# # Prepare your dataset
# train_dataset_module = DCLightningDatasetModule(train_dataset, batch_size=128, collate_fn=collate_dataset_wrapper)

# # Initialize the PyTorch Lightning trainer with GPU settings
# trainer = pl.Trainer(max_epochs=10, devices=1, accelerator='gpu', strategy='ddp_notebook')  # Adjust devices as needed

# # Train the model
# trainer.fit(lit_model, train_dataset_module)

In [None]:
progress_bar_callback.close()

In [None]:
predictions = model.predict(test_dataset)

In [None]:
predictions

In [None]:
from statistics import mean, stdev

# NOT CORRECT YET!!
# THE [nop] SHOULD NOT BE INCLUDED IN THE EVALUATION !!!

tp=[] # true positive
fn=[] # false negative
fp=[] # false positive
pr=[] # total number of predicted bits
tp_p=[] # true pos %
fp_p=[] # false pos %

cutoff = 0.5    # predicted >= 0.5 will turn bit=1

for q in tqdm(range(len(test_dataset)), desc='Loop over all test molecules'):   # loop over all test molecules

  # get predicted fingerprint of molecule q
  pred = []
  for i in predictions[q]:
    if i[1] >= cutoff:
      pred.append(1)
    else:
      pred.append(0)

  # get real fingerprint of molecule q
  real = test_dataset.y[q]

  bit = 0
  a=0
  b=0
  c=0
  d=0
  e=0

  for i in range(int(os.getenv("ENCODING_BITS"))*int(os.getenv("MAX_SELFIES_LENGTH"))):
    if real[i]==1 and pred[i]==1:     # true pos (correct prediction)
      a=a+1
    if real[i]==1 and pred[i]==0:     # false neg (missed)
      b=b+1
    if real[i]==0 and pred[i]==1:     # false pos (not correct)
      c=c+1
    if real[i]==1: # count number of 'on-bits'
      d=d+1
    if pred[i]==1: # count number of predicted 'on-bits'
      e=e+1
  
  epsilon = 10e-7
  
  tp.append(a)  # true pos
  fn.append(b)  # false neg
  fp.append(c)  # false pos
  pr.append(e)  # number of predicted on-bits
  fp_p.append(int(c/(e+epsilon)*100)) # false pos / predicted on-bits * 100%
  tp_p.append(int(a/(d+epsilon)*100)) # true pos / real number on-bits * 100%

# % True positive average, stdev and cv% for all test molecules
avg = int (mean(tp_p))
sd = int (stdev(tp_p))
cv = int (sd/avg*100)
print (f'BITWISE EVALUATION OF TEST_DATASET CONTAINING: {len(test_dataset)} MOLECULES')
print (f'--------------------------------------------------------------------')
print (f'TRUE POS:    AVG={avg}%    STDEV={sd}    CV%={cv}')

# % False positive average, stdev and cv% for all test molecules
avg = int (mean(fp_p))
sd = int (stdev(fp_p))
cv = int (sd/avg*100)
print (f'FALSE POS:   AVG={avg}%    STDEV={sd}    CV%={cv}')

In [None]:
# Create a dictionairy (itos) to transfer the hot-encoding array into a array of SELFIES and SMILES
itos={}
c=0
for i in selfies_group_dict:
  itos[selfies_group_dict[i]]=i
itos

In [None]:
# Evaluation whole test set
# check if predicted smiles == real smiles

cutoff = 0.5
hit = 0
score=[]
to_print = []

columns = ['SMILES_ID', 'Original_SMILES', 'Predicted_SMILES', 'Original_SELFIES', 'Predicted_SELFIES', 'Fingerprint_SIMILARITY', 'SELFIES_SIMILARITY']
res_df = pd.DataFrame(columns=columns)

for test_compound_id in tqdm(range(len(test_dataset)), desc='Evaluate all test molecules'):   # loop over all test molecules
  # create hot-encoding array of molecule id
  pred = []
  for i in predictions[test_compound_id]:
    if i[1] >= cutoff:
      pred.append(1)
    else:
      pred.append(0)

  sfc =''
  a = len(pred)   # 5096
  b = len(itos)   # 56
  c = int (a/b)   # 91

  for i in range(c):
    for q in range(b):
      if pred[i*b+q]==1 and itos[q]!='[nop]':
        # print (itos[q])
        sfc = sfc + (itos[q])
  sf_m = sf.decoder(sfc)

  # real molecule
  #compound_id = test_dataset.ids[test_compound_id]
  real_selfies = selfies_X_test.loc[test_compound_id]
  smile_id = sf.decoder(real_selfies)

    # Convert SMILES to RDKit molecule object
  mol_a = Chem.MolFromSmiles(sf_m)
  mol_b = Chem.MolFromSmiles(smile_id)
  
  # Only proceed if both molecules are valid
  if mol_a is not None and mol_b is not None:
    a = Chem.RDKFingerprint(mol_a)
    b = Chem.RDKFingerprint(mol_b)
    score.append(DataStructs.FingerprintSimilarity(a, b, metric=DataStructs.DiceSimilarity))
    #to_print.append(f'{test_compound_id} ------- {sf_m} -------- {smile_id} {DataStructs.FingerprintSimilarity(a,b, metric=DataStructs.DiceSimilarity)}')
    fingerprint_similarity = DataStructs.FingerprintSimilarity(a,b, metric=DataStructs.DiceSimilarity)
    selfies_similarity = NNUtils.selfies_similarity(real_selfies, sfc)
    res_df.loc[test_compound_id] = [test_compound_id, smile_id, sf_m, real_selfies, sfc, fingerprint_similarity, selfies_similarity]
    
    if sf_m == smile_id:
      hit = hit + 1
      
  else:
    #print("Invalid molecule found, skipping.")
    pass

res_df


In [None]:
res_df['SELFIES_SIMILARITY'].mean()

In [None]:
score = np.array(score)
sum_score_1 = sum(score>=1)
sum_score_09 = sum(score>=0.9)
sum_score_06 = sum(score>=0.6)

print (f'Correct smiles predictions: {hit} (={int(hit/len(test_dataset.X)*100)}%). Test set contains in total {len(test_dataset.X)} compounds.')
print (f'Tanimoto similarity >= 1.0: {sum_score_1} (={int(sum_score_1/len(test_dataset.X)*100)}%). Test set contains in total {len(test_dataset.X)} compounds.')
print (f'Tanimoto similarity >= 0.9: {sum_score_09} (={int(sum_score_09/len(test_dataset.X)*100)}%). Test set contains in total {len(test_dataset.X)} compounds.')
print (f'Tanimoto similarity >= 0.6: {sum_score_06} (={int(sum_score_06/len(test_dataset.X)*100)}%). Test set contains in total {len(test_dataset.X)} compounds.')