## Import Libraries ##

In [1]:
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.ML.Cluster import Butina
from typing import List, Tuple, Dict
from sklearn.model_selection import train_test_split


In [2]:
import os
import time
import copy
import warnings
warnings.filterwarnings('ignore')

# Abstract classes
from abc import ABC, abstractmethod

# Chemprop and DeepChem
import chemprop
from chemprop import data, featurizers, models
import deepchem as dc

from sklearn.utils import resample

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

# Hugging Face Transformers for ChemBERTa
from transformers import (
    RobertaTokenizer, RobertaModel, RobertaConfig, 
    AdamW, get_linear_schedule_with_warmup, BertModel
)

# Scikit-learn
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import roc_auc_score

# Pandas and NumPy
import pandas as pd
import numpy as np

# RDKit
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Draw


No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
2024-10-21 14:14:57.721581: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-21 14:14:57.734281: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-21 14:14:57.738219: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-21 14:14:57.748942: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow w

In [3]:
def PRINT() -> None: print(f"{'-'*80}\nDone\n{'-'*80}")
def PRINTC() -> None: print(f"{'-'*80}")
def PRINTM(M) -> None: print(f"{'-'*80}\n{M}\n{'-'*80}")

## Verify GPU Availability ##

In [4]:
!nvidia-smi

Mon Oct 21 14:15:00 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX 6000 Ada Gene...    On  |   00000000:A1:00.0 Off |                  Off |
| 30%   36C    P8             17W /  300W |       2MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

For this task, we'll use the BGU cluster GPU `NVIDIA RTX 6000 Ada Generation` to achieve better performance during the training of our pre-trained and fine-tuned models, allowing for more efficient processing of large datasets and complex computations.

In [5]:
if torch.cuda.is_available():
    PRINTM(f"GPU is available.")
    device = "cuda"
else:
    PRINTM(f"GPU is not available. Using CPU instead.")
    device = "cpu"

--------------------------------------------------------------------------------
GPU is available.
--------------------------------------------------------------------------------


In [6]:
PRINTM(f"PyTorch version: {torch.__version__}")
PRINTM(f"CUDA available: {torch.cuda.is_available()}")
PRINTM(f"CUDA version:  {torch.version.cuda}")
print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No CUDA'}")

--------------------------------------------------------------------------------
PyTorch version: 2.3.1+cu121
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CUDA available: True
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CUDA version:  12.1
--------------------------------------------------------------------------------
CUDA device: NVIDIA RTX 6000 Ada Generation


In [7]:
def convert_uniprot_ids(dataset, mapping_df):
    # Create a dictionary from the mapping dataframe
    mapping_dict = mapping_df.set_index('From')['Entry'].to_dict()

    # Map the uniprot_id1 and uniprot_id2 columns to their respective Entry values
    dataset['uniprot_id1'] = dataset['uniprot_id1'].map(mapping_dict)
    dataset['uniprot_id2'] = dataset['uniprot_id2'].map(mapping_dict)
    return dataset.drop_duplicates()

In [8]:
class PretrainedChempropModel(nn.Module):
    def __init__(self, checkpoints_path, batch_size):
        super(PretrainedChempropModel, self).__init__()
        self.mpnn = self.load_pretrained_model(checkpoints_path)
        self.featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
        self.batch_size = batch_size
        
    def forward(self, smiles):
        # Prepare the data in order to generate embeddings from modulators SMILES
        self.smiles_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smiles]
        self.smiles_dset = data.MoleculeDataset(self.smiles_data, featurizer=self.featurizer)
        self.smiles_loader = data.build_dataloader(self.smiles_dset, batch_size=batch_size, shuffle=False)
        
        embeddings = [
            # Etract the embedding from the last FFN layer, i.e., before the final prediction(thus i=-1)
            self.mpnn.predictor.encode(self.fingerprints_from_batch_molecular_graph(batch, self.mpnn), i=-1) 
            for batch in self.smiles_loader
        ]
        #print(embeddings)
        if not embeddings:
             return torch.empty(0, device=device)
        embeddings = torch.cat(embeddings, 0).to(device)
        return embeddings

    def fingerprints_from_batch_molecular_graph(self, batch, mpnn):
        batch.bmg.to(device)
        H_v = mpnn.message_passing(batch.bmg, batch.V_d)
        H = mpnn.agg(H_v, batch.bmg.batch)
        H = mpnn.bn(H)
        fingerprints = H if batch.X_d is None else torch.cat((H, mpnn.batch.X_d_transform(X_d)), 1)
        return fingerprints

    def load_pretrained_model(self, checkpoints_path):
        mpnn = models.MPNN.load_from_checkpoint(checkpoints_path).to(device)
        return mpnn

In [9]:
class ChemBERTaPT(nn.Module):
    def __init__(self):
        super(ChemBERTaPT, self).__init__()
        self.model_name = "DeepChem/ChemBERTa-77M-MTR"
        self.chemberta = RobertaModel.from_pretrained(self.model_name)

    def forward(self, input_ids, attention_mask):
        bert_output = self.chemberta(input_ids=input_ids, attention_mask=attention_mask)
        return bert_output[1]

In [None]:
import torch

# Create a tensor 'a' with specified values for testing
# First embedding: [1, 1, 1, 1] repeated 16 times (to make 64 samples of size 4)
# Second embedding: [0, 0, 0, 0] repeated 16 times
first_embedding = torch.tensor([[1, 1, 1, 1]] * 16)  # shape: [16, 4]
second_embedding = torch.tensor([[0, 0, 0, 0]] * 16)  # shape: [16, 4]

# Concatenate both embeddings along the second dimension
a = torch.cat((first_embedding, second_embedding), dim=1)  # shape: [16, 8]
print(f'Vector "a" shape -> {a.shape}')

# Split the tensor into two parts along the second dimension
a1, a2 = torch.split(a, 4, dim=1)  # shape of each: [16, 4]

print(f'Vector "a1" and "a2" shape -> {a1.shape}')

# Print the results
a

In [10]:
def compute_fingerprints(unique_smiles_list: List[str], radius: int = 2, n_bits: int = 1024) -> Dict[str, AllChem.rdchem.Mol]:
    """Compute Morgan fingerprints for a list of unique SMILES strings."""
    fingerprints = {}
    for smi in unique_smiles_list:
        mol = Chem.MolFromSmiles(smi)
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
        else:
            # Handle invalid SMILES strings by adding a zeroed fingerprint
            zero_mol = Chem.MolFromSmiles('')
            fp = AllChem.GetMorganFingerprintAsBitVect(zero_mol, radius, nBits=n_bits)
        fingerprints[smi] = fp
    return fingerprints

def cluster_fingerprints(fingerprints: Dict[str, AllChem.rdchem.Mol], cutoff: float = 0.6) -> List[Tuple[int]]:
    """Cluster fingerprints using the Butina algorithm."""
    fps = list(fingerprints.values())
    n_fps = len(fps)
    dists = []
    for i in range(1, n_fps):
        sims = AllChem.DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
        dists.extend([1 - x for x in sims])
    clusters = Butina.ClusterData(dists, n_fps, cutoff, isDistData=True)
    return clusters

def custom_butina_splitter(
    df: pd.DataFrame,
    clusters: List[Tuple[int]],
    unique_smiles_list: List[str],
    smiles_col: str = 'smiles',
    frac_train: float = 0.78,
    frac_test: float = 0.22,
    random_state: int = None
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Perform a custom Butina split into train and test sets with randomness."""
    # Set the random seed for reproducibility
    rng = np.random.default_rng(random_state)

    # Shuffle clusters to introduce randomness
    cluster_indices = list(range(len(clusters)))
    rng.shuffle(cluster_indices)

    # Desired split sizes based on molecule counts
    total_molecules = len(df)
    desired_train_size = int(frac_train * total_molecules)
    desired_test_size = total_molecules - desired_train_size  # Ensure total sums up

    # Initialize cumulative counts and splits
    train_smiles = set()
    test_smiles = set()
    train_count = 0
    test_count = 0

    # Iterate over clusters and assign to splits
    for idx in cluster_indices:
        cluster = clusters[idx]
        cluster_smiles = [unique_smiles_list[i] for i in cluster]

        # Count the number of molecules in the cluster (including duplicates in df)
        cluster_molecule_count = df[df[smiles_col].isin(cluster_smiles)].shape[0]

        # Decide where to assign the cluster
        # Compute remaining needed molecules for each split
        remaining_train = desired_train_size - train_count
        remaining_test = desired_test_size - test_count

        # Prepare a dictionary of remaining counts
        remaining = {'train': remaining_train, 'test': remaining_test}

        # Assign to the split with the largest remaining count
        possible_splits = {k: v for k, v in remaining.items() if v > 0}

        if not possible_splits:
            # All splits have reached or exceeded desired sizes
            # Assign to the split with the least total molecules
            total_counts = {'train': train_count, 'test': test_count}
            split = min(total_counts, key=total_counts.get)
        else:
            # Assign to the split needing the most molecules
            split = max(possible_splits, key=possible_splits.get)

        # Assign the cluster to the chosen split
        if split == 'train':
            train_smiles.update(cluster_smiles)
            train_count += cluster_molecule_count
        else:
            test_smiles.update(cluster_smiles)
            test_count += cluster_molecule_count

    # Use isin to get the splits from the original dataframe
    train = df[df[smiles_col].isin(train_smiles)].copy()
    test_combined = df[df[smiles_col].isin(test_smiles)].copy()

    return train, test_combined



In [11]:
class CustomButinaSplitter:
    def __init__(self, smiles_col='smiles', label_col='label', cutoff=0.6):
        self.smiles_col = smiles_col
        self.label_col = label_col
        self.cutoff = cutoff

    def split_dataset(self, df: pd.DataFrame) -> List[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]]:
        unique_smiles = df[self.smiles_col].drop_duplicates().reset_index(drop=True)
        print(f"Number of unique SMILES: {len(unique_smiles)}")

        # Compute fingerprints for unique SMILES (only once)
        unique_smiles_list = unique_smiles.tolist()
        fingerprints_dict = self._compute_fingerprints(unique_smiles_list)

        # Cluster the fingerprints (only once)
        clusters = self._cluster_fingerprints(fingerprints_dict, cutoff=self.cutoff)

        # Prepare for splitting
        splits = []
        for random_seed in [11, 7, 8, 1, 4]:
            print(f"Random Seed: {random_seed}")

            # Perform initial train/test split using Butina clustering
            train, test_combined = self._custom_butina_splitter(
                df,
                clusters,
                unique_smiles_list,
                smiles_col=self.smiles_col,
                frac_train=0.78,
                frac_test=0.22,
                random_state=random_seed
            )

            # Now split test_combined into validation and test sets using stratified split
            labels = test_combined[self.label_col]

            # Split test_combined into validation and test sets (each 50% of test_combined)
            valid, test = train_test_split(
                test_combined,
                test_size=0.5,
                random_state=random_seed,  # Use the same random seed for reproducibility
                stratify=labels
            )

            splits.append((train, valid, test))

            # Verify that the splits do not overlap
            train_indices = set(train.index)
            valid_indices = set(valid.index)
            test_indices = set(test.index)
            assert len(train_indices & valid_indices) == 0
            assert len(train_indices & test_indices) == 0
            assert len(valid_indices & test_indices) == 0

            # Calculate actual split sizes
            total_molecules = len(df)
            actual_train_frac = len(train) / total_molecules
            actual_valid_frac = len(valid) / total_molecules
            actual_test_frac = len(test) / total_molecules

            print(f"Train size: {len(train)} ({actual_train_frac:.2%}), "
                  f"Valid size: {len(valid)} ({actual_valid_frac:.2%}), "
                  f"Test size: {len(test)} ({actual_test_frac:.2%})")

        return splits

    def _compute_fingerprints(self, unique_smiles_list: List[str], radius: int = 2, n_bits: int = 1024) -> Dict[str, AllChem.rdchem.Mol]:
        """Private method to compute Morgan fingerprints for a list of unique SMILES strings."""
        fingerprints = {}
        for smi in unique_smiles_list:
            mol = Chem.MolFromSmiles(smi)
            if mol:
                fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
            else:
                # Handle invalid SMILES strings by adding a zeroed fingerprint
                zero_mol = Chem.MolFromSmiles('')
                fp = AllChem.GetMorganFingerprintAsBitVect(zero_mol, radius, nBits=n_bits)
            fingerprints[smi] = fp
        return fingerprints

    def _cluster_fingerprints(self, fingerprints: Dict[str, AllChem.rdchem.Mol], cutoff: float) -> List[Tuple[int]]:
        """Private method to cluster fingerprints using the Butina algorithm."""
        fps = list(fingerprints.values())
        n_fps = len(fps)
        dists = []
        for i in range(1, n_fps):
            sims = AllChem.DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
            dists.extend([1 - x for x in sims])
        clusters = Butina.ClusterData(dists, n_fps, cutoff, isDistData=True)
        return clusters

    def _custom_butina_splitter(
        self,
        df: pd.DataFrame,
        clusters: List[Tuple[int]],
        unique_smiles_list: List[str],
        smiles_col: str = 'smiles',
        frac_train: float = 0.78,
        frac_test: float = 0.22,
        random_state: int = None
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """Private method to perform a custom Butina split into train and test sets."""
        rng = np.random.default_rng(random_state)

        # Shuffle clusters to introduce randomness
        cluster_indices = list(range(len(clusters)))
        rng.shuffle(cluster_indices)

        total_molecules = len(df)
        desired_train_size = int(frac_train * total_molecules)
        desired_test_size = total_molecules - desired_train_size

        train_smiles = set()
        test_smiles = set()
        train_count = 0
        test_count = 0

        # Iterate over clusters and assign to splits
        for idx in cluster_indices:
            cluster = clusters[idx]
            cluster_smiles = [unique_smiles_list[i] for i in cluster]
            cluster_molecule_count = df[df[smiles_col].isin(cluster_smiles)].shape[0]

            remaining_train = desired_train_size - train_count
            remaining_test = desired_test_size - test_count

            remaining = {'train': remaining_train, 'test': remaining_test}
            possible_splits = {k: v for k, v in remaining.items() if v > 0}

            if not possible_splits:
                total_counts = {'train': train_count, 'test': test_count}
                split = min(total_counts, key=total_counts.get)
            else:
                split = max(possible_splits, key=possible_splits.get)

            if split == 'train':
                train_smiles.update(cluster_smiles)
                train_count += cluster_molecule_count
            else:
                test_smiles.update(cluster_smiles)
                test_count += cluster_molecule_count

        train = df[df[smiles_col].isin(train_smiles)].copy()
        test_combined = df[df[smiles_col].isin(test_smiles)].copy()

        return train, test_combined


In [27]:
# 12/10 - for ablation table
class AbstractModel(ABC, nn.Module):
    def __init__(self):
        super(AbstractModel, self).__init__()
        self.early_stopping_patience = 5
        self.delta = 0.001
        self.butinaSplitter = CustomButinaSplitter()

    @abstractmethod
    def forward(self, cpe, esm, fegs, gae, cbae, morgan_fingerprints, chemical_descriptors):
        pass
          
    def train_val_test_model(self, dataset, num_epochs, optimizer, criterion, 
                        batch_size=32, device='cuda', num_workers=5):
            
            best_val_auc = float('-inf')
            no_improve_epochs = 0
            best_model_state_dict = None
            splits = []
            smiles_df = dataset[["smiles"]].drop_duplicates()
            smiles_col = 'smiles'
            
            # split dataset into 5 folds of (train, val, test) dataframes using custom butina splitter obj.
            splits = self.butinaSplitter.split_dataset(dataset)
                    
            for fold_number, (train_subset, val_subset, test_subset) in enumerate(splits, 1):
                print(f"fold number {fold_number}")
                PRINTC()
                train_df, val_df, test_df = train_subset, val_subset, test_subset
                test_dataset = MoleculeDataset(test_df)
                test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
                
                bootstrap_valid_aucs = []
                bootstrap_test_aucs = []
                                
                for bootsrap in range(5):
                    print(f"bootsrap number: {bootsrap + 1}")
                    PRINTC()
                    seed_train = fold_number*1000 + bootsrap + 1
                    labels_list = train_df['label'].values
                    train_b = resample(train_df, random_state=seed_train, stratify=labels_list)

                    train_dataset = MoleculeDataset(train_b)                    
                    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
                    val_dataset = MoleculeDataset(val_df)                    
                    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

                    last_test_auc = 0  # Initialize the last test AUC for this fold
                    for epoch in range(num_epochs):
                        start_time = time.time()
                        self.train()
                        epoch_loss = 0
                        all_preds = []
                        all_labels = []
                        running_loss = 0.0
                        for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                             batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in train_loader:
                            
                            batch_chemprop_features = batch_chemprop_features.to(device)
                            batch_chemberta_features = batch_chemberta_features.to(device)
                            batch_esm_features = batch_esm_features.to(device)
                            batch_fegs_features = batch_fegs_features.to(device)
                            batch_gae_features = batch_gae_features.to(device)
                            batch_morgan = batch_morgan.to(device)
                            batch_chem_desc = batch_chem_desc.to(device)
                            batch_labels = batch_labels.to(device)
                                
                            optimizer.zero_grad()
                            outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                                           batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc) 
                
                            loss = criterion(outputs.squeeze(), batch_labels)
                            loss.backward()
                            optimizer.step()
                            epoch_loss += loss.item()
                
                            all_labels.extend(batch_labels.cpu().numpy())
                            all_preds.extend(outputs.squeeze().detach().cpu().numpy())
                            
                        train_auc = roc_auc_score(all_labels, all_preds)
                                
                        # Evaluate the model on the validation set
                        all_val_labels = []
                        all_val_outputs = []
                        self.eval()
                        with torch.no_grad():
                            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:
                                
                                batch_chemprop_features = batch_chemprop_features.to(device)
                                batch_chemberta_features = batch_chemberta_features.to(device)
                                batch_esm_features = batch_esm_features.to(device)
                                batch_fegs_features = batch_fegs_features.to(device)
                                batch_gae_features = batch_gae_features.to(device)
                                batch_morgan = batch_morgan.to(device)
                                batch_chem_desc = batch_chem_desc.to(device)
                                batch_labels = batch_labels.to(device)
                                                
                                outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)               
                                
                                all_val_labels.extend(batch_labels.cpu().numpy())
                                all_val_outputs.extend(outputs.squeeze().detach().cpu().numpy())
                                
                        all_val_labels = np.array(all_val_labels)
                        all_val_outputs = np.array(all_val_outputs)
                        
                        # Perform bootstrapping on predictions and labels (validation phase)
                        current_b_aucs = []
                        N_test = all_val_labels.shape[0]
                        for b in range(1000):
                            seed_value = epoch * 1000 + b + (bootsrap+1)*1000  # or any function of your parameters
                            np.random.seed(seed_value)
                            indices = np.random.randint(0, N_test, size=N_test)
                            y_valid_pred_b = all_val_outputs[indices]
                            y_valid_b = all_val_labels[indices]
                            valid_auc = roc_auc_score(y_valid_b, y_valid_pred_b)
                            current_b_aucs.append(valid_auc)
                
                        mean_val_auc = np.mean(current_b_aucs)
                        end_time = time.time()
                        epoch_time = (end_time - start_time) / 60
                        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {(epoch_loss/len(train_loader)):.4f}, Train AUC: {train_auc:.4f}, Mean Validation AUC: {mean_val_auc:.4f}, Epoch Time: {epoch_time:.4f}')

                        # Early stopping logic
                        if mean_val_auc > best_val_auc:
                            best_val_auc = mean_val_auc
                            epochs_without_improvement = 0
                            # Save the best model state dict
                            best_model_state_dict = copy.deepcopy(model.state_dict())
                        else:
                            epochs_without_improvement += 1
            
                        if epochs_without_improvement >= self.early_stopping_patience:
                            print("Early stopping triggered")
                            break
                            
                    # Load the best model in order to evaluate it on the test set
                    self.load_state_dict(best_model_state_dict)
                    
                    all_test_labels = []
                    all_test_outputs = []
                    self.eval()
                    with torch.no_grad():
                        for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                                batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:
                                
                            batch_chemprop_features = batch_chemprop_features.to(device)
                            batch_chemberta_features = batch_chemberta_features.to(device)
                            batch_esm_features = batch_esm_features.to(device)
                            batch_fegs_features = batch_fegs_features.to(device)
                            batch_gae_features = batch_gae_features.to(device)
                            batch_morgan = batch_morgan.to(device)
                            batch_chem_desc = batch_chem_desc.to(device)
                            batch_labels = batch_labels.to(device)
                                                
                            outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                                           batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)               
                                
                            all_test_labels.extend(batch_labels.cpu().numpy())
                            all_test_outputs.extend(outputs.squeeze().detach().cpu().numpy())
                                
                    all_test_labels = np.array(all_test_labels)
                    all_test_outputs = np.array(all_test_outputs)                 

                    # Perform bootstrapping on predictions and labels (test phase)
                    current_b_aucs = []
                    N_test = all_test_labels.shape[0]
                    for b in range(1000):
                        seed_value = epoch * 1000 + b + (bootsrap+1)*1000  # or any function of your parameters
                        np.random.seed(seed_value)
                        indices = np.random.randint(0, N_test, size=N_test)
                        y_test_pred_b = all_test_outputs[indices]
                        y_test_b = all_test_labels[indices]
                        test_auc = roc_auc_score(y_test_b, y_test_pred_b)
                        current_b_aucs.append(test_auc)
                
                    mean_test_auc = np.mean(current_b_aucs)
                    print(f'Bootstrap {bootsrap}, Mean Test AUC: {mean_test_auc:.4f}')

                    # Store the best validation and test AUCs for this bootstrap
                    bootstrap_valid_aucs.append(best_val_auc)
                    bootstrap_test_aucs.append(mean_test_auc)
                
                # Compute mean validation and test AUCs for the current fold
                current_fold_mean_valid_auc = np.mean(bootstrap_valid_aucs)
                current_fold_mean_test_auc = np.mean(bootstrap_test_aucs)
                print(f"Fold {fold_number} Mean Validation AUC: {current_fold_mean_valid_auc:.4f}")
                print(f"Fold {fold_number} Mean Test AUC: {current_fold_mean_test_auc:.4f}")
                all_folds_valid_aucs.append(current_fold_mean_valid_auc)
                all_folds_test_aucs.append(current_fold_mean_test_auc)
            
            PRINTC()               
            print("Final Mean Validation AUC across all folds:", np.mean(all_folds_valid_aucs))
            print("Validation AUCs for all folds:", all_folds_valid_aucs)
            print("Final Mean Test AUC across all folds:", np.mean(all_folds_test_aucs))
            print("Test AUCs for all folds:", all_folds_test_aucs)                                    

    def train_model(self, fold, num_epochs, dataset, optimizer, criterion, 
                    batch_size=32, device='cuda', num_workers=5):
        
        train_dataset = MoleculeDataset(dataset)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        print(f'Start training {fold} for {num_epochs} epochs !')
        for epoch in range(num_epochs):
            start_time = time.time()
            self.train()
            running_loss = 0.0
            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in train_loader:
                # Move tensors to the configured device
                batch_chemprop_features = batch_chemprop_features.to(device)
                batch_chemberta_features = batch_chemberta_features.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                optimizer.zero_grad()
                outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc) 
                loss = criterion(outputs.squeeze(), batch_labels)
                
                train_loss += loss.item()
                all_labels.extend(batch_labels.cpu().numpy())
                all_outputs.extend(outputs.squeeze().cpu().numpy())
        
                loss.backward()
                optimizer.step()

            # Calculate epoch loss & AUC
            train_loss /= len(train_loader)
            train_auc = roc_auc_score(all_labels, all_outputs)
            end_time = time.time()
            epoch_time = (end_time - start_time) / 60
            print(f"Epoch {epoch+1} Time: {epoch_time:.2f},Train Loss: {train_loss:.4f}, Train AUC: {train_auc:.4f}")


        
    def train_val_model(self, fold, num_epochs, dataset, optimizer, criterion, 
                    batch_size=32, device='cuda', num_workers=5):
        best_val_auc = float('-inf')
        no_improve_epochs = 0

        X = dataset.copy().drop(columns=['smiles'])
        ids = dataset['smiles']
        # extract labels for scaffold split (in order to make sure we got balance train & validation sets)
        labels = dataset['label'].values
        dc_dataset = dc.data.DiskDataset.from_numpy(X=X ,y=labels ,w=np.zeros(len(X)),ids=ids)
        splitter = dc.splits.ScaffoldSplitter()
        train_idx, val_idx, _ = splitter.split(dc_dataset, frac_train=0.8, frac_valid=0.2, frac_test=0)

        #train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.2, stratify=labels, shuffle=shuffle)
        
        train_subset = dataset.iloc[train_idx].reset_index(drop=True)
        val_subset = dataset.iloc[val_idx].reset_index(drop=True)

        train_dataset = MoleculeDataset(train_subset)
        val_dataset = MoleculeDataset(val_subset)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        for epoch in range(num_epochs):
            start_time = time.time()
            self.train()
            running_loss = 0.0
            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in train_loader:
                batch_chemprop_features = batch_chemprop_features.to(device)
                batch_chemberta_features = batch_chemberta_features.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                optimizer.zero_grad()
                outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)                 
                loss = criterion(outputs.squeeze(), batch_labels)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()

            val_loss, val_accuracy, val_auc = self.validate_model(val_loader, criterion, device)
            end_time = time.time()
            epoch_time = (end_time - start_time) / 60

            print(f"Epoch {epoch+1} - Validation Loss: {val_loss:.5f}, "
                  f"Validation Accuracy: {val_accuracy:.2f}, Validation AUC: {val_auc:.5f}, Epoch Time: {epoch_time:.2f}")
            # Check whether val_auc > best_val_auc + delta
            if val_auc > best_val_auc + self.delta:
                best_val_auc = val_auc
                train_epoch = epoch+1
                no_improve_epochs = 0 
                print(f"Current best val_auc -> {val_auc:.5f}, at epoch {epoch+1}")
            else:
                no_improve_epochs += 1
                if no_improve_epochs >= self.early_stopping_patience:
                    print(f"Stopping early at epoch {epoch+1}")
                    break

        print(f'Train the model for -> {train_epoch}, best validation auc: {best_val_auc:.5f}')
                
    def test_model(self, test_dataset, criterion, batch_size, device, num_workers):
        test_dataset = MoleculeDataset(test_dataset)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        self.eval()

        test_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []

        with torch.no_grad():
            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in test_loader:
                
                batch_chemprop_features = batch_chemprop_features.to(device)
                batch_chemberta_features = batch_chemberta_features.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)               
                loss = criterion(outputs.squeeze(), batch_labels)
                test_loss += loss.item()

                all_labels.extend(batch_labels.cpu().numpy())
                all_outputs.extend(outputs.squeeze().cpu().numpy())

                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()

        test_loss /= len(test_loader)
        accuracy = correct / total
        test_auc = roc_auc_score(all_labels, all_outputs)
        PRINTC()
        print(f"Test AUC: {test_auc:.4f}")
        PRINTC()
        return round(test_auc, 4)

    def validate_model(self, val_loader, criterion, device):
        self.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []

        with torch.no_grad():
            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:

                batch_chemprop_features = batch_chemprop_features.to(device)
                batch_chemberta_features = batch_chemberta_features.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)                
                loss = criterion(outputs.squeeze(), batch_labels)
                val_loss += loss.item()

                all_labels.extend(batch_labels.cpu().numpy())
                all_outputs.extend(outputs.squeeze().cpu().numpy())

                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()

        val_loss /= len(val_loader)
        accuracy = correct / total
        val_auc = roc_auc_score(all_labels, all_outputs)
        return val_loss, accuracy, val_auc


In [28]:
class DTF(nn.Module):
    def __init__(self, channels=128, r=4):
        super(DTF, self).__init__()
        inter_channels = int(channels // r)

        self.att1 = nn.Sequential(
            nn.Linear(channels, inter_channels),
            nn.BatchNorm1d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Linear(inter_channels, channels),
            nn.BatchNorm1d(channels)
        )

        self.att2 = nn.Sequential(
            nn.Linear(channels, inter_channels),
            nn.BatchNorm1d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Linear(inter_channels, channels),
            nn.BatchNorm1d(channels)
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, fd, fp):
        w1 = self.sigmoid(self.att1(fd + fp))
        fout1 = fd * w1 + fp * (1 - w1)

        w2 = self.sigmoid(self.att2(fout1))
        fout2 = fd * w2 + fp * (1 - w2)
        return fout2

In [29]:
# 20/10 - Final expirements - all features & without finetune and attention
class AUVG_PPI(AbstractModel):
    def __init__(self ,dropout):
        super(AUVG_PPI, self).__init__()
        self.dropout = dropout
        
        # PPI Features MLP layers: (esm, custom, fegs, gae)
        self.esm_mlp = nn.Sequential(
            nn.Linear(in_features=1280 + 1280 , out_features=1750),
            nn.ReLU(),
            nn.BatchNorm1d(1750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1750, out_features=1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1000, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        self.fegs_mlp = nn.Sequential(
            nn.Linear(in_features=578 + 578, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )        

        self.gae_mlp = nn.Sequential(
            nn.Linear(in_features=500 + 500, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        # MLP for ppi_features
        self.ppi_mlp = nn.Sequential(
            nn.Linear(in_features=256 * 3, out_features=512),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )
        
        self.fp_mlp = nn.Sequential(
            nn.Linear(in_features=1200, out_features=1024), 
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        # Morgan fingerprints & chemical descriptors MLP layers
        self.mfp_cd_mlp = nn.Sequential(
            nn.Linear(in_features=1024 + 194, out_features= 750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        self.chemberta_mlp = nn.Sequential(
            nn.Linear(in_features=384, out_features= 256)
        )

        # MLP for smiles_embeddings
        self.smiles_mlp = nn.Sequential(
            nn.Linear(in_features=256 * 3 , out_features= 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        self.additional_layers = nn.Sequential(
            nn.Linear(in_features=256 + 256, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=128, out_features=1),
        )
        
        #self.sigmoid = nn.Sigmoid()

    def forward(self, cpe, esm, fegs, gae, cbae, morgan_fingerprints, chemical_descriptors):
        # Forward pass batch mol graph through pretrained chemprop model in order to get fingerprints embeddings
        # Afterwards, pass the fingerprints through MLP layer
        cp_fingerprints = self.fp_mlp(cpe)
        cbae = self.chemberta_mlp(cbae)
        #chemberta_embeddings = self.chemberta_mlp(chemberta_embeddings)
        mfp_chem_descriptors = torch.cat([morgan_fingerprints, chemical_descriptors], dim=1)
        mfp_chem_descriptors = self.mfp_cd_mlp(mfp_chem_descriptors)
        
        # Concatenate all 3 smiles embeddings along a new dimension (3x384) 
        smiles_embeddings = torch.cat([cp_fingerprints, cbae, mfp_chem_descriptors], dim=1).to(device)  # shape ->> (batch_size, 3*384)
        smiles_embeddings = self.smiles_mlp(smiles_embeddings)

        # Pass all PPI features  through MLP layers, and then pass them all together into another MLP layer
        #ppi_features = proteins.to(device)
        esm_embeddings = self.esm_mlp(esm)
        fegs_embeddings = self.fegs_mlp(fegs)
        gae_embeddings = self.gae_mlp(gae)

        # Concatenate all 3 ppi embeddings along a new dimension (3x512) 
        ppi_embeddings = torch.cat([esm_embeddings, fegs_embeddings, gae_embeddings], dim=1).to(device)  # shape ->> (batch_size, 3*512)
        ppi_features = self.ppi_mlp(ppi_embeddings)

        combined_embeddings = torch.cat([smiles_embeddings, ppi_features], dim=1)
        output = self.additional_layers(combined_embeddings)
        
        return output

In [44]:
# 20/10 - Final expirements  - without finetune & with protein attention (from last papaer)
class AUVG_PPI(AbstractModel):
    def __init__(self ,dropout):
        super(AUVG_PPI, self).__init__()
        PRINTM('All features without finetune & with protein attention (from last papaer)')
        self.dropout = dropout
        self.esm_dtf = DTF(channels=1280, r=4)
        self.gae_dtf = DTF(channels=500, r=4)
        self.fegs_dtf = DTF(channels=578, r=2)
        
        # PPI Features MLP layers: (esm, custom, fegs, gae)
        self.esm_mlp = nn.Sequential(
            nn.Linear(in_features=1280 , out_features=750),
            nn.BatchNorm1d(750),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256)
        )

        self.fegs_mlp = nn.Sequential(
            nn.Linear(in_features=578 ,out_features=256),
            nn.BatchNorm1d(256)
        )        

        self.gae_mlp = nn.Sequential(
            nn.Linear(in_features=500, out_features=256),
            nn.BatchNorm1d(256)
        )

        # MLP for ppi_features
        self.ppi_mlp = nn.Sequential(
            nn.Linear(in_features=256 * 3, out_features=512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256)
        )
        
        self.fp_mlp = nn.Sequential(
            nn.Linear(in_features=1200, out_features=1024), 
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256)
        )

        # Morgan fingerprints & chemical descriptors MLP layers
        self.mfp_cd_mlp = nn.Sequential(
            nn.Linear(in_features=1024 + 194, out_features= 750),
            nn.BatchNorm1d(750),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256)
        )

        self.chemberta_mlp = nn.Sequential(
            nn.Linear(in_features=384, out_features= 256),
            nn.BatchNorm1d(256)
        )

        # MLP for smiles_embeddings
        self.smiles_mlp = nn.Sequential(
            nn.Linear(in_features=256 * 3 , out_features= 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256)
        )

        self.additional_layers = nn.Sequential(
            nn.Linear(in_features=256 + 256, out_features=256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=128, out_features=1),
        )
        
        #self.sigmoid = nn.Sigmoid()

    def forward(self, cpe, esm, fegs, gae, cbae, morgan_fingerprints, chemical_descriptors):
        # Forward pass batch mol graph through pretrained chemprop model in order to get fingerprints embeddings
        # Afterwards, pass the fingerprints through MLP layer
        cp_fingerprints = self.fp_mlp(cpe)
        cbae = self.chemberta_mlp(cbae)
        #chemberta_embeddings = self.chemberta_mlp(chemberta_embeddings)
        mfp_chem_descriptors = torch.cat([morgan_fingerprints, chemical_descriptors], dim=1)
        mfp_chem_descriptors = self.mfp_cd_mlp(mfp_chem_descriptors)
        
        # Concatenate all 3 smiles embeddings along a new dimension (3x384) 
        smiles_embeddings = torch.cat([cp_fingerprints, cbae, mfp_chem_descriptors], dim=1).to(device)  # shape ->> (batch_size, 3*384)
        smiles_embeddings = self.smiles_mlp(smiles_embeddings)

        # Pass all PPI features through their DTF module and then through MLP layer
        # in order to reduce shape to (batch_size, 256)
        esm_embedding_p1, esm_embedding_p2 = torch.split(esm, esm.shape[1] // 2, dim=1)
        esm_embeddings = self.esm_mlp(self.esm_dtf(esm_embedding_p1, esm_embedding_p2))
        
        fegs_embedding_p1, fegs_embedding_p2 = torch.split(fegs, fegs.shape[1] // 2, dim=1)
        fegs_embeddings = self.fegs_mlp(self.fegs_dtf(fegs_embedding_p1, fegs_embedding_p2))
        
        gae_embedding_p1, gae_embedding_p2 = torch.split(gae, gae.shape[1] // 2, dim=1)
        gae_embeddings = self.gae_mlp(self.gae_dtf(gae_embedding_p1, gae_embedding_p2))

        # Concatenate all 3 ppi embeddings along a new dimension (3x256) 
        ppi_embeddings = torch.cat([esm_embeddings, fegs_embeddings, gae_embeddings], dim=1).to(device)  # shape ->> (batch_size, 3*512)
        ppi_features = self.ppi_mlp(ppi_embeddings)

        combined_embeddings = torch.cat([smiles_embeddings, ppi_features], dim=1)
        output = self.additional_layers(combined_embeddings)
        
        return output

In [31]:
# For bootstrap and new training architecture - 18/10
class MoleculeDataset(Dataset):
    def __init__(self, ds_):
        self.data = ds_
        self.mapping_df = pd.read_csv(os.path.join('datasets', 'idmapping_unip.tsv'), delimiter = "\t")
        self.esm = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'esm_features.csv'))
        self.fegs = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'fegs_features.csv'))

        # In all predicted values, use zero vectors (after expirements that proved that)
        gae_path = f'GAE_FEATURES_WITH_PREDICTED_alpha_0.csv'            
        self.gae = pd.read_csv(os.path.join('datasets', 'GAE', gae_path))
        self.gae.loc[self.gae['predicted'] == 1, self.gae.columns[9:509]] = 0
        gae_features_columns = self.gae.iloc[:, 9:509]

        gae_uniprot_column = self.gae[['From']].rename(columns={'From': 'UniProt_ID'})
        self.gae = pd.concat([gae_uniprot_column, gae_features_columns], axis=1)
        self.gae_features_ppi = self.merge_datasets(self.data, self.gae).drop(columns=['smiles', 'label']).astype(np.float32)
        self.esm_features_ppi = self.merge_datasets(self.data, self.esm).drop(columns=['smiles', 'label']).astype(np.float32)
        self.fegs_features_ppi = self.merge_datasets(self.data, self.fegs).drop(columns=['smiles', 'label']).astype(np.float32)

         # SMILES RDKit features - Morgan Fingerprints (r=4, nbits=1024)  chemical descriptors, chemprop & chemBERTa
        self.smiles_morgan_fingerprints = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_morgan_fingerprints_dataset.csv'))
        self.smiles_chemical_descriptors = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_chem_descriptors_mapping_dataset.csv'))
        self.chemprop = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'chemprop_features.csv'))
        self.chemberta = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'chemBERTa_features.csv'))
        # Necessary features for ChemBERTa model
        self.smiles_list = self.data['smiles'].tolist()
        self.tokenizer = RobertaTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
        self.encoded_smiles = self.tokenizer(self.smiles_list, truncation=True, padding=True, return_tensors="pt")

    def merge_datasets(self, dataset, features_df):
        dataset = dataset.merge(features_df, how='left', left_on='uniprot_id1', right_on='UniProt_ID', suffixes=('', '_id1'))
        dataset = dataset.drop(columns=['UniProt_ID'])
        
        features_df_renamed = features_df.add_suffix('_id2')
        features_df_renamed = features_df_renamed.rename(columns={'UniProt_ID_id2': 'UniProt_ID'})
        dataset = dataset.merge(features_df_renamed, how='left', left_on='uniprot_id2', right_on='UniProt_ID', suffixes=('', '_id2'))
        dataset = dataset.drop(columns=['UniProt_ID', 'uniprot_id1', 'uniprot_id2'])
        
        # In order to avoid dropping duplicated rows that holds only zeros (in gae when there is zero vectors), which can be represents embeddings of ppi vector when
        # specifying to reset the rows to hold only zeros
        dataset['zero_count'] = (dataset == 0).any(axis=1).astype(int)
        count = 1
        for index in dataset.index:
            if dataset.at[index, 'zero_count'] == 1:
                dataset.at[index, 'zero_count'] = count
                count += 1
                
        # Fill null values with 0
        dataset.fillna(0, inplace=True)
        #dataset.drop_duplicates(inplace=True)

        return dataset.drop(columns=['zero_count'])

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        smiles = self.data.iloc[idx, 0]
        label = np.array(self.data.iloc[idx, -1], dtype=np.float32)  
        esm_features = np.array(self.esm_features_ppi.iloc[idx].values, dtype=np.float32)
        fegs_features = np.array(self.fegs_features_ppi.iloc[idx].values, dtype=np.float32)
        gae_features = np.array(self.gae_features_ppi.iloc[idx].values, dtype=np.float32)

        input_ids = self.encoded_smiles["input_ids"][idx]
        attention_mask = self.encoded_smiles["attention_mask"][idx]

        # Retrieve precomputed RDKit Morgan fingerprints
        morgan_fingerprint = self.smiles_morgan_fingerprints.loc[self.smiles_morgan_fingerprints['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemical_descriptors = self.smiles_chemical_descriptors.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemprop_features = self.chemprop.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemberta_features = self.chemberta.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        
        return (chemprop_features, esm_features, fegs_features, gae_features, 
                chemberta_features, morgan_fingerprint, chemical_descriptors, label)
        

In [None]:
# For ablation table 12/10
class MoleculeDataset(Dataset):
    def __init__(self, ds_):
        self.data = ds_
        self.mapping_df = pd.read_csv(os.path.join('datasets', 'idmapping_unip.tsv'), delimiter = "\t")
        self.esm = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'esm_features.csv'))
        self.fegs = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'fegs_features.csv'))

        # In all predicted values, use zero vectors (after expirements that proved that)
        gae_path = f'GAE_FEATURES_WITH_PREDICTED_alpha_0.csv'            
        self.gae = pd.read_csv(os.path.join('datasets', 'GAE', gae_path))
        self.gae.loc[self.gae['predicted'] == 1, self.gae.columns[9:509]] = 0
        gae_features_columns = self.gae.iloc[:, 9:509]

        gae_uniprot_column = self.gae[['From']].rename(columns={'From': 'UniProt_ID'})
        self.gae = pd.concat([gae_uniprot_column, gae_features_columns], axis=1)
        self.gae_features_ppi = self.merge_datasets(self.data, self.gae).drop(columns=['smiles', 'label']).astype(np.float32)
        self.esm_features_ppi = self.merge_datasets(self.data, self.esm).drop(columns=['smiles', 'label']).astype(np.float32)
        self.fegs_features_ppi = self.merge_datasets(self.data, self.fegs).drop(columns=['smiles', 'label']).astype(np.float32)
        
        # SMILES RDKit features - Morgan Fingerprints (r=4, nbits=1024) & chemical descriptors
        self.smiles_morgan_fingerprints = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_morgan_fingerprints_dataset.csv'))
        self.smiles_chemical_descriptors = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_chem_descriptors_mapping_dataset.csv'))

        # Necessary features for ChemBERTa model
        self.smiles_list = self.data['smiles'].tolist()
        self.tokenizer = RobertaTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
        self.encoded_smiles = self.tokenizer(self.smiles_list, truncation=True, padding=True, return_tensors="pt")

    def merge_datasets(self, dataset, features_df):
        dataset = dataset.merge(features_df, how='left', left_on='uniprot_id1', right_on='UniProt_ID', suffixes=('', '_id1'))
        dataset = dataset.drop(columns=['UniProt_ID'])
        
        features_df_renamed = features_df.add_suffix('_id2')
        features_df_renamed = features_df_renamed.rename(columns={'UniProt_ID_id2': 'UniProt_ID'})
        dataset = dataset.merge(features_df_renamed, how='left', left_on='uniprot_id2', right_on='UniProt_ID', suffixes=('', '_id2'))
        dataset = dataset.drop(columns=['UniProt_ID', 'uniprot_id1', 'uniprot_id2'])
        
        # In order to avoid dropping duplicated rows that holds only zeros (in gae when there is zero vectors), which can be represents embeddings of ppi vector when
        # specifying to reset the rows to hold only zeros
        dataset['zero_count'] = (dataset == 0).any(axis=1).astype(int)
        count = 1
        for index in dataset.index:
            if dataset.at[index, 'zero_count'] == 1:
                dataset.at[index, 'zero_count'] = count
                count += 1
                
        # Fill null values with 0
        dataset.fillna(0, inplace=True)
        dataset.drop_duplicates(inplace=True)

        return dataset.drop(columns=['zero_count'])

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        smiles = self.data.iloc[idx, 0]
        label = np.array(self.data.iloc[idx, -1], dtype=np.float32)  
        esm_features = np.array(self.esm_features_ppi.iloc[idx].values, dtype=np.float32)
        fegs_features = np.array(self.fegs_features_ppi.iloc[idx].values, dtype=np.float32)
        gae_features = np.array(self.gae_features_ppi.iloc[idx].values, dtype=np.float32)

        input_ids = self.encoded_smiles["input_ids"][idx]
        attention_mask = self.encoded_smiles["attention_mask"][idx]

        # Retrieve precomputed RDKit Morgan fingerprints
        morgan_fingerprint = self.smiles_morgan_fingerprints.loc[self.smiles_morgan_fingerprints['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemical_descriptors = self.smiles_chemical_descriptors.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        
        return (smiles, esm_features, fegs_features, gae_features, 
                input_ids, attention_mask, morgan_fingerprint, chemical_descriptors, label)


## Train and Validate ##

Next we'll train the model on 80% of the data and validate it on 20% on the data in order to decide the exact number of epochs to train each fold in the training process using scaffold split.

### Load & Prepare the Dataset ###

In [16]:
ds_folder_path = os.path.join('datasets', 'test_dataset', 'train_test_5_0.75')
all_files = os.listdir(ds_folder_path)

PRINTM(f'Folder content:\n\n{all_files}')

--------------------------------------------------------------------------------
Folder content:

['train_fold5_5_0.75.csv', 'train_fold4_5_0.75.csv', 'train_fold3_5_0.75.csv', 'train_fold2_5_0.75.csv', 'train_fold1_5_0.75.csv', 'test_fold5_5_0.75.csv', 'test_fold4_5_0.75.csv', 'test_fold3_5_0.75.csv', 'test_fold2_5_0.75.csv', 'test_fold1_5_0.75.csv']
--------------------------------------------------------------------------------


In [17]:
dataframes = {}

# Read each CSV file into a dataframe and store it in the dictionary
for file in all_files:
    file_path = os.path.join(ds_folder_path, file)
    df = pd.read_csv(file_path)
    df_name = file.replace('_5_0.75.csv', '_df')
    dataframes[df_name] = df

In [18]:
uniprot_mapping = pd.read_csv(os.path.join('datasets', 'idmapping_unip.tsv'), delimiter = "\t")

In [19]:
for df_name in dataframes.keys():
    dataframes[df_name] = convert_uniprot_ids(dataframes[df_name], uniprot_mapping)
    #dataframes[df_name] = merge_datasets(dataframes[df_name], ppi_features_df)

# Access each dataframe using its name
train_fold1_df = dataframes['train_fold1_df']
train_fold2_df = dataframes['train_fold2_df']
train_fold3_df = dataframes['train_fold3_df']
train_fold4_df = dataframes['train_fold4_df']
train_fold5_df = dataframes['train_fold5_df']
test_fold1_df = dataframes['test_fold1_df']
test_fold2_df = dataframes['test_fold2_df']
test_fold3_df = dataframes['test_fold3_df']
test_fold4_df = dataframes['test_fold4_df']
test_fold5_df = dataframes['test_fold5_df']

PRINTM(f'Done inverse mapping & merging successfully !')

--------------------------------------------------------------------------------
Done inverse mapping & merging successfully !
--------------------------------------------------------------------------------


In [20]:
train_fold1_df.head(8)

Unnamed: 0,smiles,uniprot_id1,uniprot_id2,label
0,CC1(C)CCC(c2ccc(Cl)cc2)=C(CN2CCN(c3ccc(C(=O)NS...,P06756,P18084,0
1,CC[C@H](C)[C@H](NC(=O)[C@@H]1CCCN1C(=O)CNC(=O)...,Q96SW2,Q9UBN7,0
2,CCOc1ccc(C(C)=O)cc1Nc1cc(C(=O)OC)cc(C(=O)OC)c1,O00255,Q03164,0
3,Cc1ccc(NC(=O)CCCCCCCSc2nnc(Cc3nn(C)c(=O)c4cccc...,Q07820,Q07812,0
4,O=C1C(Cl)=C(NCCN2CCOCC2)C(=O)c2ccccc21,P10415,Q07812,0
5,C[C@H]1CC[C@@H](Oc2cccc(Sc3ccc(/C=C/C(=O)N4CCO...,P62942,,0
6,O=C1[C@H](c2ccc(Cl)cc2)N(Cc2ccc(Cl)cc2F)C(=O)c...,Q07817,Q92934,0
7,CC[C@H](N)C(=O)N[C@@H]1C(=O)N2[C@H](CC[C@H]2C(...,O00255,Q03164,0


In [21]:
def generate_model(batch_size,
                  dropout) -> nn.Module:
    checkpoints_path = os.path.join('pt_chemprop_checkpoint_r4_', 'fold_0', 'model_0', 'checkpoints', 'best-epoch=39-val_loss=0.39.ckpt')
    pretrained_chemprop_model = PretrainedChempropModel(checkpoints_path, batch_size)
    chemberta_model = ChemBERTaPT()
    model = AUVG_PPI(pretrained_chemprop_model, chemberta_model, dropout).to(device)

    PRINTM('Generated model successfully !')
    return model

In [22]:
def generate_model(batch_size,
                  dropout) -> nn.Module:
    model = AUVG_PPI(dropout).to(device)

    PRINTM('Generated model successfully !')
    return model

In [23]:
batch_size = 64

In [24]:
final_dataset = pd.read_csv(r'datasets/test_dataset/final_dataset_5_0.75_25_09_2024_without_long_uncategorized_PPIs.csv')
final_dataset = convert_uniprot_ids(final_dataset, uniprot_mapping)

final_dataset

Unnamed: 0,smiles,uniprot_id1,uniprot_id2,label
0,Clc1ccc(C(c2ccc(Cl)cc2)[n+]2ccn(CC(OCc3ccc(Cl)...,Q07817,Q16611,0
1,CCCCN(CCCC)C(=O)c1nn(c(C)c1Cl)c2ccc(cc2C(=O)N3...,O00255,Q03164,0
2,CCN(C)c1c(N[C@@H](Cc2ccc(NC(=O)c3c(Cl)cncc3Cl)...,P98170,Q9NR28,0
3,N=C(N)NCCCC(=O)N1CCC[C@@H]1C(=O)NCCC(=O)O,P25440,Q07912,0
4,CC[C@@H]1CC[C@@H](C(=O)N2CCN(C)CC2)N1C(=O)C1=C...,P04637,Q00987,1
...,...,...,...,...
80215,CCOc1cc2c(cc1OCC)C(c1ccc(Cl)cc1)N(c1ccc(Cl)cc1...,P04637,Q00987,1
80216,C[C@H](CNC(=O)c1c(O)c(O)cc2c(O)c(c(C)cc12)c3c(...,P06756,P05106,0
80217,O=C(NCc1ccccc1)c1cc(Cl)c(O)c(S(=O)(=O)N(Cc2ccc...,O00255,Q03164,0
80218,CCOC(=O)CSc1ccc(C(=O)c2[nH]c(Cl)c(Cl)c2-n2c(C(...,Q07817,O43521,1


### Train, val & test the model ###

The next step is to decide on the model’s hyperparameters. We'll train the model on 5 custom splits, and each split will be trained 5 times, with a newly initialized model each time. As a result, each fold will be bootstrapped 5 times, meaning it will be trained and validated on the validation set, where we will save the best model state during each epoch. After training (for up to 100 epochs, or until early stopping is triggered), the model will be tested on the test set. Both validation and testing will be performed using bootstrapping (resampling) 1,000 times, after which the mean AUC will be calculated. We will use a custom Butina splitter object to divide our data. In total, there will be 5 x 5 = 25 experiments: 5 folds, with each fold being trained 5 times.

In [40]:
def train_val_test_model(dataset, num_epochs, dropout, lr, weight_decay, criterion, 
                        batch_size=64, device='cuda', num_workers=5):
            all_folds_valid_aucs = []
            all_folds_test_aucs = []
            splits = []
            smiles_df = dataset[["smiles"]].drop_duplicates()
            smiles_col = 'smiles'
            
            # split dataset into 5 folds of (train, val, test) dataframes using custom butina splitter obj.
            butinaSplitter = CustomButinaSplitter()
            splits = butinaSplitter.split_dataset(dataset)
                    
            for fold_number, (train_subset, val_subset, test_subset) in enumerate(splits, 1):
                PRINTC()
                print(f"fold number {fold_number}")
                PRINTC()
                train_df, val_df, test_df = train_subset, val_subset, test_subset
                test_dataset = MoleculeDataset(test_df)
                test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
                
                bootstrap_valid_aucs = []
                bootstrap_test_aucs = []
                                
                for bootsrap in range(5):
                    best_val_auc = float('-inf')
                    no_improve_epochs = 0
                    early_stopping_patience = 5
                    best_model_state_dict = None
                    
                    print(f"bootsrap number: {bootsrap + 1}")
                    model = generate_model(batch_size=batch_size, dropout=dropout)
                    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
                    PRINTC()
                    
                    seed_train = fold_number*1000 + bootsrap + 1
                    labels_list = train_df['label'].values
                    train_b = resample(train_df, random_state=seed_train, stratify=labels_list)

                    train_dataset = MoleculeDataset(train_b)                    
                    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
                    val_dataset = MoleculeDataset(val_df)                    
                    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

                    last_test_auc = 0  # Initialize the last test AUC for this fold
                    for epoch in range(num_epochs):
                        start_time = time.time()
                        model.train()
                        epoch_loss = 0
                        all_preds = []
                        all_labels = []
                        running_loss = 0.0
                        for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                             batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in train_loader:

                            # Move all tensors to device
                            batches = [batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc, batch_labels]
                            batches = [batch.to(device) for batch in batches]
                            batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc, batch_labels = batches
                                
                            optimizer.zero_grad()
                            outputs = model(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                                           batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc) 
                
                            loss = criterion(outputs.squeeze(), batch_labels)
                            loss.backward()
                            optimizer.step()
                            epoch_loss += loss.item()
                
                            all_labels.extend(batch_labels.cpu().numpy())
                            all_preds.extend(outputs.squeeze().detach().cpu().numpy())
                            
                        train_auc = roc_auc_score(all_labels, all_preds)
                                
                        # Evaluate the model on the validation set
                        all_val_labels = []
                        all_val_outputs = []
                        model.eval()
                        with torch.no_grad():
                            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:
                                
                                # Move all tensors to device
                                batches = [batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc, batch_labels]
                                batches = [batch.to(device) for batch in batches]
                                batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc, batch_labels = batches
                                                
                                outputs = model(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)               
                                
                                all_val_labels.extend(batch_labels.cpu().numpy())
                                all_val_outputs.extend(outputs.squeeze().detach().cpu().numpy())
                                
                        all_val_labels = np.array(all_val_labels)
                        all_val_outputs = np.array(all_val_outputs)
                        
                        # Perform bootstrapping on predictions and labels (validation phase)
                        current_b_aucs = []
                        N_test = all_val_labels.shape[0]
                        for b in range(1000):
                            seed_value = epoch * 1000 + b + (bootsrap+1)*1000  # or any function of your parameters
                            np.random.seed(seed_value)
                            indices = np.random.randint(0, N_test, size=N_test)
                            y_valid_pred_b = all_val_outputs[indices]
                            y_valid_b = all_val_labels[indices]
                            valid_auc = roc_auc_score(y_valid_b, y_valid_pred_b)
                            current_b_aucs.append(valid_auc)
                
                        mean_val_auc = np.mean(current_b_aucs)
                        end_time = time.time()
                        epoch_time = (end_time - start_time) / 60
                        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {(epoch_loss/len(train_loader)):.4f}, Train AUC: {train_auc:.4f}, Mean Validation AUC: {mean_val_auc:.4f}, Epoch Time: {epoch_time:.4f}')

                        # Early stopping logic
                        if mean_val_auc > best_val_auc:
                            best_val_auc = mean_val_auc
                            epochs_without_improvement = 0
                            # Save the best model state dict
                            best_model_state_dict = copy.deepcopy(model.state_dict())
                        else:
                            epochs_without_improvement += 1
            
                        if epochs_without_improvement >= early_stopping_patience:
                            print("Early stopping triggered")
                            break
                            
                    # Load the best model in order to evaluate it on the test set
                    model.load_state_dict(best_model_state_dict)
                    
                    all_test_labels = []
                    all_test_outputs = []
                    model.eval()
                    with torch.no_grad():
                        for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                                batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:
                                
                            # Move all tensors to device
                            batches = [batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc, batch_labels]
                            batches = [batch.to(device) for batch in batches]
                            batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc, batch_labels = batches
                                                
                            outputs = model(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                                           batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)               
                                
                            all_test_labels.extend(batch_labels.cpu().numpy())
                            all_test_outputs.extend(outputs.squeeze().detach().cpu().numpy())
                                
                    all_test_labels = np.array(all_test_labels)
                    all_test_outputs = np.array(all_test_outputs)                 

                    # Perform bootstrapping on predictions and labels (test phase)
                    current_b_aucs = []
                    N_test = all_test_labels.shape[0]
                    for b in range(1000):
                        seed_value = epoch * 1000 + b + (bootsrap+1)*1000  # or any function of your parameters
                        np.random.seed(seed_value)
                        indices = np.random.randint(0, N_test, size=N_test)
                        y_test_pred_b = all_test_outputs[indices]
                        y_test_b = all_test_labels[indices]
                        test_auc = roc_auc_score(y_test_b, y_test_pred_b)
                        current_b_aucs.append(test_auc)
                
                    mean_test_auc = np.mean(current_b_aucs)
                    print(f'Bootstrap {bootsrap}, Mean Test AUC: {mean_test_auc:.4f}')

                    # Store the best validation and test AUCs for this bootstrap
                    bootstrap_valid_aucs.append(best_val_auc)
                    bootstrap_test_aucs.append(mean_test_auc)
                
                # Compute mean validation and test AUCs for the current fold
                current_fold_mean_valid_auc = np.mean(bootstrap_valid_aucs)
                current_fold_mean_test_auc = np.mean(bootstrap_test_aucs)
                print(f"Fold {fold_number} Mean Validation AUC: {current_fold_mean_valid_auc:.4f}")
                print(f"Fold {fold_number} Mean Test AUC: {current_fold_mean_test_auc:.4f}")
                all_folds_valid_aucs.append(current_fold_mean_valid_auc)
                all_folds_test_aucs.append(current_fold_mean_test_auc)
            
            PRINTC()               
            print("Final Mean Validation AUC across all folds:", np.mean(all_folds_valid_aucs))
            print("Validation AUCs for all folds:", all_folds_valid_aucs)
            print("Final Mean Test AUC across all folds:", np.mean(all_folds_test_aucs))
            print("Test AUCs for all folds:", all_folds_test_aucs)                                    


In [45]:
train_val_test_model(dataset=final_dataset, num_epochs=100, dropout=0.3, lr=1e-5, weight_decay=1e-3,
                              criterion=nn.BCEWithLogitsLoss(), batch_size=64, device=device, num_workers=16)

Number of unique SMILES: 10720
Random Seed: 11
Train size: 61862 (77.17%), Valid size: 9148 (11.41%), Test size: 9149 (11.41%)
Random Seed: 7
Train size: 62524 (78.00%), Valid size: 8817 (11.00%), Test size: 8818 (11.00%)
Random Seed: 8
Train size: 62568 (78.05%), Valid size: 8795 (10.97%), Test size: 8796 (10.97%)
Random Seed: 1
Train size: 62576 (78.06%), Valid size: 8791 (10.97%), Test size: 8792 (10.97%)
Random Seed: 4
Train size: 62782 (78.32%), Valid size: 8688 (10.84%), Test size: 8689 (10.84%)
--------------------------------------------------------------------------------
fold number 1
--------------------------------------------------------------------------------
bootsrap number: 1
--------------------------------------------------------------------------------
All features without finetune & with protein attention (from last papaer)
--------------------------------------------------------------------------------
--------------------------------------------------------------

In [47]:
ls = [0.9564830945117251, 0.9487758212348123, 0.9604757489871449, 0.9519970364786623, 0.9399813714236013]
f'{np.var(ls):.15f}'

'0.000049146644058'

### Train & Validate the Model ###

#### Fold number 1 #####

In [None]:
model_f1 = generate_model(batch_size=64, dropout=0.3)
model_f1.train_val_model('fold1', num_epochs=50, dataset=train_fold1_df,
                              optimizer=optim.AdamW(model_f1.parameters(), lr=1e-5, weight_decay=1e-3),
                              criterion=nn.BCEWithLogitsLoss(),
                              batch_size=64, device=device, num_workers=16)

In [None]:
model_f1.test_model(test_fold1_df,
                    criterion= nn.BCEWithLogitsLoss() ,batch_size=64,
                    device=device, num_workers=16)

#### Fold number 2 #####

In [None]:
model_f2 = generate_model(batch_size=64, dropout=0.3)
model_f2.train_val_model('fold2', num_epochs=50, dataset=train_fold2_df,
                              optimizer=optim.AdamW(model_f2.parameters(), lr=1e-5, weight_decay=1e-3),
                              criterion=nn.BCEWithLogitsLoss(),                              
                              batch_size=64, device=device, num_workers=16)

#### Fold number 3 ####

In [None]:
model_f3 = generate_model(batch_size=64, dropout=0.3)
model_f3.train_val_model('fold3', num_epochs=50, dataset=train_fold3_df,
                              optimizer=optim.AdamW(model_f3.parameters(), lr=1e-5, weight_decay=1e-3),
                              criterion=nn.BCEWithLogitsLoss(),                              
                              batch_size=64, device=device, num_workers=16)

#### Folder number 4 ####

In [None]:
model_f4 = generate_model(batch_size=64, dropout=0.3)
model_f4.train_val_model('fold4', num_epochs=50, dataset=train_fold4_df,
                              optimizer=optim.AdamW(model_f4.parameters(), lr=1e-5, weight_decay=1e-3),
                              criterion=nn.BCEWithLogitsLoss(),                              
                              batch_size=64, device=device, num_workers=16)

#### Folder number 5 ####

In [None]:
model_f5 = generate_model(batch_size=64, dropout=0.3)
model_f5.train_val_model('fold5', num_epochs=50, dataset=train_fold5_df,
                              optimizer=optim.AdamW(model_f5.parameters(), lr=1e-5, weight_decay=1e-3),
                              criterion=nn.BCEWithLogitsLoss(),                              
                              batch_size=64, device=device, num_workers=16)

### TRAIN - Average AUC on 5 Expirements ###

In [None]:
def avg_expirements_auc(num_epochs_list, n):
    auc_dict = {f'exp{i+1}': [] for i in range(n)}
    for exp_num in range(1, 6):
        print(f"Starting Experiment {exp_num}")
        
        for fold_num in range(1, 6): 
            fold_name = f'fold{fold_num}'
            train_fold, test_fold = dataframes[f'train_{fold_name}_df'], dataframes[f'test_{fold_name}_df'] 
            num_epochs = num_epochs_list[fold_num - 1] 
            model = generate_model(batch_size=64, dropout=0.3)
            model.train_model(fold_name, num_epochs=num_epochs, dataset=train_fold,
                                          optimizer=optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-3),
                                          criterion=nn.BCEWithLogitsLoss(),
                                          batch_size=64, device=device, num_workers=16)
    
            auc = model.test_model(test_fold, 
                                      criterion=nn.BCEWithLogitsLoss(), batch_size=64, 
                                      device=device, num_workers=16)
    
            auc_dict[f'exp{exp_num}'].append(auc)
    return auc_dict

In [None]:
def train_expirements_auc(num_epochs, fold_num):
    fold_name = f'fold{fold_num}'
    train_fold, test_fold = dataframes[f'train_{fold_name}_df'], dataframes[f'test_{fold_name}_df'] 
    model = generate_model(batch_size=64, dropout=0.3)
    model.train_model(fold_name, num_epochs=num_epochs, dataset=train_fold,
                      optimizer=optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-3),
                      criterion=nn.BCEWithLogitsLoss(),
                      batch_size=64, device=device, num_workers=16)
    
    model.test_model(test_fold, 
                     criterion=nn.BCEWithLogitsLoss(), batch_size=64, 
                     device=device, num_workers=16)


In [None]:
# 12/10 - for ablation table - all features
class AUVG_PPI(AbstractModel):
    def __init__(self, pretrained_chemprop_model, chemberta_model, dropout):
        
        super(AUVG_PPI, self).__init__()
        self.pretrained_chemprop_model = pretrained_chemprop_model
        self.chemberta_model = chemberta_model
        self.dropout = dropout
        
        # PPI Features MLP layers: (esm, custom, fegs, gae)
        self.esm_mlp = nn.Sequential(
            nn.Linear(in_features=1280 + 1280 , out_features=1750),
            nn.ReLU(),
            nn.BatchNorm1d(1750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1750, out_features=1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1000, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        self.fegs_mlp = nn.Sequential(
            nn.Linear(in_features=578 + 578, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )        

        self.gae_mlp = nn.Sequential(
            nn.Linear(in_features=500 + 500, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        # MLP for ppi_features
        self.ppi_mlp = nn.Sequential(
            nn.Linear(in_features=512 * 3, out_features=1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512)
        )
        
        self.fp_mlp = nn.Sequential(
            nn.Linear(in_features=2100, out_features=1536),
            nn.ReLU(),
            nn.BatchNorm1d(1536),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1536, out_features=1024), 
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=384)
        )

        # Morgan fingerprints & chemical descriptors MLP layers
        self.mfp_cd_mlp = nn.Sequential(
            nn.Linear(in_features=1024 + 194, out_features= 750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=384)
        )

        # MLP for smiles_embeddings
        self.smiles_mlp = nn.Sequential(
            nn.Linear(in_features=384 * 3 , out_features= 750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        self.additional_layers = nn.Sequential(
            nn.Linear(in_features=512 + 512, out_features=768),
            nn.ReLU(),
            nn.BatchNorm1d(768),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=768, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=128, out_features=1),
        )
        
        #self.sigmoid = nn.Sigmoid()

    def forward(self, bmg, esm, fegs, gae,
                input_ids, attention_mask,
                morgan_fingerprints, chemical_descriptors):
        # Forward pass batch mol graph through pretrained chemprop model in order to get fingerprints embeddings
        # Afterwards, pass the fingerprints through MLP layer
        cp_fingerprints = self.pretrained_chemprop_model(bmg)
        cp_fingerprints = self.fp_mlp(cp_fingerprints)

        chemberta_embeddings = self.chemberta_model(input_ids, attention_mask)
        #chemberta_embeddings = self.chemberta_mlp(chemberta_embeddings)
        mfp_chem_descriptors = torch.cat([morgan_fingerprints, chemical_descriptors], dim=1)
        mfp_chem_descriptors = self.mfp_cd_mlp(mfp_chem_descriptors)
        
        # Concatenate all 3 smiles embeddings along a new dimension (3x384) 
        smiles_embeddings = torch.cat([cp_fingerprints,chemberta_embeddings, mfp_chem_descriptors], dim=1).to(device)  # shape ->> (batch_size, 3*384)
        smiles_embeddings = self.smiles_mlp(smiles_embeddings)

        # Pass all PPI features  through MLP layers, and then pass them all together into another MLP layer
        #ppi_features = proteins.to(device)
        esm_embeddings = self.esm_mlp(esm)
        fegs_embeddings = self.fegs_mlp(fegs)
        gae_embeddings = self.gae_mlp(gae)

        # Concatenate all 3 ppi embeddings along a new dimension (3x512) 
        ppi_embeddings = torch.cat([esm_embeddings, fegs_embeddings, gae_embeddings], dim=1).to(device)  # shape ->> (batch_size, 3*512)
        ppi_features = self.ppi_mlp(ppi_embeddings)

        combined_embeddings = torch.cat([smiles_embeddings, ppi_features], dim=1)
        output = self.additional_layers(combined_embeddings)
        
        return output

In [None]:
# 12/10 - for ablation table
class AbstractModel(ABC, nn.Module):
    def __init__(self):
        super(AbstractModel, self).__init__()
        self.early_stopping_patience = 5
        self.delta = 0.001

    @abstractmethod
    def forward(self, bmg, esm, fegs, gae,
                input_ids, attention_mask,
                morgan_fingerprints, chemical_descriptors):
        pass
        
    def train_model(self, fold, num_epochs, dataset, optimizer, criterion, 
                    batch_size=32, device='cuda', num_workers=5):
        
        train_dataset = MoleculeDataset(dataset)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        print(f'Start training {fold} for {num_epochs} epochs !')
        for epoch in range(num_epochs):
            start_time = time.time()
            self.train()
            running_loss = 0.0
            for (batch_smiles, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mask, batch_morgan, batch_chem_desc ,batch_labels) in train_loader:
                # Move tensors to the configured device
                batch_attention_mask = batch_attention_mask.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                optimizer.zero_grad()
                outputs = self(batch_smiles, batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_input_ids, batch_attention_mask, batch_morgan, batch_chem_desc)               
                loss = criterion(outputs.squeeze(), batch_labels)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()
            end_time = time.time()
            epoch_time = (end_time - start_time) / 60
            print(f"Epoch {epoch+1} Time: {epoch_time:.2f}")

        
    def train_val_model(self, fold, num_epochs, dataset, optimizer, criterion, 
                    batch_size=32, device='cuda', num_workers=5):
        best_val_auc = float('-inf')
        no_improve_epochs = 0

        X = dataset.copy().drop(columns=['smiles'])
        ids = dataset['smiles']
        # extract labels for scaffold split (in order to make sure we got balance train & validation sets)
        labels = dataset['label'].values
        dc_dataset = dc.data.DiskDataset.from_numpy(X=X ,y=labels ,w=np.zeros(len(X)),ids=ids)
        splitter = dc.splits.ScaffoldSplitter()
        train_idx, val_idx, _ = splitter.split(dc_dataset, frac_train=0.8, frac_valid=0.2, frac_test=0)

        #train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.2, stratify=labels, shuffle=shuffle)
        
        train_subset = dataset.iloc[train_idx].reset_index(drop=True)
        val_subset = dataset.iloc[val_idx].reset_index(drop=True)

        train_dataset = MoleculeDataset(train_subset)
        val_dataset = MoleculeDataset(val_subset)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        for epoch in range(num_epochs):
            start_time = time.time()
            self.train()
            running_loss = 0.0
            for (batch_smiles, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mask, batch_morgan, batch_chem_desc ,batch_labels) in train_loader:
                # Move tensors to the configured device
                batch_attention_mask = batch_attention_mask.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                optimizer.zero_grad()
                outputs = self(batch_smiles, batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_input_ids, batch_attention_mask, batch_morgan, batch_chem_desc)                 
                loss = criterion(outputs.squeeze(), batch_labels)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()

            val_loss, val_accuracy, val_auc = self.validate_model(val_loader, criterion, device)
            end_time = time.time()
            epoch_time = (end_time - start_time) / 60

            print(f"Epoch {epoch+1} - Validation Loss: {val_loss:.5f}, "
                  f"Validation Accuracy: {val_accuracy:.2f}, Validation AUC: {val_auc:.5f}, Epoch Time: {epoch_time:.2f}")
            # Check whether val_auc > best_val_auc + delta
            if val_auc > best_val_auc + self.delta:
                best_val_auc = val_auc
                train_epoch = epoch+1
                no_improve_epochs = 0 
                print(f"Current best val_auc -> {val_auc:.5f}, at epoch {epoch+1}")
            else:
                no_improve_epochs += 1
                if no_improve_epochs >= self.early_stopping_patience:
                    print(f"Stopping early at epoch {epoch+1}")
                    break

        print(f'Train the model for -> {train_epoch}, best validation auc: {best_val_auc:.5f}')
                
    def test_model(self, test_dataset, criterion, batch_size, device, num_workers):
        test_dataset = MoleculeDataset(test_dataset)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        self.eval()

        test_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []

        with torch.no_grad():
            for (batch_smiles, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mask, batch_morgan, batch_chem_desc ,batch_labels) in test_loader:
                # Move tensors to the configured device
                batch_attention_mask = batch_attention_mask.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                outputs = self(batch_smiles, batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_input_ids, batch_attention_mask, batch_morgan, batch_chem_desc)                 
                loss = criterion(outputs.squeeze(), batch_labels)
                test_loss += loss.item()

                all_labels.extend(batch_labels.cpu().numpy())
                all_outputs.extend(outputs.squeeze().cpu().numpy())

                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()

        test_loss /= len(test_loader)
        accuracy = correct / total
        test_auc = roc_auc_score(all_labels, all_outputs)
        PRINTC()
        print(f"Test AUC: {test_auc:.4f}")
        PRINTC()
        return round(test_auc, 4)

    def validate_model(self, val_loader, criterion, device):
        self.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []

        with torch.no_grad():
            for (batch_smiles, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mask, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:
                # Move tensors to the configured device
                batch_attention_mask = batch_attention_mask.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)
                 
                outputs = self(batch_smiles, batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_input_ids, batch_attention_mask, batch_morgan, batch_chem_desc)                
                loss = criterion(outputs.squeeze(), batch_labels)
                val_loss += loss.item()

                all_labels.extend(batch_labels.cpu().numpy())
                all_outputs.extend(outputs.squeeze().cpu().numpy())

                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()

        val_loss /= len(val_loader)
        accuracy = correct / total
        val_auc = roc_auc_score(all_labels, all_outputs)
        return val_loss, accuracy, val_auc


In [None]:
# 12/10 - for ablation table - all features
class AUVG_PPI(AbstractModel):
    def __init__(self ,dropout):
        super(AUVG_PPI, self).__init__()
        self.dropout = dropout
        
        # PPI Features MLP layers: (esm, custom, fegs, gae)
        self.esm_mlp = nn.Sequential(
            nn.Linear(in_features=1280 + 1280 , out_features=1750),
            nn.ReLU(),
            nn.BatchNorm1d(1750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1750, out_features=1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1000, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        self.fegs_mlp = nn.Sequential(
            nn.Linear(in_features=578 + 578, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )        

        self.gae_mlp = nn.Sequential(
            nn.Linear(in_features=500 + 500, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        # MLP for ppi_features
        self.ppi_mlp = nn.Sequential(
            nn.Linear(in_features=256 * 3, out_features=512),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )
        
        self.fp_mlp = nn.Sequential(
            nn.Linear(in_features=1200, out_features=1024), 
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        # Morgan fingerprints & chemical descriptors MLP layers
        self.mfp_cd_mlp = nn.Sequential(
            nn.Linear(in_features=1024 + 194, out_features= 750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        self.chemberta_mlp = nn.Sequential(
            nn.Linear(in_features=384, out_features= 256)
        )

        # MLP for smiles_embeddings
        self.smiles_mlp = nn.Sequential(
            nn.Linear(in_features=256 * 3 , out_features= 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256)
        )

        self.additional_layers = nn.Sequential(
            nn.Linear(in_features=256 + 256, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=128, out_features=1),
        )
        
        #self.sigmoid = nn.Sigmoid()

    def forward(self, cpe, esm, fegs, gae, cbae, morgan_fingerprints, chemical_descriptors):
        # Forward pass batch mol graph through pretrained chemprop model in order to get fingerprints embeddings
        # Afterwards, pass the fingerprints through MLP layer
        cp_fingerprints = self.fp_mlp(cpe)
        cbae = self.chemberta_mlp(cbae)
        #chemberta_embeddings = self.chemberta_mlp(chemberta_embeddings)
        mfp_chem_descriptors = torch.cat([morgan_fingerprints, chemical_descriptors], dim=1)
        mfp_chem_descriptors = self.mfp_cd_mlp(mfp_chem_descriptors)
        
        # Concatenate all 3 smiles embeddings along a new dimension (3x384) 
        smiles_embeddings = torch.cat([cp_fingerprints, cbae, mfp_chem_descriptors], dim=1).to(device)  # shape ->> (batch_size, 3*384)
        smiles_embeddings = self.smiles_mlp(smiles_embeddings)

        # Pass all PPI features  through MLP layers, and then pass them all together into another MLP layer
        #ppi_features = proteins.to(device)
        esm_embeddings = self.esm_mlp(esm)
        fegs_embeddings = self.fegs_mlp(fegs)
        gae_embeddings = self.gae_mlp(gae)

        # Concatenate all 3 ppi embeddings along a new dimension (3x512) 
        ppi_embeddings = torch.cat([esm_embeddings, fegs_embeddings, gae_embeddings], dim=1).to(device)  # shape ->> (batch_size, 3*512)
        ppi_features = self.ppi_mlp(ppi_embeddings)

        combined_embeddings = torch.cat([smiles_embeddings, ppi_features], dim=1)
        output = self.additional_layers(combined_embeddings)
        
        return output

In [None]:
# 12/10 - for ablation table
class AbstractModel(ABC, nn.Module):
    def __init__(self):
        super(AbstractModel, self).__init__()
        self.early_stopping_patience = 5
        self.delta = 0.001

    @abstractmethod
    def forward(self, cpe, esm, fegs, gae, cbae, morgan_fingerprints, chemical_descriptors):
        pass
        
    def train_model(self, fold, num_epochs, dataset, optimizer, criterion, 
                    batch_size=32, device='cuda', num_workers=5):
        
        train_dataset = MoleculeDataset(dataset)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        print(f'Start training {fold} for {num_epochs} epochs !')
        for epoch in range(num_epochs):
            start_time = time.time()
            self.train()
            running_loss = 0.0
            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in train_loader:
                # Move tensors to the configured device
                batch_chemprop_features = batch_chemprop_features.to(device)
                batch_chemberta_features = batch_chemberta_features.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                optimizer.zero_grad()
                outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)               
                loss = criterion(outputs.squeeze(), batch_labels)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()
            end_time = time.time()
            epoch_time = (end_time - start_time) / 60
            print(f"Epoch {epoch+1} Time: {epoch_time:.2f}")

        
    def train_val_model(self, fold, num_epochs, dataset, optimizer, criterion, 
                    batch_size=32, device='cuda', num_workers=5):
        best_val_auc = float('-inf')
        no_improve_epochs = 0

        X = dataset.copy().drop(columns=['smiles'])
        ids = dataset['smiles']
        # extract labels for scaffold split (in order to make sure we got balance train & validation sets)
        labels = dataset['label'].values
        dc_dataset = dc.data.DiskDataset.from_numpy(X=X ,y=labels ,w=np.zeros(len(X)),ids=ids)
        splitter = dc.splits.ScaffoldSplitter()
        train_idx, val_idx, _ = splitter.split(dc_dataset, frac_train=0.8, frac_valid=0.2, frac_test=0)

        #train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.2, stratify=labels, shuffle=shuffle)
        
        train_subset = dataset.iloc[train_idx].reset_index(drop=True)
        val_subset = dataset.iloc[val_idx].reset_index(drop=True)

        train_dataset = MoleculeDataset(train_subset)
        val_dataset = MoleculeDataset(val_subset)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        for epoch in range(num_epochs):
            start_time = time.time()
            self.train()
            running_loss = 0.0
            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in train_loader:
                
                batch_chemprop_features = batch_chemprop_features.to(device)
                batch_chemberta_features = batch_chemberta_features.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                optimizer.zero_grad()
                outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)                 
                loss = criterion(outputs.squeeze(), batch_labels)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()

            val_loss, val_accuracy, val_auc = self.validate_model(val_loader, criterion, device)
            end_time = time.time()
            epoch_time = (end_time - start_time) / 60

            print(f"Epoch {epoch+1} - Validation Loss: {val_loss:.5f}, "
                  f"Validation Accuracy: {val_accuracy:.2f}, Validation AUC: {val_auc:.5f}, Epoch Time: {epoch_time:.2f}")
            # Check whether val_auc > best_val_auc + delta
            if val_auc > best_val_auc + self.delta:
                best_val_auc = val_auc
                train_epoch = epoch+1
                no_improve_epochs = 0 
                print(f"Current best val_auc -> {val_auc:.5f}, at epoch {epoch+1}")
            else:
                no_improve_epochs += 1
                if no_improve_epochs >= self.early_stopping_patience:
                    print(f"Stopping early at epoch {epoch+1}")
                    break

        print(f'Train the model for -> {train_epoch}, best validation auc: {best_val_auc:.5f}')
                
    def test_model(self, test_dataset, criterion, batch_size, device, num_workers):
        test_dataset = MoleculeDataset(test_dataset)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        self.eval()

        test_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []

        with torch.no_grad():
            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in test_loader:
                
                batch_chemprop_features = batch_chemprop_features.to(device)
                batch_chemberta_features = batch_chemberta_features.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)               
                loss = criterion(outputs.squeeze(), batch_labels)
                test_loss += loss.item()

                all_labels.extend(batch_labels.cpu().numpy())
                all_outputs.extend(outputs.squeeze().cpu().numpy())

                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()

        test_loss /= len(test_loader)
        accuracy = correct / total
        test_auc = roc_auc_score(all_labels, all_outputs)
        PRINTC()
        print(f"Test AUC: {test_auc:.4f}")
        PRINTC()
        return round(test_auc, 4)

    def validate_model(self, val_loader, criterion, device):
        self.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []

        with torch.no_grad():
            for (batch_chemprop_features, batch_esm_features, batch_fegs_features, batch_gae_features,
                 batch_chemberta_features, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:

                batch_chemprop_features = batch_chemprop_features.to(device)
                batch_chemberta_features = batch_chemberta_features.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                outputs = self(batch_chemprop_features , batch_esm_features, batch_fegs_features,
                               batch_gae_features, batch_chemberta_features, batch_morgan, batch_chem_desc)                
                loss = criterion(outputs.squeeze(), batch_labels)
                val_loss += loss.item()

                all_labels.extend(batch_labels.cpu().numpy())
                all_outputs.extend(outputs.squeeze().cpu().numpy())

                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()

        val_loss /= len(val_loader)
        accuracy = correct / total
        val_auc = roc_auc_score(all_labels, all_outputs)
        return val_loss, accuracy, val_auc


In [None]:
# For ablation table 17/10
class MoleculeDataset(Dataset):
    def __init__(self, ds_):
        self.data = ds_
        self.mapping_df = pd.read_csv(os.path.join('datasets', 'idmapping_unip.tsv'), delimiter = "\t")
        self.esm = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'esm_features.csv'))
        self.fegs = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'fegs_features.csv'))

        # In all predicted values, use zero vectors (after expirements that proved that)
        gae_path = f'GAE_FEATURES_WITH_PREDICTED_alpha_0.csv'            
        self.gae = pd.read_csv(os.path.join('datasets', 'GAE', gae_path))
        self.gae.loc[self.gae['predicted'] == 1, self.gae.columns[9:509]] = 0
        gae_features_columns = self.gae.iloc[:, 9:509]

        gae_uniprot_column = self.gae[['From']].rename(columns={'From': 'UniProt_ID'})
        self.gae = pd.concat([gae_uniprot_column, gae_features_columns], axis=1)
        self.gae_features_ppi = self.merge_datasets(self.data, self.gae).drop(columns=['smiles', 'label']).astype(np.float32)
        self.esm_features_ppi = self.merge_datasets(self.data, self.esm).drop(columns=['smiles', 'label']).astype(np.float32)
        self.fegs_features_ppi = self.merge_datasets(self.data, self.fegs).drop(columns=['smiles', 'label']).astype(np.float32)

         # SMILES RDKit features - Morgan Fingerprints (r=4, nbits=1024)  chemical descriptors, chemprop & chemBERTa
        self.smiles_morgan_fingerprints = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_morgan_fingerprints_dataset.csv'))
        self.smiles_chemical_descriptors = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_chem_descriptors_mapping_dataset.csv'))
        self.chemprop = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'chemprop_features.csv'))
        self.chemberta = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'chemBERTa_features.csv'))
        # Necessary features for ChemBERTa model
        self.smiles_list = self.data['smiles'].tolist()
        self.tokenizer = RobertaTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
        self.encoded_smiles = self.tokenizer(self.smiles_list, truncation=True, padding=True, return_tensors="pt")

    def merge_datasets(self, dataset, features_df):
        dataset = dataset.merge(features_df, how='left', left_on='uniprot_id1', right_on='UniProt_ID', suffixes=('', '_id1'))
        dataset = dataset.drop(columns=['UniProt_ID'])
        
        features_df_renamed = features_df.add_suffix('_id2')
        features_df_renamed = features_df_renamed.rename(columns={'UniProt_ID_id2': 'UniProt_ID'})
        dataset = dataset.merge(features_df_renamed, how='left', left_on='uniprot_id2', right_on='UniProt_ID', suffixes=('', '_id2'))
        dataset = dataset.drop(columns=['UniProt_ID', 'uniprot_id1', 'uniprot_id2'])
        
        # In order to avoid dropping duplicated rows that holds only zeros (in gae when there is zero vectors), which can be represents embeddings of ppi vector when
        # specifying to reset the rows to hold only zeros
        dataset['zero_count'] = (dataset == 0).any(axis=1).astype(int)
        count = 1
        for index in dataset.index:
            if dataset.at[index, 'zero_count'] == 1:
                dataset.at[index, 'zero_count'] = count
                count += 1
                
        # Fill null values with 0
        dataset.fillna(0, inplace=True)
        dataset.drop_duplicates(inplace=True)

        return dataset.drop(columns=['zero_count'])

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        smiles = self.data.iloc[idx, 0]
        label = np.array(self.data.iloc[idx, -1], dtype=np.float32)  
        esm_features = np.array(self.esm_features_ppi.iloc[idx].values, dtype=np.float32)
        fegs_features = np.array(self.fegs_features_ppi.iloc[idx].values, dtype=np.float32)
        gae_features = np.array(self.gae_features_ppi.iloc[idx].values, dtype=np.float32)

        input_ids = self.encoded_smiles["input_ids"][idx]
        attention_mask = self.encoded_smiles["attention_mask"][idx]

        # Retrieve precomputed RDKit Morgan fingerprints
        morgan_fingerprint = self.smiles_morgan_fingerprints.loc[self.smiles_morgan_fingerprints['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemical_descriptors = self.smiles_chemical_descriptors.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemprop_features = self.chemprop.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemberta_features = self.chemberta.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        
        return (chemprop_features, esm_features, fegs_features, gae_features, 
                chemberta_features, morgan_fingerprint, chemical_descriptors, label)
