# Build Research Final Model Pipeline #

## Import Libraries ##

In [1]:
import os
import time
import warnings
warnings.filterwarnings('ignore')

import chemprop
from chemprop import data, featurizers, models
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch import Tensor

# import necessary libraries for Chemberta model
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, AdamW, get_linear_schedule_with_warmup , BertModel

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score


import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Draw
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect

In [2]:
def PRINT() -> None: print(f"{'-'*80}\nDone\n{'-'*80}")
def PRINTC() -> None: print(f"{'-'*80}")
def PRINTM(M) -> None: print(f"{'-'*80}\n{M}\n{'-'*80}")

## Verify GPU Availability ##

In [3]:
!nvidia-smi

Sat Sep 14 08:50:43 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX 6000 Ada Gene...    On  |   00000000:21:00.0 Off |                  Off |
| 30%   33C    P8             26W /  300W |       1MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

For this task, we'll use the BGU cluster GPU `NVIDIA RTX 6000 Ada Generation` to achieve better performance during the training of our pre-trained and fine-tuned models, allowing for more efficient processing of large datasets and complex computations.

In [5]:
if torch.cuda.is_available():
    PRINTM(f"GPU is available.")
    device = "cuda"
else:
    PRINTM(f"GPU is not available. Using CPU instead.")
    device = "cpu"

--------------------------------------------------------------------------------
GPU is available.
--------------------------------------------------------------------------------


In [6]:
PRINTM(f"PyTorch version: {torch.__version__}")
PRINTM(f"CUDA available: {torch.cuda.is_available()}")
PRINTM(f"CUDA version:  {torch.version.cuda}")
print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No CUDA'}")

--------------------------------------------------------------------------------
PyTorch version: 2.3.1+cu121
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CUDA available: True
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CUDA version:  12.1
--------------------------------------------------------------------------------
CUDA device: NVIDIA RTX 6000 Ada Generation


In [6]:
uniprot_mapping = pd.read_csv(os.path.join('datasets', 'idmapping_unip.tsv'), delimiter = "\t")
PRINT()

--------------------------------------------------------------------------------
Done
--------------------------------------------------------------------------------


In [7]:
def convert_uniprot_ids(dataset, mapping_df):
    # Create a dictionary from the mapping dataframe
    mapping_dict = mapping_df.set_index('From')['Entry'].to_dict()

    # Map the uniprot_id1 and uniprot_id2 columns to their respective Entry values
    dataset['uniprot_id1'] = dataset['uniprot_id1'].map(mapping_dict)
    dataset['uniprot_id2'] = dataset['uniprot_id2'].map(mapping_dict)
    return dataset.drop_duplicates()

In [9]:
class ChemBERTaPT(nn.Module):
    def __init__(self):
        super(ChemBERTaPT, self).__init__()
        self.model_name = "DeepChem/ChemBERTa-77M-MTR"
        self.chemberta = RobertaModel.from_pretrained(self.model_name)

    def forward(self, input_ids, attention_mask):
        bert_output = self.chemberta(input_ids=input_ids, attention_mask=attention_mask)
        return bert_output[1]

In [8]:
class PretrainedChempropModel(nn.Module):
    def __init__(self, checkpoints_path, batch_size):
        super(PretrainedChempropModel, self).__init__()
        self.mpnn = self.load_pretrained_model(checkpoints_path)
        self.featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
        self.batch_size = batch_size
        
    def forward(self, smiles):
        # Prepare the data in order to generate embeddings from modulators SMILES
        self.smiles_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smiles]
        self.smiles_dset = data.MoleculeDataset(self.smiles_data, featurizer=self.featurizer)
        self.smiles_loader = data.build_dataloader(self.smiles_dset, batch_size=self.batch_size, shuffle=False)
        
        embeddings = [
            # Etract the embedding from the last FFN layer, i.e., before the final prediction(thus i=-1)
            self.mpnn.predictor.encode(self.fingerprints_from_batch_molecular_graph(batch, self.mpnn), i=-1) 
            for batch in self.smiles_loader
        ]
        #print(embeddings)
        if not embeddings:
             return torch.empty(0, device=device)
        embeddings = torch.cat(embeddings, 0).to(device)
        return embeddings

    def fingerprints_from_batch_molecular_graph(self, batch, mpnn):
        batch.bmg.to(device)
        H_v = mpnn.message_passing(batch.bmg, batch.V_d)
        H = mpnn.agg(H_v, batch.bmg.batch)
        H = mpnn.bn(H)
        fingerprints = H if batch.X_d is None else torch.cat((H, mpnn.batch.X_d_transform(X_d)), 1)
        return fingerprints

    def load_pretrained_model(self, checkpoints_path):
        mpnn = models.MPNN.load_from_checkpoint(checkpoints_path).to(device)
        return mpnn

In [None]:
def data_augmentation_with_uniprots_order_switchings(df):
    # generate a copy of the DataFrame with swapped uniprot_id1 and uniprot_id2
    swapped_df = df.copy()
    swapped_df[['uniprot_id1', 'uniprot_id2']] = swapped_df[['uniprot_id2', 'uniprot_id1']]

    # concatenate the original and swapped DataFrames & drop duplicated samples
    combined_df = pd.concat([df, swapped_df])
    combined_df = combined_df.drop_duplicates()

    return combined_df

# Models v2 (with Attention) #

In [11]:
class AbstractModel(ABC, nn.Module):
    def __init__(self):
    	super(AbstractModel, self).__init__()
    
    @abstractmethod
    def forward(self, bmg, bpsf1, bpsf2, esm, custom, fegs, gae,
    	input_ids, attention_mask,
    	morgan_fingerprints, chemical_descriptors):
    	pass

    def train_model(self, num_epochs, train_loader, val_loader, optimizer, criterion, device):
        PRINTM(f'Start training !')
        for epoch in range(num_epochs):
            start_time = time.time()
            self.train()
            running_loss = 0.0
            for (batch_smiles, batch_psf1, batch_psf2, batch_esm_features, batch_custom_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc, batch_labels) in train_loader:
                # Move tensors to the configured device
                batch_attention_mas = batch_attention_mas.to(device)
                batch_psf1 = batch_psf1.to(device)
                batch_psf2 = batch_psf2.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_custom_features = batch_custom_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                
                optimizer.zero_grad()
                outputs = self(batch_smiles, batch_psf1, batch_psf2, batch_esm_features,batch_custom_features,
                               batch_fegs_features, batch_gae_features, batch_input_ids, batch_attention_mas,batch_morgan, batch_chem_desc)

                loss = criterion(outputs.squeeze(), batch_labels)    
                loss.backward()
                optimizer.step()
    
            # Validate the model on the validation set
            val_loss, val_accuracy, val_auc = self.validate_model(val_loader, criterion, device)
            end_time = time.time()
            epoch_time = (end_time - start_time) / 60
            PRINTC()
            print(f"Epoch: {epoch+1}")
            print(f"Validation BCEWithLogitsLoss: {val_loss:.5f}")
            print(f"Validation Accuracy (>0.8): {val_accuracy:.2f}")
            print(f"Validation AUC: {val_auc:.5f}")
            print(f"Epoch time: {epoch_time:.2f} minutes")
            PRINTC()
    
        print("Finish training !")

    def test_model(self, test_dataset, criterion, batch_size, shuffle, device):
        test_dataset = MoleculeDataset(test_dataset)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=test_dataset.collate_fn)
        self.eval()
        
        test_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []
                
        with torch.no_grad():
            for (batch_smiles, batch_psf1, batch_psf2, batch_esm_features, batch_custom_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc, batch_labels) in test_loader:
                # Move tensors to the configured device
                batch_attention_mas = batch_attention_mas.to(device)
                batch_psf1 = batch_psf1.to(device)
                batch_psf2 = batch_psf2.to(device)                
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_custom_features = batch_custom_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)
    
                outputs = self(batch_smiles, batch_psf1, batch_psf2, batch_esm_features, batch_custom_features,
                               batch_fegs_features, batch_gae_features, batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc)
    
                loss = criterion(outputs.squeeze(), batch_labels)
                test_loss += loss.item()

                all_labels.extend(batch_labels.cpu().numpy())  
                all_outputs.extend(outputs.squeeze().cpu().numpy())  
    
                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
        
    
        test_loss /= len(test_loader)
        accuracy = correct / total
        test_auc = roc_auc_score(all_labels, all_outputs) 
        PRINTC() 
        print(f"Test AUC: {test_auc:.5f}")
        PRINTC()
        return test_auc

    def validate_model(self, val_loader, criterion, device):
        self.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []
        with torch.no_grad():
            for (batch_smiles, batch_psf1, batch_psf2, batch_esm_features, batch_custom_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:
                # Move tensors to the configured device
                batch_attention_mas = batch_attention_mas.to(device)
                batch_psf1 = batch_psf1.to(device)
                batch_psf2 = batch_psf2.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_custom_features = batch_custom_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)
    
                outputs = self(batch_smiles, batch_psf1, batch_psf2, batch_esm_features,batch_custom_features,
                               batch_fegs_features, batch_gae_features, batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc)
                loss = criterion(outputs.squeeze(), batch_labels)
                val_loss += loss.item()
    
                all_labels.extend(batch_labels.cpu().numpy())  
                all_outputs.extend(outputs.squeeze().cpu().numpy())  
    
                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
    
        val_loss /= len(val_loader)
        accuracy = correct / total
        val_auc = roc_auc_score(all_labels, all_outputs)  
        return val_loss, accuracy, val_auc

    def cross_validate(self, dataset, num_folds=5,num_epochs=10, batch_size=32, learning_rate=0.0001, weight_decay=1e-5, shuffle=True, device='cuda'):
        kf = KFold(n_splits=num_folds, shuffle=shuffle)
        
        fold_results = []
        
        for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
            
            print(f"Fold {fold+1}/{num_folds}")
            
            # Split dataset
            train_subset = dataset.iloc[train_idx].reset_index(drop=True)
            val_subset = dataset.iloc[val_idx].reset_index(drop=True)
            
            train_dataset = MoleculeDataset(train_subset)
            val_dataset = MoleculeDataset(val_subset)
            
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=train_dataset.collate_fn)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=val_dataset.collate_fn)
            
            criterion = nn.BCEWithLogitsLoss()
            optimizer = optim.AdamW(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
            
            self.train_model(num_epochs, train_loader, val_loader, optimizer, criterion, device)
            
            # Validate the model
            val_loss, val_accuracy, val_auc = self.validate_model(val_loader, criterion, device)
            fold_results.append((val_loss, val_accuracy, val_auc))

            PRINTC()
            print(f"Fold {fold+1} - Validation BCEWithLogitsLoss: {val_loss:.5f}, Accuracy: {val_accuracy:.2f}, AUC: {val_auc:.5f}")
            PRINTC()
            
        avg_val_loss = sum([result[0] for result in fold_results]) / num_folds
        avg_val_accuracy = sum([result[1] for result in fold_results]) / num_folds
        avg_val_auc = sum([result[2] for result in fold_results]) / num_folds
        
        print(f"\nAverage Validation BCEWithLogitsLoss: {avg_val_loss:.5f}")
        print(f"Average Validation Accuracy: {avg_val_accuracy:.2f}")
        print(f"Average Validation AUC: {avg_val_auc:.5f}")
        
        return fold_results


In [10]:
class custom_self_attention(nn.Module):
    def __init__(self, embed_dim_, num_heads_, dropout_):
        super(custom_self_attention, self).__init__()
        self.self_attention = nn.MultiheadAttention(embed_dim=embed_dim_, num_heads=num_heads_, dropout= dropout_)
        self.norm_layer = nn.LayerNorm(embed_dim_)

    def forward(self, embeddings_mat):
        # Apply self-attention for PPI
        embeddings_mat = embeddings_mat.permute(1, 0, 2)  # Change to (num_heads, batch_size, embed_dim) for MultiheadAttention
        attn_output, attn_weights = self.self_attention(embeddings_mat, embeddings_mat, embeddings_mat)
        attn_output = attn_output.permute(1, 0, 2)  # shape ->> (batch_size, num_heads, embed_dim)

        # Add & Norm
        embeddings_mat = embeddings_mat.permute(1, 0, 2)  # Back to original shape for residual connection
        attn_output = (0.5*attn_output) + (0.5*embeddings_mat)  # Add (residual connection) & apply weighted residual connection 
        attn_output = self.norm_layer(attn_output)  # Apply LayerNorm

        # Flatten the output for the next MLP layer
        embeddings_mat = attn_output.flatten(start_dim=1)  # Shape: (batch_size, num_heads*embed_dim)
        
        return embeddings_mat

In [44]:
# With Attention (first version) - without structure features
class AUVG_PPI(AbstractModel):
    def __init__(self, pretrained_chemprop_model, chemberta_model, dropout):
        
        super(AUVG_PPI, self).__init__()
        self.pretrained_chemprop_model = pretrained_chemprop_model
        self.chemberta_model = chemberta_model
        self.dropout = dropout
        self.ppi_self_attention = custom_self_attention(512, 8, 0.2)
        self.smiles_self_attention = custom_self_attention(384, 4, 0.2)
        self.cross_attention = nn.MultiheadAttention(512, 8, 0.2)
        self.max_pool = nn.MaxPool1d(2)
        
        # PPI Features MLP layers: (esm, custom, fegs, gae)
        self.esm_mlp = nn.Sequential(
            nn.Linear(in_features=1280 + 1280 , out_features=1750),
            nn.ReLU(),
            nn.BatchNorm1d(1750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1750, out_features=1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1000, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        self.fegs_mlp = nn.Sequential(
            nn.Linear(in_features=578 + 578, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )        

        self.custom_mlp = nn.Sequential(
            nn.Linear(in_features=4700 + 4700 , out_features=8000),
            nn.ReLU(),
            nn.BatchNorm1d(8000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=8000, out_features=6500),
            nn.ReLU(),
            nn.BatchNorm1d(6500),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=6500, out_features=5000),
            nn.ReLU(),
            nn.BatchNorm1d(5000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=5000, out_features=3500),
            nn.ReLU(),
            nn.BatchNorm1d(3500),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=3500, out_features=2000),
            nn.ReLU(),
            nn.BatchNorm1d(2000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=2000, out_features=1028),
            nn.ReLU(),
            nn.BatchNorm1d(1028),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1028, out_features=512)
        )

        self.gae_mlp = nn.Sequential(
            nn.Linear(in_features=500 + 500, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        # MLP for ppi_features
        self.ppi_mlp = nn.Sequential(
            nn.Linear(in_features=512 * 4 , out_features= 1536),
            nn.ReLU(),
            nn.BatchNorm1d(1536),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1536, out_features=1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512)
        )
        
        self.fp_mlp = nn.Sequential(
            nn.Linear(in_features=2100, out_features=1536),
            nn.ReLU(),
            nn.BatchNorm1d(1536),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1536, out_features=1024), 
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=384)
        )

        # Morgan fingerprints & chemical descriptors MLP layers
        self.mfp_cd_mlp = nn.Sequential(
            nn.Linear(in_features=1024 + 194, out_features= 750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=384)
        )

        # MLP for smiles_embeddings
        self.smiles_mlp = nn.Sequential(
            nn.Linear(in_features=384 * 3 , out_features= 750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        self.additional_layers = nn.Sequential(
            nn.Linear(in_features=256 + 256, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=64, out_features=1)
        )
        
        #self.sigmoid = nn.Sigmoid()

    def forward(self, bmg, esm, custom, fegs, gae,
                input_ids, attention_mask,
                morgan_fingerprints, chemical_descriptors):
        # Forward pass batch mol graph through pretrained chemprop model in order to get fingerprints embeddings
        # Afterwards, pass the fingerprints through MLP layer
        cp_fingerprints = self.pretrained_chemprop_model(bmg)
        cp_fingerprints = self.fp_mlp(cp_fingerprints)

        chemberta_embeddings = self.chemberta_model(input_ids, attention_mask)
        #chemberta_embeddings = self.chemberta_mlp(chemberta_embeddings)
        mfp_chem_descriptors = torch.cat([morgan_fingerprints, chemical_descriptors], dim=1)
        mfp_chem_descriptors = self.mfp_cd_mlp(mfp_chem_descriptors)
        
        # Concatenate all 3 smiles embeddings along a new dimension (3x384) & pass them throw self-attention layer
        smiles_embeddings = torch.stack([cp_fingerprints, chemberta_embeddings, mfp_chem_descriptors], dim=1).to(device)  # shape ->> (batch_size, 3, 384)
        smiles_features = self.smiles_self_attention(smiles_embeddings)
        smiles_embeddings = self.smiles_mlp(smiles_features).unsqueeze(1)

        # Pass all PPI features  through MLP layers, and then pass them all together into another MLP layer
        #ppi_features = proteins.to(device)
        esm_embeddings = self.esm_mlp(esm)
        custom_embeddings = self.custom_mlp(custom)
        fegs_embeddings = self.fegs_mlp(fegs)
        gae_embeddings = self.gae_mlp(gae)

        # Concatenate all 4 ppi embeddings along a new dimension (4x512) & pass them throw self-attention layer
        ppi_embeddings = torch.stack([esm_embeddings, custom_embeddings, fegs_embeddings, gae_embeddings], dim=1).to(device)  # shape ->> (batch_size, 4, 320)
        ppi_features = self.ppi_self_attention(ppi_embeddings)
        ppi_features = self.ppi_mlp(ppi_features).unsqueeze(1)

        #Cross-attention between smiles and PPI to capture the interaction relationships
        ppi_QKV = ppi_features.permute(1, 0, 2)
        smiles_QKV = smiles_embeddings.permute(1, 0, 2)
        
        smiles_att, _ = self.cross_attention(smiles_QKV, ppi_QKV, ppi_QKV)
        ppi_att, _ = self.cross_attention(ppi_QKV, smiles_QKV, smiles_QKV)

        # permute attention outputrs to match (batch_size, embed_dim, num_heads) shape
        smiles_attn_output = (0.5* smiles_att.permute(1, 2, 0)) + (0.5* smiles_embeddings.permute(0, 2, 1))  # Add (residual connection) & apply weighted residual connection 
        ppi_attn_output = (0.5* ppi_att.permute(1, 2, 0)) + (0.5* ppi_features.permute(0, 2, 1))  # Add (residual connection) & apply weighted residual connection 

        # Drop the last dim in order to get (batch_size, embed_dim) & 
        # Pass cross-attention norm outputs throw max-pool layer before passing throw MLP layers
        smiles_att = self.max_pool(smiles_attn_output.squeeze(2))
        ppi_att = self.max_pool(ppi_attn_output.squeeze(2)) 
        combined_embeddings = torch.cat([smiles_att, ppi_att], dim=1)
        output = self.additional_layers(combined_embeddings)
        
        return output
        #return self.sigmoid(output)

In [12]:
class FeatureReducer_(nn.Module):
    # Feature reducer for joint attention in PPI structure feature - in order to reduce tensors dim for math operations
    # Use this class if |UniProt_NumOfAminoAcidComp| < 128
    def __init__(self, in_channels, out_channels):
        super(FeatureReducer_, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
    
    def forward(self, x):
        # x shape: [batch_size, sequence_length, in_channels]
        x = x.transpose(1, 2)  # Change shape to [batch_size, in_channels, sequence_length]
        x = self.conv(x)       
        x = x.transpose(1, 2)  # Change shape back to [batch_size, target_length, out_channels]
        return x

In [13]:
class FeatureReducer(nn.Module):
    # Feature reducer for joint attention in PPI structure feature - in order to reduce tensors dim for math operations
    # Use this class if |UniProt_NumOfAminoAcidComp| >= 128
    def __init__(self, in_channels, out_channels, target_length):
        super(FeatureReducer, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.pool = nn.AdaptiveAvgPool1d(target_length)
    
    def forward(self, x):
        # x shape: [batch_size, sequence_length, in_channels]
        x = x.transpose(1, 2)  # Change shape to [batch_size, in_channels, sequence_length]
        x = self.conv(x)    
        x = self.pool(x) 
        x = x.transpose(1, 2)  # Change shape back to [batch_size, target_length, out_channels]
        return x

In [14]:
# With Attention (second version) - wthi structure features
class AUVG_PPI(AbstractModel):
    def __init__(self, pretrained_chemprop_model, chemberta_model, dropout):
        
        super(AUVG_PPI, self).__init__()
        self.pretrained_chemprop_model = pretrained_chemprop_model
        self.chemberta_model = chemberta_model
        self.dropout = dropout
        self.ppi_self_attention = custom_self_attention(512, 8, 0.2)
        self.smiles_self_attention = custom_self_attention(384, 4, 0.2)
        self.cross_attention = nn.MultiheadAttention(512, 8, 0.2)
        self.max_pool = nn.MaxPool1d(2)
        self.compound_dim = 512
        self.W_p1, self.W_p2 = nn.Linear(self.compound_dim, self.compound_dim), nn.Linear(self.compound_dim, self.compound_dim)

        
        # PPI Features MLP layers: (esm, custom, fegs, gae)
        self.esm_mlp = nn.Sequential(
            nn.Linear(in_features=1280 + 1280 , out_features=1750),
            nn.ReLU(),
            nn.BatchNorm1d(1750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1750, out_features=1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1000, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        self.fegs_mlp = nn.Sequential(
            nn.Linear(in_features=578 + 578, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )        

        self.custom_mlp = nn.Sequential(
            nn.Linear(in_features=4700 + 4700 , out_features=8000),
            nn.ReLU(),
            nn.BatchNorm1d(8000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=8000, out_features=6500),
            nn.ReLU(),
            nn.BatchNorm1d(6500),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=6500, out_features=5000),
            nn.ReLU(),
            nn.BatchNorm1d(5000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=5000, out_features=3500),
            nn.ReLU(),
            nn.BatchNorm1d(3500),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=3500, out_features=2000),
            nn.ReLU(),
            nn.BatchNorm1d(2000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=2000, out_features=1028),
            nn.ReLU(),
            nn.BatchNorm1d(1028),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1028, out_features=512)
        )

        self.gae_mlp = nn.Sequential(
            nn.Linear(in_features=500 + 500, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        # MLP for ppi_features
        self.ppi_mlp = nn.Sequential(
            nn.Linear(in_features=512 * 5 , out_features= 1536),
            nn.ReLU(),
            nn.BatchNorm1d(1536),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1536, out_features=1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512)
        )
        
        self.fp_mlp = nn.Sequential(
            nn.Linear(in_features=2100, out_features=1536),
            nn.ReLU(),
            nn.BatchNorm1d(1536),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1536, out_features=1024), 
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=384)
        )

        # Morgan fingerprints & chemical descriptors MLP layers
        self.mfp_cd_mlp = nn.Sequential(
            nn.Linear(in_features=1024 + 194, out_features= 750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=384)
        )

        # MLP for smiles_embeddings
        self.smiles_mlp = nn.Sequential(
            nn.Linear(in_features=384 * 3 , out_features= 750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        self.additional_layers = nn.Sequential(
            nn.Linear(in_features=256 + 256, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=64, out_features=1)
        )
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        
        #self.sigmoid = nn.Sigmoid()
    # bptfs -> batch protein tuple feature structure
    def forward(self, bmg, bpsf1, bpsf2, esm, custom, fegs, gae,
                input_ids, attention_mask,
                morgan_fingerprints, chemical_descriptors):
        # Forward pass batch mol graph through pretrained chemprop model in order to get fingerprints embeddings
        # Afterwards, pass the fingerprints through MLP layer
        cp_fingerprints = self.pretrained_chemprop_model(bmg)
        cp_fingerprints = self.fp_mlp(cp_fingerprints)

        chemberta_embeddings = self.chemberta_model(input_ids, attention_mask)
        #chemberta_embeddings = self.chemberta_mlp(chemberta_embeddings)
        mfp_chem_descriptors = torch.cat([morgan_fingerprints, chemical_descriptors], dim=1)
        mfp_chem_descriptors = self.mfp_cd_mlp(mfp_chem_descriptors)
        
        # Concatenate all 3 smiles embeddings along a new dimension (3x384) & pass them throw self-attention layer
        smiles_embeddings = torch.stack([cp_fingerprints, chemberta_embeddings, mfp_chem_descriptors], dim=1).to(device)  # shape ->> (batch_size, 3, 384)
        smiles_features = self.smiles_self_attention(smiles_embeddings)
        smiles_embeddings = self.smiles_mlp(smiles_features).unsqueeze(1)

        # Pass all PPI features  through MLP layers, and then pass them all together into another MLP layer
        esm_embeddings = self.esm_mlp(esm)
        custom_embeddings = self.custom_mlp(custom)
        fegs_embeddings = self.fegs_mlp(fegs)
        gae_embeddings = self.gae_mlp(gae)
        
        # Structure features
        if bpsf1.shape[1] > 128: feature_reducer_p1 = FeatureReducer(in_channels=722, out_channels=512, target_length=128).to(device)
        else: feature_reducer_p1 = FeatureReducer_(in_channels=722, out_channels=512).to(device)
        if bpsf2.shape[1] > 128: feature_reducer_p2 = FeatureReducer(in_channels=722, out_channels=512, target_length=128).to(device)
        else: feature_reducer_p2 = FeatureReducer_(in_channels=722, out_channels=512).to(device)
        bpsf1 = feature_reducer_p1(bpsf1)
        bpsf2 = feature_reducer_p2(bpsf2)
        #print(f'bpsf1 -> {bpsf1.shape}, bpsf2 -> {bpsf2.shape}')
        inter_comp_prot = self.sigmoid(torch.einsum('bij,bkj->bik', self.W_p1(self.relu(bpsf1)), self.W_p2(self.relu(bpsf2))))
        #print(f'inter_comp_prot -> {inter_comp_prot.shape}')
        inter_comp_prot_sum = torch.einsum('bij->b', inter_comp_prot)
        inter_comp_prot = torch.einsum('bij,b->bij', inter_comp_prot, 1/inter_comp_prot_sum)
        #print(f'after, inter_comp_prot -> {inter_comp_prot.shape}')
        
        # compound-protein joint embedding
        cp_embedding = self.tanh(torch.einsum('bij,bkj->bikj', bpsf1, bpsf2))
        #print(cp_embedding.shape)
        cp_embedding = torch.einsum('bijk,bij->bk', cp_embedding, inter_comp_prot)
        #print(f'end, cp_embedding -> {cp_embedding.shape}')
        
        # Concatenate all 4 ppi embeddings along a new dimension (4x512) & pass them throw self-attention layer
        ppi_embeddings = torch.stack([cp_embedding, esm_embeddings, custom_embeddings, fegs_embeddings, gae_embeddings], dim=1).to(device)  # shape ->> (batch_size, 4, 320)
        ppi_features = self.ppi_self_attention(ppi_embeddings)
        ppi_features = self.ppi_mlp(ppi_features).unsqueeze(1)

        #Cross-attention between smiles and PPI to capture the interaction relationships
        ppi_QKV = ppi_features.permute(1, 0, 2)
        smiles_QKV = smiles_embeddings.permute(1, 0, 2)
        
        smiles_att, _ = self.cross_attention(smiles_QKV, ppi_QKV, ppi_QKV)
        ppi_att, _ = self.cross_attention(ppi_QKV, smiles_QKV, smiles_QKV)

        # permute attention outputrs to match (batch_size, embed_dim, num_heads) shape
        smiles_attn_output = (0.5* smiles_att.permute(1, 2, 0)) + (0.5* smiles_embeddings.permute(0, 2, 1))  # Add (residual connection) & apply weighted residual connection 
        ppi_attn_output = (0.5* ppi_att.permute(1, 2, 0)) + (0.5* ppi_features.permute(0, 2, 1))  # Add (residual connection) & apply weighted residual connection 

        # Drop the last dim in order to get (batch_size, embed_dim) & 
        # Pass cross-attention norm outputs throw max-pool layer before passing throw MLP layers
        smiles_att = self.max_pool(smiles_attn_output.squeeze(2))
        ppi_att = self.max_pool(ppi_attn_output.squeeze(2)) 
        combined_embeddings = torch.cat([smiles_att, ppi_att], dim=1)
        output = self.additional_layers(combined_embeddings)
        
        return output
        #return self.sigmoid(output)

## MoleculeDatasets ##

In [19]:
# for training with prob
class MoleculeDataset(Dataset):
    def __init__(self, ds_):
        # Initialize data and load other features
        self.data = ds_
        self.mapping_df = pd.read_csv(os.path.join('datasets', 'idmapping_unip.tsv'), delimiter = "\t")
        self.esm = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'esm_features.csv'))
        self.custom = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'custom_features.csv'))
        self.fegs = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'fegs_features.csv'))
        self.uniprots = self.data.drop(columns=['smiles', 'label'])

        gae_path = f'GAE_FEATURES_WITH_PREDICTED_alpha_0.25.csv'
        self.gae = pd.read_csv(os.path.join('datasets', 'GAE', gae_path))
        gae_features_columns = self.gae.iloc[:, 9:509]
        gae_uniprot_column = self.gae[['From']].rename(columns={'From': 'UniProt_ID'})
        self.gae = pd.concat([gae_uniprot_column, gae_features_columns], axis=1)
        
        self.uniprots = self.data.drop(columns=['smiles', 'label'])

        # Merge datasets
        self.esm_features_ppi = self.merge_datasets(self.data, self.esm).drop(columns=['smiles', 'label']).astype(np.float32)
        self.custom_features_ppi = self.merge_datasets(self.data, self.custom).drop(columns=['smiles', 'label']).astype(np.float32)
        self.fegs_features_ppi = self.merge_datasets(self.data, self.fegs).drop(columns=['smiles', 'label']).astype(np.float32)
        self.gae_features_ppi = self.merge_datasets(self.data, self.gae).drop(columns=['smiles', 'label']).astype(np.float32)

        # SMILES RDKit features - Morgan Fingerprints (r=4, nbits=1024) & chemical descriptors
        self.smiles_morgan_fingerprints = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_morgan_fingerprints_dataset.csv'))
        self.smiles_chemical_descriptors = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_chem_descriptors_mapping_dataset.csv'))

        # Necessary features for ChemBERTa model
        self.smiles_list = self.data['smiles'].tolist()
        self.tokenizer = RobertaTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
        self.encoded_smiles = self.tokenizer(self.smiles_list, truncation=True, padding=True, return_tensors="pt")

    def merge_datasets(self, dataset, features_df):
        # Existing merging logic
        dataset = dataset.merge(features_df, how='left', left_on='uniprot_id1', right_on='UniProt_ID', suffixes=('', '_id1'))
        dataset = dataset.drop(columns=['UniProt_ID'])
        
        features_df_renamed = features_df.add_suffix('_id2')
        features_df_renamed = features_df_renamed.rename(columns={'UniProt_ID_id2': 'UniProt_ID'})
        dataset = dataset.merge(features_df_renamed, how='left', left_on='uniprot_id2', right_on='UniProt_ID', suffixes=('', '_id2'))
        dataset = dataset.drop(columns=['UniProt_ID', 'uniprot_id1', 'uniprot_id2'])
        
        return dataset.drop_duplicates()

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        smiles = self.data.iloc[idx, 0]
        label = np.array(self.data.iloc[idx, -1], dtype=np.float32)  
        uniprots_tuple = (self.data.iloc[idx, 1], self.data.iloc[idx, 2]) # tuple that hold uniprot_id1 and uniprots_id2 -> for prob in testing phase
        esm_features = np.array(self.esm_features_ppi.iloc[idx].values, dtype=np.float32)
        custom_features = np.array(self.custom_features_ppi.iloc[idx].values, dtype=np.float32)
        fegs_features = np.array(self.fegs_features_ppi.iloc[idx].values, dtype=np.float32)
        gae_features = np.array(self.gae_features_ppi.iloc[idx].values, dtype=np.float32)

        input_ids = self.encoded_smiles["input_ids"][idx]
        attention_mask = self.encoded_smiles["attention_mask"][idx]

        # Retrieve precomputed RDKit Morgan fingerprints
        morgan_fingerprint = self.smiles_morgan_fingerprints.loc[self.smiles_morgan_fingerprints['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemical_descriptors = self.smiles_chemical_descriptors.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        
        return (smiles, uniprots_tuple, esm_features, custom_features, fegs_features, gae_features, 
                input_ids, attention_mask, morgan_fingerprint, chemical_descriptors, label)


In [20]:
# for training without prob
class MoleculeDataset(Dataset):
    def __init__(self, ds_):
        # Initialize data and load other features
        self.data = ds_
        self.mapping_df = pd.read_csv(os.path.join('datasets', 'idmapping_unip.tsv'), delimiter = "\t")
        self.esm = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'esm_features.csv'))
        self.custom = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'custom_features.csv'))
        self.fegs = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'fegs_features.csv'))

        gae_path = f'GAE_FEATURES_WITH_PREDICTED_alpha_0.25.csv'
        self.gae = pd.read_csv(os.path.join('datasets', 'GAE', gae_path))
        gae_features_columns = self.gae.iloc[:, 9:509]
        gae_uniprot_column = self.gae[['From']].rename(columns={'From': 'UniProt_ID'})
        self.gae = pd.concat([gae_uniprot_column, gae_features_columns], axis=1)
        

        # Merge datasets
        self.esm_features_ppi = self.merge_datasets(self.data, self.esm).drop(columns=['smiles', 'label']).astype(np.float32)
        self.custom_features_ppi = self.merge_datasets(self.data, self.custom).drop(columns=['smiles', 'label']).astype(np.float32)
        self.fegs_features_ppi = self.merge_datasets(self.data, self.fegs).drop(columns=['smiles', 'label']).astype(np.float32)
        self.gae_features_ppi = self.merge_datasets(self.data, self.gae).drop(columns=['smiles', 'label']).astype(np.float32)

        # SMILES RDKit features - Morgan Fingerprints (r=4, nbits=1024) & chemical descriptors
        self.smiles_morgan_fingerprints = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_morgan_fingerprints_dataset.csv'))
        self.smiles_chemical_descriptors = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_chem_descriptors_mapping_dataset.csv'))

        # Necessary features for ChemBERTa model
        self.smiles_list = self.data['smiles'].tolist()
        self.tokenizer = RobertaTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
        self.encoded_smiles = self.tokenizer(self.smiles_list, truncation=True, padding=True, return_tensors="pt")

    def merge_datasets(self, dataset, features_df):
        # Existing merging logic
        dataset = dataset.merge(features_df, how='left', left_on='uniprot_id1', right_on='UniProt_ID', suffixes=('', '_id1'))
        dataset = dataset.drop(columns=['UniProt_ID'])
        
        features_df_renamed = features_df.add_suffix('_id2')
        features_df_renamed = features_df_renamed.rename(columns={'UniProt_ID_id2': 'UniProt_ID'})
        dataset = dataset.merge(features_df_renamed, how='left', left_on='uniprot_id2', right_on='UniProt_ID', suffixes=('', '_id2'))
        dataset = dataset.drop(columns=['UniProt_ID', 'uniprot_id1', 'uniprot_id2'])
        
        return dataset.drop_duplicates()

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        smiles = self.data.iloc[idx, 0]
        label = np.array(self.data.iloc[idx, -1], dtype=np.float32)  
        esm_features = np.array(self.esm_features_ppi.iloc[idx].values, dtype=np.float32)
        custom_features = np.array(self.custom_features_ppi.iloc[idx].values, dtype=np.float32)
        fegs_features = np.array(self.fegs_features_ppi.iloc[idx].values, dtype=np.float32)
        gae_features = np.array(self.gae_features_ppi.iloc[idx].values, dtype=np.float32)

        input_ids = self.encoded_smiles["input_ids"][idx]
        attention_mask = self.encoded_smiles["attention_mask"][idx]

        # Retrieve precomputed RDKit Morgan fingerprints
        morgan_fingerprint = self.smiles_morgan_fingerprints.loc[self.smiles_morgan_fingerprints['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemical_descriptors = self.smiles_chemical_descriptors.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        
        return (smiles, esm_features, custom_features, fegs_features, gae_features, 
                input_ids, attention_mask, morgan_fingerprint, chemical_descriptors, label)


In [15]:
class MoleculeDataset(Dataset):
    def __init__(self, ds_):
        self.data = ds_
        self.mapping_df = pd.read_csv(os.path.join('datasets', 'idmapping_unip.tsv'), delimiter="\t")
        self.esm = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'esm_features.csv'))
        self.custom = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'custom_features.csv'))
        self.fegs = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'fegs_features.csv'))

        gae_path = f'GAE_FEATURES_WITH_PREDICTED_alpha_0.25.csv'
        self.gae = pd.read_csv(os.path.join('datasets', 'GAE', gae_path))
        gae_features_columns = self.gae.iloc[:, 9:509]
        gae_uniprot_column = self.gae[['From']].rename(columns={'From': 'UniProt_ID'})
        self.gae = pd.concat([gae_uniprot_column, gae_features_columns], axis=1)

        # Merge datasets
        self.esm_features_ppi = self.merge_datasets(self.data, self.esm).drop(columns=['smiles', 'label']).astype(np.float32)
        self.custom_features_ppi = self.merge_datasets(self.data, self.custom).drop(columns=['smiles', 'label']).astype(np.float32)
        self.fegs_features_ppi = self.merge_datasets(self.data, self.fegs).drop(columns=['smiles', 'label']).astype(np.float32)
        self.gae_features_ppi = self.merge_datasets(self.data, self.gae).drop(columns=['smiles', 'label']).astype(np.float32)

        # SMILES RDKit features - Morgan Fingerprints (r=4, nbits=1024) & chemical descriptors
        self.smiles_morgan_fingerprints = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_morgan_fingerprints_dataset.csv'))
        self.smiles_chemical_descriptors = pd.read_csv(os.path.join('datasets', 'MolDatasets', 'smiles_chem_descriptors_mapping_dataset.csv'))

        # Necessary features for ChemBERTa model
        self.smiles_list = self.data['smiles'].tolist()
        self.tokenizer = RobertaTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
        self.encoded_smiles = self.tokenizer(self.smiles_list, truncation=True, padding=True, return_tensors="pt")

        # Protein structure feature extraction
        self.uniprots = ds_[['uniprot_id1', 'uniprot_id2']]

        # Preload all protein structure feature files into a dictionary
        self.protein_structure_dict = self.load_protein_structure_features()

    def merge_datasets(self, dataset, features_df):
        # Existing merging logic
        dataset = dataset.merge(features_df, how='left', left_on='uniprot_id1', right_on='UniProt_ID', suffixes=('', '_id1'))
        dataset = dataset.drop(columns=['UniProt_ID'])
        
        features_df_renamed = features_df.add_suffix('_id2')
        features_df_renamed = features_df_renamed.rename(columns={'UniProt_ID_id2': 'UniProt_ID'})
        dataset = dataset.merge(features_df_renamed, how='left', left_on='uniprot_id2', right_on='UniProt_ID', suffixes=('', '_id2'))
        dataset = dataset.drop(columns=['UniProt_ID', 'uniprot_id1', 'uniprot_id2'])
        
        return dataset.drop_duplicates()
        
    def __len__(self):
        return len(self.data)

    def load_protein_structure_features(self):
        """Preload all protein structure feature CSVs into a dictionary."""
        protein_structure_dir = os.path.join('datasets', 'MolDatasets', 'ProteinStructureFeatures')
        protein_structure_files = os.listdir(protein_structure_dir)
        protein_structure_dict = {}
        
        for file in protein_structure_files:
            uniprot_id_key = file.replace('_ifeature_omega.csv', '')
            protein_structure_dict[uniprot_id_key] = pd.read_csv(os.path.join(protein_structure_dir, file)).iloc[:, 1:].astype(np.float32)
        return protein_structure_dict

    def __getitem__(self, idx):
        smiles = self.data.iloc[idx, 0]
        label = np.array(self.data.iloc[idx, -1], dtype=np.float32)  
        esm_features = np.array(self.esm_features_ppi.iloc[idx].values, dtype=np.float32)
        custom_features = np.array(self.custom_features_ppi.iloc[idx].values, dtype=np.float32)
        fegs_features = np.array(self.fegs_features_ppi.iloc[idx].values, dtype=np.float32)
        gae_features = np.array(self.gae_features_ppi.iloc[idx].values, dtype=np.float32)

        input_ids = self.encoded_smiles["input_ids"][idx]
        attention_mask = self.encoded_smiles["attention_mask"][idx]

        # Retrieve precomputed RDKit Morgan fingerprints
        morgan_fingerprint = self.smiles_morgan_fingerprints.loc[self.smiles_morgan_fingerprints['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)
        chemical_descriptors = self.smiles_chemical_descriptors.loc[self.smiles_chemical_descriptors['SMILES'] == smiles].iloc[0, 1:].values.astype(np.float32)

        # Protein structure feature extraction
        prot1_sfp, prot2_sfp = self.uniprots.iloc[idx, 0], self.uniprots.iloc[idx, 1]
        ans = self.checkProteins(prot1_sfp, prot2_sfp)
        
        if ans:
            prot1_sf = self.protein_structure_dict[ans]
            prot2_sf = self.protein_structure_dict[f'Q03164_with_{ans}']
        else:
            prot1_sf = self.protein_structure_dict[prot1_sfp]
            prot2_sf = self.protein_structure_dict[prot2_sfp]

        return (smiles, prot1_sf, prot2_sf, esm_features, custom_features, fegs_features, gae_features, 
                input_ids, attention_mask, morgan_fingerprint, chemical_descriptors, label)

    def checkProteins(self, unip1, unip2):
        if unip1 == 'Q03164':
            return unip2
        elif unip2 == 'Q03164':
            return unip1
        return None

    def collate_fn(self, batch):
        smiles, prot1_sfs, prot2_sfs, esm_features, custom_features, fegs_features, gae_features, input_ids, attention_masks, morgan_fingerprints, chemical_descriptors, labels = zip(*batch)

        # Convert lists of numpy arrays into tensors for protein structure features
        prot1_sfs = [torch.tensor(p.values) for p in prot1_sfs]
        prot2_sfs = [torch.tensor(p.values) for p in prot2_sfs]
        
        # Find the maximum number of rows (0th dimension) for prot1 and prot2 individually
        max_len_prot1 = max([p.shape[0] for p in prot1_sfs])
        max_len_prot2 = max([p.shape[0] for p in prot2_sfs])

        # Pad the protein structure features along the 0th dimension (rows) for each protein individually
        prot1_sfs_padded = torch.stack([torch.nn.functional.pad(p, (0, 0, 0, max_len_prot1 - p.shape[0]), "constant", 0) for p in prot1_sfs])
        prot2_sfs_padded = torch.stack([torch.nn.functional.pad(p, (0, 0, 0, max_len_prot2 - p.shape[0]), "constant", 0) for p in prot2_sfs])
    
        esm_features = torch.tensor(esm_features, dtype=torch.float32)
        custom_features = torch.tensor(custom_features, dtype=torch.float32)
        fegs_features = torch.tensor(fegs_features, dtype=torch.float32)
        gae_features = torch.tensor(gae_features, dtype=torch.float32)
        
        input_ids = torch.stack(input_ids)
        attention_masks = torch.stack(attention_masks)
        
        morgan_fingerprints = torch.tensor(morgan_fingerprints, dtype=torch.float32)
        chemical_descriptors = torch.tensor(chemical_descriptors, dtype=torch.float32)
        
        # Convert labels to a flat list of scalars and then to a tensor
        flattened_labels = [label.item() for label in labels]
        labels_tensor = torch.tensor(flattened_labels, dtype=torch.float32)
        
        return (smiles, prot1_sfs_padded, prot2_sfs_padded, esm_features, custom_features, fegs_features, gae_features, 
                input_ids, attention_masks, morgan_fingerprints, chemical_descriptors, labels_tensor)

In [16]:
def generate_model(checkpoint_path,
                   batch_size,
                  dropout) -> nn.Module:
    pretrained_chemprop_model = PretrainedChempropModel(checkpoints_path, batch_size)
    chemberta_model = ChemBERTaPT()
    ft_model = AUVG_PPI(pretrained_chemprop_model, chemberta_model, dropout).to(device)

    PRINTM('Generated combined model for fine-tuning successfully !')
    return ft_model

## Train the Model on multi_ppim_fold_2_0.8 Folds ##

### Load & Prepare the Dataset ###

In [17]:
ds_folder_path = os.path.join('datasets', 'finetune_dataset', 'multi_ppim_folds_2_0.8')
all_files = os.listdir(ds_folder_path)

PRINTM(f'Folder content:\n\n{all_files}')

--------------------------------------------------------------------------------
Folder content:

['train_fold1.csv', 'train_fold3.csv', 'test_fold3.csv', 'train_fold5.csv', 'test_fold5.csv', 'test_fold2.csv', 'test_fold4.csv', 'train_fold4.csv', 'test_fold1.csv', 'train_fold2.csv']
--------------------------------------------------------------------------------


In [18]:
dataframes = {}

# Read each CSV file into a dataframe and store it in the dictionary
for file in all_files:
    file_path = os.path.join(ds_folder_path, file)
    df = pd.read_csv(file_path)
    df_name = file.replace('.csv', '_df')
    dataframes[df_name] = df

In [19]:
uniprot_mapping = pd.read_csv(os.path.join('datasets', 'idmapping_unip.tsv'), delimiter = "\t")
ppi_features_df = pd.read_csv(os.path.join('datasets', 'merged_ppi_features.csv'))
PRINT()

--------------------------------------------------------------------------------
Done
--------------------------------------------------------------------------------


In [20]:
for df_name in dataframes.keys():
    dataframes[df_name] = convert_uniprot_ids(dataframes[df_name], uniprot_mapping)
    #dataframes[df_name] =data_augmentation_with_uniprots_order_switchings(dataframes[df_name])
    #dataframes[df_name] = merge_datasets(dataframes[df_name], ppi_features_df)

# Access each dataframe using its name
train_fold1_df = dataframes['train_fold1_df']
train_fold2_df = dataframes['train_fold2_df']
train_fold3_df = dataframes['train_fold3_df']
train_fold4_df = dataframes['train_fold4_df']
train_fold5_df = dataframes['train_fold5_df']
test_fold1_df = dataframes['test_fold1_df']
test_fold2_df = dataframes['test_fold2_df']
test_fold3_df = dataframes['test_fold3_df']
test_fold4_df = dataframes['test_fold4_df']
test_fold5_df = dataframes['test_fold5_df']

PRINTM(f'Done inverse mapping !')

--------------------------------------------------------------------------------
Done inverse mapping !
--------------------------------------------------------------------------------


In [None]:
checkpoints_path = os.path.join('pt_chemprop_checkpoint_r4_', 'fold_0', 'model_0', 'checkpoints', 'best-epoch=39-val_loss=0.39.ckpt')

### Train a Model & Test on Each Fold ###

#### Fold number 1 #####

In [None]:
ft_model_f1 = generate_model(checkpoints_path, batch_size=32, dropout=0.3)
f1_res = ft_model_f1.cross_validate(train_fold1_df, num_folds=5, num_epochs=2,
                     batch_size=32, learning_rate=1e-5, weight_decay=1e-3,
                     shuffle=True, device=device)

In [39]:
#13/09 - d=0.1, k=5, n=2, lr=1e-5, wd=1e-4 --> Overfit -> increase d
f1_auc = ft_model_f1.test_model(test_fold1_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.76603
--------------------------------------------------------------------------------


In [49]:
#13/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-4 
f1_auc = ft_model_f1.test_model(test_fold1_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.78468
--------------------------------------------------------------------------------


In [60]:
#13/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-3 
f1_auc = ft_model_f1.test_model(test_fold1_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.80617
--------------------------------------------------------------------------------


In [29]:
#13/09 - d=0.3, k=5, n=2, lr=1e-5, wd=1e-3 
f1_auc = ft_model_f1.test_model(test_fold1_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.70589
--------------------------------------------------------------------------------


#### Fold number 2 ####

In [None]:
ft_model_f2 = generate_model(checkpoints_path, batch_size=32, dropout=0.3)
f2_res = ft_model_f2.cross_validate(train_fold2_df, num_folds=5, num_epochs=2,
                     batch_size=32, learning_rate=1e-5, weight_decay=1e-3,
                     shuffle=True, device=device)

In [41]:
#13/09 - d=0.1, k=5, n=2, lr=1e-5, wd=1e-4 --> Overfit -> increase d
f2_auc = ft_model_f2.test_model(test_fold2_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.62343
--------------------------------------------------------------------------------


In [51]:
#13/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-4 
f2_auc = ft_model_f2.test_model(test_fold2_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.66849
--------------------------------------------------------------------------------


In [62]:
#13/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-3
f2_auc = ft_model_f2.test_model(test_fold2_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.68760
--------------------------------------------------------------------------------


In [31]:
#13/09 - d=0.3, k=5, n=2, lr=1e-5, wd=1e-3
f2_auc = ft_model_f2.test_model(test_fold2_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.72244
--------------------------------------------------------------------------------


#### Fold number 3 ####

In [26]:
ft_model_f3 = generate_model(checkpoints_path, batch_size=32, dropout=0.3)
f3_res = ft_model_f3.cross_validate(train_fold3_df, num_folds=5, num_epochs=2,
                     batch_size=32, learning_rate=1e-5, weight_decay=1e-3,
                     shuffle=True, device=device)

Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


--------------------------------------------------------------------------------
Generated combined model for fine-tuning successfully !
--------------------------------------------------------------------------------
Fold 1/5
--------------------------------------------------------------------------------
Start training !
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Epoch: 1
Validation BCEWithLogitsLoss: 0.74370
Validation Accuracy (>0.8): 0.67
Validation AUC: 0.60483
Epoch time: 3.98 minutes
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Epoch: 2
Validation BCEWithLogitsLoss: 0.70420
Validation Accuracy (>0.8): 0.71
Validation AUC: 0.68103
Epoch time: 4.05 minutes
--------------------------------------------------------------------------------
Finish train

In [43]:
#13/09 - d=0.1, k=5, n=2, lr=1e-5, wd=1e-4 --> Overfir -> increase d
f3_auc = ft_model_f3.test_model(test_fold3_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.72224
--------------------------------------------------------------------------------


In [53]:
#13/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-4 --> Overfit
f3_auc = ft_model_f3.test_model(test_fold3_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.77971
--------------------------------------------------------------------------------


In [64]:
#13/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-3 
f3_auc = ft_model_f3.test_model(test_fold3_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.75466
--------------------------------------------------------------------------------


In [27]:
#14/09 - d=0.3, k=5, n=2, lr=1e-5, wd=1e-3 
f3_auc = ft_model_f3.test_model(test_fold3_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.68496
--------------------------------------------------------------------------------


#### Folder number 4 ####

In [None]:
ft_model_f4 = generate_model(checkpoints_path, batch_size=32, dropout=0.3)
f4_res = ft_model_f4.cross_validate(train_fold4_df, num_folds=5, num_epochs=2,
                     batch_size=32, learning_rate=1e-5, weight_decay=1e-3,
                     shuffle=True, device=device)

Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


--------------------------------------------------------------------------------
Generated combined model for fine-tuning successfully !
--------------------------------------------------------------------------------
Fold 1/5
--------------------------------------------------------------------------------
Start training !
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Epoch: 1
Validation BCEWithLogitsLoss: 0.70503
Validation Accuracy (>0.8): 0.70
Validation AUC: 0.62226
Epoch time: 3.05 minutes
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Epoch: 2
Validation BCEWithLogitsLoss: 0.67812
Validation Accuracy (>0.8): 0.71
Validation AUC: 0.67787
Epoch time: 3.04 minutes
--------------------------------------------------------------------------------
Finish train

In [45]:
#13/09 - d=0.1, k=5, n=2, lr=1e-5, wd=1e-4
f4_auc = ft_model_f4.test_model(test_fold4_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.83464
--------------------------------------------------------------------------------


In [55]:
#13/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-4
f4_auc = ft_model_f4.test_model(test_fold4_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.78949
--------------------------------------------------------------------------------


In [25]:
#13/09 - d=0.1, k=5, n=2, lr=1e-5, wd=1e-4
f4_auc = ft_model_f4.test_model(test_fold4_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.84944
--------------------------------------------------------------------------------


In [23]:
#14/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-3 --> underfit
f4_auc = ft_model_f4.test_model(test_fold4_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.74260
--------------------------------------------------------------------------------


In [None]:
#14/09 - d=0.3, k=5, n=2, lr=1e-5, wd=1e-3 --> underfit
f4_auc = ft_model_f4.test_model(test_fold4_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

#### Folder number 5 ####

In [None]:
ft_model_f5 = generate_model(checkpoints_path, batch_size=32, dropout=0.3)
f5_res = ft_model_f5.cross_validate(train_fold5_df, num_folds=5, num_epochs=2,
                     batch_size=32, learning_rate=1e-5, weight_decay=1e-3,
                     shuffle=True, device=device)

In [47]:
#13/09 - d=0.1, k=5, n=2, lr=1e-5, wd=1e-4
f5_auc = ft_model_f5.test_model(test_fold5_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.77281
--------------------------------------------------------------------------------


In [57]:
#13/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-4
f5_auc = ft_model_f5.test_model(test_fold5_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.84392
--------------------------------------------------------------------------------


In [27]:
#13/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-4
f5_auc = ft_model_f5.test_model(test_fold5_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.80041
--------------------------------------------------------------------------------


In [25]:
#14/09 - d=0.2, k=5, n=2, lr=1e-5, wd=1e-3
f5_auc = ft_model_f5.test_model(test_fold5_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.80463
--------------------------------------------------------------------------------


In [None]:
#14/09 - d=0.3, k=5, n=2, lr=1e-5, wd=1e-3
f5_auc = ft_model_f5.test_model(test_fold5_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=32,
                         shuffle=True, device=device)

## Train the Model on Finetuned Datasets (DLIP_folds_2_0.8)  ##

In [29]:
ds_folder_path = os.path.join('datasets', 'finetune_dataset', 'DLIP_folds_2_0.8')
all_files = os.listdir(ds_folder_path)

PRINTM(f'Folder content:\n\n{all_files}')

--------------------------------------------------------------------------------
Folder content:

['test_fold5.csv', 'test_fold2.csv', 'test_fold4.csv', 'test_fold3.csv', 'test_fold1.csv', 'train_fold4.csv', 'train_fold2.csv', 'train_fold3.csv', 'train_fold1.csv', 'train_fold5.csv']
--------------------------------------------------------------------------------


In [30]:
dataframes = {}

# Read each CSV file into a dataframe and store it in the dictionary
for file in all_files:
    file_path = os.path.join(ds_folder_path, file)
    df = pd.read_csv(file_path)
    df_name = file.replace('.csv', '_df')
    dataframes[df_name] = df

In [31]:
for df_name in dataframes.keys():
    dataframes[df_name] = convert_uniprot_ids(dataframes[df_name], uniprot_mapping)
    #dataframes[df_name] = merge_datasets(dataframes[df_name], ppi_features_df)

# Access each dataframe using its name
train_fold1_df = dataframes['train_fold1_df']
train_fold2_df = dataframes['train_fold2_df']
train_fold3_df = dataframes['train_fold3_df']
train_fold4_df = dataframes['train_fold4_df']
train_fold5_df = dataframes['train_fold5_df']
test_fold1_df = dataframes['test_fold1_df']
test_fold2_df = dataframes['test_fold2_df']
test_fold3_df = dataframes['test_fold3_df']
test_fold4_df = dataframes['test_fold4_df']
test_fold5_df = dataframes['test_fold5_df']

PRINTM(f'Done inverse mapping & merging successfully !')

--------------------------------------------------------------------------------
Done inverse mapping & merging successfully !
--------------------------------------------------------------------------------


#### Fold number 1 ####

In [None]:
ft_model_f1_dlip = generate_model(checkpoints_path, batch_size=64, dropout=0.1)
f1_dlip_res = ft_model_f1_dlip.cross_validate(train_fold1_df, num_folds=5, num_epochs=2,
                     batch_size=64, learning_rate=0.0001, weight_decay=1e-3,
                     shuffle=True, device=device)

In [None]:
f1_auc = ft_model_f1_dlip.test_model(test_fold1_df,
                         criterion= nn.BCEWithLogitsLoss(),batch_size=64,
                         shuffle=True, device=device)

#### Fold number 2 ####

In [None]:
ft_model_f2_dlip = generate_model(checkpoints_path, batch_size=32, dropout=0.1)
f2_dlip_res = ft_model_f2_dlip.cross_validate(train_fold2_df, num_folds=2, num_epochs=15,
                     batch_size=32, learning_rate=0.0001, weight_decay=1e-5,
                     shuffle=True, device=device)

In [24]:
ft_model_f2_dlip.test_model(test_fold2_df,
                         criterion= nn.BCELoss() ,batch_size=32,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.75522
--------------------------------------------------------------------------------


#### Fold number 3 ####

In [None]:
ft_model_f3_dlip = generate_model(checkpoints_path, batch_size=64, dropout=0.1)
f3_dlip_res = ft_model_f3_dlip.cross_validate(train_fold3_df, num_folds=5, num_epochs=2,
                     batch_size=64, learning_rate=0.0001, weight_decay=1e-3,
                     shuffle=True, device=device)

In [None]:
f3_auc = ft_model_f3_dlip.test_model(test_fold3_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=64,
                         shuffle=True, device=device)

In [39]:
f3_auc = ft_model_f3_dlip.test_model(test_fold3_df,
                         criterion= nn.BCEWithLogitsLoss() ,batch_size=64,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.75724
--------------------------------------------------------------------------------


0.7572357382550335

#### Fold number 4 ####

In [None]:
ft_model_f4_dlip = generate_model(checkpoints_path, batch_size=64, dropout=0.1)
f4_dlip_res = ft_model_f4_dlip.cross_validate(train_fold4_df, num_folds=5, num_epochs=2,
                     batch_size=64, learning_rate=0.0001, weight_decay=1e-3,
                     shuffle=True, device=device)

In [None]:
ft_model_f4_dlip.test_model(test_fold4_df,
                         criterion= nn.BCELoss() ,batch_size=64,
                         shuffle=True, device=device)

#### Fold number 5 ####

In [None]:
ft_model_f5_dlip = generate_model(checkpoints_path, batch_size=64, dropout=0.1)
f5_dlip_res = ft_model_f5_dlip.cross_validate(train_fold5_df, num_folds=5, num_epochs=2,
                     batch_size=64, learning_rate=0.0001, weight_decay=1e-3,
                     shuffle=True, device=device)

In [None]:
ft_model_f5_dlip.test_model(test_fold5_df,
                         criterion= nn.BCELoss() ,batch_size=64,
                         shuffle=True, device=device)

## Train the Model on Finetuned Datasets (DLIP_folds_2_0.9)  ##

In [21]:
ds_folder_path = os.path.join('datasets', 'finetune_dataset', 'DLIP_folds_3_0.9')
all_files = os.listdir(ds_folder_path)

PRINTM(f'Folder content:\n\n{all_files}')

--------------------------------------------------------------------------------
Folder content:

['test_fold3.csv', 'test_fold2.csv', 'test_fold1.csv', 'test_fold5.csv', 'train_fold4.csv', 'train_fold2.csv', 'train_fold3.csv', 'train_fold5.csv', 'test_fold4.csv', 'train_fold1.csv']
--------------------------------------------------------------------------------


In [22]:
dataframes = {}

# Read each CSV file into a dataframe and store it in the dictionary
for file in all_files:
    file_path = os.path.join(ds_folder_path, file)
    df = pd.read_csv(file_path)
    df_name = file.replace('.csv', '_df')
    dataframes[df_name] = df

In [23]:
for df_name in dataframes.keys():
    dataframes[df_name] = convert_uniprot_ids(dataframes[df_name], uniprot_mapping)
    dataframes[df_name] = merge_datasets(dataframes[df_name], ppi_features_df)

# Access each dataframe using its name
train_fold1_df = dataframes['train_fold1_df']
train_fold2_df = dataframes['train_fold2_df']
train_fold3_df = dataframes['train_fold3_df']
train_fold4_df = dataframes['train_fold4_df']
train_fold5_df = dataframes['train_fold5_df']
test_fold1_df = dataframes['test_fold1_df']
test_fold2_df = dataframes['test_fold2_df']
test_fold3_df = dataframes['test_fold3_df']
test_fold4_df = dataframes['test_fold4_df']
test_fold5_df = dataframes['test_fold5_df']

PRINTM(f'Done inverse mapping & merging successfully !')

--------------------------------------------------------------------------------
Done inverse mapping & merging successfully !
--------------------------------------------------------------------------------


In [24]:
test_fold5_df.head()

Unnamed: 0,smiles,label,Feature_0,Feature_1,Feature_2,Feature_3,Feature_4,Feature_5,Feature_6,Feature_7,...,Feature_1270_id2,Feature_1271_id2,Feature_1272_id2,Feature_1273_id2,Feature_1274_id2,Feature_1275_id2,Feature_1276_id2,Feature_1277_id2,Feature_1278_id2,Feature_1279_id2
0,Nc1ccc(CNC(=O)NC[C@H](NC(=O)[C@@H]2CCCN2S(=O)(...,1,0.03022,-0.068394,-0.09304,0.142344,-0.101274,-0.057618,0.093081,-0.037972,...,0.052764,0.00723,-0.032889,0.013346,-0.002336,-0.061765,0.12224,-0.107535,-0.056424,0.064282
1,O=C(O)CCNC(=O)c1ccc2c(c1)C(=O)N(CCC1CCNCC1)C2,1,0.03022,-0.068394,-0.09304,0.142344,-0.101274,-0.057618,0.093081,-0.037972,...,-0.046108,0.018945,-0.160452,0.058708,-0.045173,-0.135907,0.036617,-0.136259,-0.023708,0.164988
2,Cc1cc(C)cc(S(=O)(=O)N2CCC[C@H]2C(=O)N[C@@H](CN...,1,0.03022,-0.068394,-0.09304,0.142344,-0.101274,-0.057618,0.093081,-0.037972,...,0.052764,0.00723,-0.032889,0.013346,-0.002336,-0.061765,0.12224,-0.107535,-0.056424,0.064282
3,O=C(O)CC(NC(=O)CCC(=O)Nc1ccc2c(c1)CNC2)c1ccccc1,1,0.013775,-0.061126,-0.020618,0.032968,-0.081779,-0.046594,0.086232,-0.010491,...,-0.046108,0.018945,-0.160452,0.058708,-0.045173,-0.135907,0.036617,-0.136259,-0.023708,0.164988
4,N=C(N)NCCC[C@@H]1NC(=O)[C@H]2COCCN2C(=O)[C@@H]...,1,0.016134,-0.060721,-0.051823,0.028336,-0.062652,-0.053298,0.08515,-0.030032,...,-0.036552,-0.016408,-0.189185,0.029432,-0.061862,-0.07556,0.090893,-0.115112,-0.029084,0.098593


In [None]:
for df_name, df in dataframes.items():
    null_counts = df.isnull().sum().sum()
    PRINTM(f'Number of nan values in {df_name} is -> {null_counts}')

#### Fold number 1 #####

In [34]:
batch_size=64

In [None]:
ft_model_f1_dlip_ = generate_model(checkpoints_path, batch_size=64, dropout=0.8)
f1_dlip_res = ft_model_f1_dlip_.cross_validate(train_fold1_df, num_folds=10, num_epochs=5,
                     batch_size=64, learning_rate=0.0001, weight_decay=1e-4,
                     shuffle=True, device=device)

In [33]:
ft_model_f1_dlip_.test_model(test_fold1_df,
                         criterion= nn.BCELoss() ,batch_size=64,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.38214
--------------------------------------------------------------------------------


#### Fold number 2 #####

In [None]:
ft_model_f2_dlip_ = generate_model(checkpoints_path, batch_size=64, dropout=0.3)
f2_dlip_res_ = ft_model_f2_dlip_.cross_validate(train_fold2_df, num_folds=10, num_epochs=5,
                     batch_size=64, learning_rate=0.0001, weight_decay=1e-4,
                     shuffle=True, device=device)

In [38]:
ft_model_f2_dlip_.test_model(test_fold2_df,
                         criterion= nn.BCELoss() ,batch_size=64,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.77614
--------------------------------------------------------------------------------


#### Fold number 3 #####

In [None]:
ft_model_f3_dlip_ = generate_model(checkpoints_path, batch_size=64, dropout=0.3)
f3_dlip_res_ = ft_model_f3_dlip_.cross_validate(train_fold3_df, num_folds=10, num_epochs=5,
                     batch_size=64, learning_rate=0.0001, weight_decay=1e-4,
                     shuffle=True, device=device)

In [27]:
ft_model_f3_dlip_.test_model(test_fold3_df,
                         criterion= nn.BCELoss() ,batch_size=64,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.71449
--------------------------------------------------------------------------------


#### Fold number 4 #####

In [None]:
ft_model_f4_dlip_ = generate_model(checkpoints_path, batch_size=64, dropout=0.3)
f4_dlip_res_ = ft_model_f4_dlip_.cross_validate(train_fold4_df, num_folds=10, num_epochs=5,
                     batch_size=64, learning_rate=0.0001, weight_decay=1e-4,
                     shuffle=True, device=device)

In [29]:
ft_model_f4_dlip_.test_model(test_fold4_df,
                         criterion= nn.BCELoss() ,batch_size=64,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.84679
--------------------------------------------------------------------------------


#### Fold number 5 #####

In [None]:
ft_model_f5_dlip_ = generate_model(checkpoints_path, batch_size=64, dropout=0.3)
f5_dlip_res_ = ft_model_f5_dlip_.cross_validate(train_fold5_df, num_folds=10, num_epochs=5,
                     batch_size=64, learning_rate=0.0001, weight_decay=1e-4,
                     shuffle=True, device=device)

In [31]:
ft_model_f5_dlip_.test_model(test_fold5_df,
                         criterion= nn.BCELoss() ,batch_size=64,
                         shuffle=True, device=device)

--------------------------------------------------------------------------------
Test AUC: 0.83778
--------------------------------------------------------------------------------


In [44]:
ds_folder_path = os.path.join('datasets', 'finetune_dataset', 'original_folds PPIMI')
all_files = os.listdir(ds_folder_path)

PRINTM(f'Folder content:\n\n{all_files}')

--------------------------------------------------------------------------------
Folder content:

['test_fold1.csv', 'test_fold3.csv', 'test_fold5.csv', 'train_val_fold2.csv', 'train_val_fold3.csv', 'train_val_fold5.csv', 'test_fold4.csv', 'train_val_fold4.csv', 'test_fold2.csv', 'train_val_fold1.csv']
--------------------------------------------------------------------------------


In [45]:
dataframes = {}

# Read each CSV file into a dataframe and store it in the dictionary
for file in all_files:
    file_path = os.path.join(ds_folder_path, file)
    df = pd.read_csv(file_path)
    df_name = file.replace('.csv', '_df')
    dataframes[df_name] = df

In [46]:
for df_name, df in dataframes.items():
    # Replace 'na' with np.nan if necessary
    df.replace('na', np.nan, inplace=True)
    
    # Identify rows where 'uniprot_id2' is NaN and replace them with 'uniprot_id1' values
    df.loc[df['uniprot_id2'].isna(), 'uniprot_id2'] = df['uniprot_id1']
    
    print(f'Updated DataFrame: {df_name}')

Updated DataFrame: test_fold1_df
Updated DataFrame: test_fold3_df
Updated DataFrame: test_fold5_df
Updated DataFrame: train_val_fold2_df
Updated DataFrame: train_val_fold3_df
Updated DataFrame: train_val_fold5_df
Updated DataFrame: test_fold4_df
Updated DataFrame: train_val_fold4_df
Updated DataFrame: test_fold2_df
Updated DataFrame: train_val_fold1_df


In [49]:
test_fold5_df = dataframes['test_fold5_df']
test_fold5_df.tail()

Unnamed: 0,smiles,uniprot_id1,uniprot_id2,label
6305,COc1cnc2n1C(C)(Cc1ccc(Br)cc1)C(=O)N2c1cc(Cl)cc...,P62942,P62942,0
6306,COc1cccc2c(C(=O)C(=O)N3CCN(C(=O)c4ccccc4)CC3)c...,P62942,P62942,0
6307,COc1cc(-c2cn(C)c(=O)c3cnccc23)cc(OC)c1CN(C)C,P62942,P62942,0
6308,CC(C)c1ccccc1Sc1ccc(-c2ccnc(N3CCCC3)c2)cc1C(F)...,P62942,P62942,0
6309,CNc1cccc(CCOc2ccc(C[C@H](NC(=O)c3c(Cl)cncc3Cl)...,P62942,P62942,0


In [50]:
for df_name in dataframes.keys():
    dataframes[df_name] = convert_uniprot_ids(dataframes[df_name], uniprot_mapping)
    dataframes[df_name] = merge_datasets(dataframes[df_name], esm_df)

# Access each dataframe using its name
train_fold1_df = dataframes['train_val_fold1_df']
train_fold2_df = dataframes['train_val_fold2_df']
train_fold3_df = dataframes['train_val_fold3_df']
train_fold4_df = dataframes['train_val_fold4_df']
train_fold5_df = dataframes['train_val_fold5_df']
test_fold1_df = dataframes['test_fold1_df']
test_fold2_df = dataframes['test_fold2_df']
test_fold3_df = dataframes['test_fold3_df']
test_fold4_df = dataframes['test_fold4_df']
test_fold5_df = dataframes['test_fold5_df']

PRINTM(f'Done inverse mapping & merging successfully !')

--------------------------------------------------------------------------------
Done inverse mapping & merging successfully !
--------------------------------------------------------------------------------


In [51]:
test_fold5_df.head()

Unnamed: 0,smiles,label,Feature_0,Feature_1,Feature_2,Feature_3,Feature_4,Feature_5,Feature_6,Feature_7,...,Feature_1270_id2,Feature_1271_id2,Feature_1272_id2,Feature_1273_id2,Feature_1274_id2,Feature_1275_id2,Feature_1276_id2,Feature_1277_id2,Feature_1278_id2,Feature_1279_id2
0,Nc1ccc(CNC(=O)NC[C@H](NC(=O)[C@@H]2CCCN2S(=O)(...,1,0.03022,-0.068394,-0.09304,0.142344,-0.101274,-0.057618,0.093081,-0.037972,...,0.052764,0.00723,-0.032889,0.013346,-0.002336,-0.061765,0.12224,-0.107535,-0.056424,0.064282
1,O=C(O)CCNC(=O)c1ccc2c(c1)C(=O)N(CCC1CCNCC1)C2,1,0.03022,-0.068394,-0.09304,0.142344,-0.101274,-0.057618,0.093081,-0.037972,...,-0.046108,0.018945,-0.160452,0.058708,-0.045173,-0.135907,0.036617,-0.136259,-0.023708,0.164988
2,Cc1cc(C)cc(S(=O)(=O)N2CCC[C@H]2C(=O)N[C@@H](CN...,1,0.03022,-0.068394,-0.09304,0.142344,-0.101274,-0.057618,0.093081,-0.037972,...,0.052764,0.00723,-0.032889,0.013346,-0.002336,-0.061765,0.12224,-0.107535,-0.056424,0.064282
3,O=C(O)CC(NC(=O)CCC(=O)Nc1ccc2c(c1)CNC2)c1ccccc1,1,0.013775,-0.061126,-0.020618,0.032968,-0.081779,-0.046594,0.086232,-0.010491,...,-0.046108,0.018945,-0.160452,0.058708,-0.045173,-0.135907,0.036617,-0.136259,-0.023708,0.164988
4,N=C(N)NCCC[C@@H]1NC(=O)[C@H]2COCCN2C(=O)[C@@H]...,1,0.016134,-0.060721,-0.051823,0.028336,-0.062652,-0.053298,0.08515,-0.030032,...,-0.036552,-0.016408,-0.189185,0.029432,-0.061862,-0.07556,0.090893,-0.115112,-0.029084,0.098593


In [None]:
for df_name, df in dataframes.items():
    null_counts = df.isnull().sum().sum()
    PRINTM(f'Number of nan values in {df_name} is -> {null_counts}')

In [None]:
ft_model_f1 = generate_model(checkpoints_path, batch_size=32, dropout=0.5)
f1_res = ft_model_f1.cross_validate(train_fold1_df, num_folds=5, num_epochs=10,
                     batch_size=32, learning_rate=0.0001, weight_decay=1e-5,
                     shuffle=True, device=device)

In [None]:
ft_model_f1.test_model(test_fold1_df,
                         criterion= nn.BCELoss() ,batch_size=32,
                         shuffle=True, device=device)

In [None]:
ft_model_f2 = generate_model(checkpoints_path, batch_size=32, dropout=0.5)
f2_res = ft_model_f2.cross_validate(train_fold2_df, num_folds=5, num_epochs=10,
                     batch_size=32, learning_rate=0.0001, weight_decay=1e-5,
                     shuffle=True, device=device)

In [None]:
ft_model_f2.test_model(test_fold2_df,
                         criterion= nn.BCELoss() ,batch_size=32,
                         shuffle=True, device=device)

In [None]:
ft_model_f3 = generate_model(checkpoints_path, batch_size=32, dropout=0.5)
f3_res = ft_model_f3.cross_validate(train_fold3_df, num_folds=5, num_epochs=10,
                     batch_size=32, learning_rate=0.0001, weight_decay=1e-5,
                     shuffle=True, device=device)

In [None]:
ft_model_f3.test_model(test_fold3_df,
                         criterion= nn.BCELoss() ,batch_size=32,
                         shuffle=True, device=device)

In [None]:
ft_model_f4 = generate_model(checkpoints_path, batch_size=32, dropout=0.5)
f4_res = ft_model_f4.cross_validate(train_fold4_df, num_folds=5, num_epochs=10,
                     batch_size=32, learning_rate=0.0001, weight_decay=1e-5,
                     shuffle=True, device=device)

In [None]:
ft_model_f4.test_model(test_fold4_df,
                         criterion= nn.BCELoss() ,batch_size=32,
                         shuffle=True, device=device)

In [None]:
ft_model_f5 = generate_model(checkpoints_path, batch_size=32, dropout=0.5)
f5_res = ft_model_f5.cross_validate(train_fold5_df, num_folds=5, num_epochs=10,
                     batch_size=32, learning_rate=0.0001, weight_decay=1e-5,
                     shuffle=True, device=device)

In [None]:
ft_model_f5.test_model(test_fold5_df,
                         criterion= nn.BCELoss() ,batch_size=32,
                         shuffle=True, device=device)

# TODO -> Section to keep ; models with data augmentation and prob in testing #

In [20]:
# without attention & with prob
class AUVG_PPI(nn.Module):
    def __init__(self, pretrained_chemprop_model, chemberta_model, dropout):
        super(AUVG_PPI, self).__init__()
        self.pretrained_chemprop_model = pretrained_chemprop_model
        self.chemberta_model = chemberta_model
        self.dropout = dropout
        
        self.esm_mlp = nn.Sequential(
            nn.Linear(in_features=1280 + 1280 , out_features=1280),
            nn.ReLU(),
            nn.BatchNorm1d(1280),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1280, out_features=640),
            nn.ReLU(),
            nn.BatchNorm1d(640),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=640, out_features=320)
        )

        self.fegs_mlp = nn.Sequential(
            nn.Linear(in_features=578 + 578, out_features=578),
            nn.ReLU(),
            nn.BatchNorm1d(578),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=578, out_features=320)
        )        

        self.custom_mlp = nn.Sequential(
            nn.Linear(in_features=4700 + 4700 , out_features=4700),
            nn.ReLU(),
            nn.BatchNorm1d(4700),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=4700, out_features=2560),
            nn.ReLU(),
            nn.BatchNorm1d(2560),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=2560, out_features=1280),
            nn.ReLU(),
            nn.BatchNorm1d(1280),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1280, out_features=640),
            nn.ReLU(),
            nn.BatchNorm1d(640),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=640, out_features=320), 
        )

        self.gae_mlp = nn.Sequential(
            nn.Linear(in_features=500 + 500, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=320)
        )

        # MLP for ppi_features
        self.ppi_mlp = nn.Sequential(
            nn.Linear(in_features=320 + 320 + 320 + 320 , out_features=480),
            nn.ReLU(),
            nn.BatchNorm1d(480),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=480, out_features=300),
        )
        
        self.fp_mlp = nn.Sequential(
            nn.Linear(in_features=2100, out_features=1050),
            nn.ReLU(),
            nn.BatchNorm1d(1050),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1050, out_features=600), 
            nn.ReLU(),
            nn.BatchNorm1d(600),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=600, out_features=350)
        )

        self.mfp_mlp = nn.Sequential(
            nn.Linear(in_features=1024, out_features= 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=350)
        )

        self.smiles_mlp = nn.Sequential(
            nn.Linear(in_features=350 + 384 + 350 + 194, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=300)
        )

        # Additional layrs in order to concatinate chemprop fingerprints, chemBERTa embeddings & ppi features all together
        self.additional_layers = nn.Sequential(
            nn.Linear(in_features=300 + 300, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=128, out_features=1)
        )
        self.sigmoid = nn.Sigmoid()


    def forward(self, bmg, esm, custom, fegs, gae,
                input_ids, attention_mask,
                morgan_fingerprints, chemical_descriptors):
        # Forward pass batch mol graph through pretrained chemprop model in order to get fingerprints embeddings
        # Afterwards, pass the fingerprints through MLP layer
        cp_fingerprints = self.pretrained_chemprop_model(bmg)
        cp_fingerprints = self.fp_mlp(cp_fingerprints)

        # Forward pass ids & attention mask in through chemBERTa pretrained model in order to get embeddings
        chemberta_embeddings = self.chemberta_model(input_ids, attention_mask)
        
        # Forward pass SMILES morgan fingerprints embeddings through MLP layer, then concate them with
        # chemprop embeddings and SMILES molecular descriptors embeddings, and pass all together through another
        # MLP layer
        mfp = self.mfp_mlp(morgan_fingerprints)
        combined_smiles_features_embeddings = torch.cat([cp_fingerprints,chemberta_embeddings, mfp, chemical_descriptors], dim=1).to(device)
        smiles_embeddings = self.smiles_mlp(combined_smiles_features_embeddings)

        # Pass all PPI features  through MLP layers, and then pass them all together into another MLP layer
        #ppi_features = proteins.to(device)
        esm_embeddings = self.esm_mlp(esm)
        custom_embeddings = self.custom_mlp(custom)
        fegs_embeddings = self.fegs_mlp(fegs)
        gae_embeddings = self.gae_mlp(gae)
        combined_ppi_features_embeddings = torch.cat([esm_embeddings, custom_embeddings, fegs_embeddings, gae_embeddings], dim=1).to(device)
        ppi_features = self.ppi_mlp(combined_ppi_features_embeddings)

        # Concatinate chemprop fingerprints embeddings, chemberta embeddings and PPI embeddings together into one tensor
        # Afterwards, pass them through MLP layer and make prediction
        combined_embeddings = torch.cat([smiles_embeddings, ppi_features], dim=1).to(device)
        output = self.additional_layers(combined_embeddings)
        output = self.sigmoid(output)
        return output

    def train_model(self, num_epochs, train_loader, val_loader, optimizer, criterion, device):
        PRINTM(f'Start training !')
        for epoch in range(num_epochs):
            start_time = time.time()
            self.train()
            running_loss = 0.0
            for (batch_smiles, _, batch_esm_features, batch_custom_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc, batch_labels) in train_loader:
                # Move tensors to the configured device
                batch_attention_mas = batch_attention_mas.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_custom_features = batch_custom_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                
                optimizer.zero_grad()
                outputs = self(batch_smiles, batch_esm_features,batch_custom_features,
                               batch_fegs_features, batch_gae_features, batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc)

                #outputs = self(batch_smiles, batch_protein_features, batch_input_ids, batch_attention_mas)
                loss = criterion(outputs.squeeze(), batch_labels)
    
                loss.backward()
                optimizer.step()
    
                #running_loss += loss.item()
                #if i % 100 == 99 and i > 0:
                    #print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss / 100:.4f}")
                    #running_loss = 0.0
    
            # Validate the model on the validation set
            val_loss, val_accuracy, val_auc = self.validate_model(val_loader, criterion, device)
            end_time = time.time()
            epoch_time = (end_time - start_time) / 60
            PRINTC()
            print(f"Epoch: {epoch+1}")
            print(f"Validation BCELoss: {val_loss:.5f}")
            print(f"Validation Accuracy (>0.8): {val_accuracy:.2f}")
            print(f"Validation AUC: {val_auc:.5f}")
            print(f"Epoch time: {epoch_time:.2f} minutes")
            PRINTC()
    
        print("Finish training !")

    def test_model(self, test_dataset, criterion, batch_size, shuffle, device):
        test_dataset = MoleculeDataset(test_dataset)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)
        self.eval()
        
        test_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []
        
        # Initialize prob_df
        prob_df = pd.DataFrame(columns=['uniprot_id1', 'uniprot_id2', 'output_prob', 'label'])
        
        with torch.no_grad():
            for (batch_smiles, batch_uniprots_tuple, batch_esm_features, batch_custom_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc, batch_labels) in test_loader:
                # Move tensors to the configured device
                batch_attention_mas = batch_attention_mas.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_custom_features = batch_custom_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)
    
                outputs = self(batch_smiles, batch_esm_features, batch_custom_features,
                               batch_fegs_features, batch_gae_features, batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc)
    
                loss = criterion(outputs.squeeze(), batch_labels)
                test_loss += loss.item()
        
                # Get the output probabilities and labels
                output_probs = outputs.squeeze().cpu().numpy()
                labels = batch_labels.cpu().numpy()
        
                # Add to prob_df
                for i in range(len(batch_uniprots_tuple)):
                    uniprot_id1, uniprot_id2 = batch_uniprots_tuple[i][0], batch_uniprots_tuple[i][1]
                    prob_df = prob_df.append({
                        'uniprot_id1': uniprot_id1,
                        'uniprot_id2': uniprot_id2,
                        'output_prob': output_probs[i],
                        'label': labels[i]
                    }, ignore_index=True)
        
                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
    
        # calc average probabilities for duplicate UniProt ID pairs
        prob_df['pair'] = prob_df.apply(lambda row: tuple(sorted([row['uniprot_id1'], row['uniprot_id2']])), axis=1)
        avg_prob_df = prob_df.groupby('pair').agg({'output_prob': 'mean', 'label': 'first'}).reset_index()
    
        # extract final all_labels and all_outputs
        all_labels = avg_prob_df['label'].tolist()
        all_outputs = avg_prob_df['output_prob'].tolist()
    
        test_loss /= len(test_loader)
        accuracy = correct / total
        test_auc = roc_auc_score(all_labels, all_outputs)
        
        print(f"Test BCELoss: {test_loss:.5f}")
        print(f"Test Accuracy: {accuracy:.2f}")
        print(f"Test AUC: {test_auc:.5f}")
    
        return prob_df, all_labels, all_outputs

    def validate_model(self, val_loader, criterion, device):
        self.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []
        with torch.no_grad():
            for (batch_smiles, _, batch_esm_features, batch_custom_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:
                # Move tensors to the configured device
                batch_attention_mas = batch_attention_mas.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_custom_features = batch_custom_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)
    
                outputs = self(batch_smiles, batch_esm_features,batch_custom_features,
                               batch_fegs_features, batch_gae_features, batch_input_ids, batch_attention_mas,
                              batch_morgan, batch_chem_desc)
                loss = criterion(outputs.squeeze(), batch_labels)
                val_loss += loss.item()
    
                all_labels.extend(batch_labels.cpu().numpy())  
                all_outputs.extend(outputs.squeeze().cpu().numpy())  
    
                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
    
        val_loss /= len(val_loader)
        accuracy = correct / total
        val_auc = roc_auc_score(all_labels, all_outputs)  
        return val_loss, accuracy, val_auc

    def cross_validate(self, dataset,
                       num_folds=5,num_epochs=10,
                       batch_size=32,
                       learning_rate=0.0001, weight_decay=1e-5,
                       shuffle=True, device='cuda'):
        kf = KFold(n_splits=num_folds, shuffle=shuffle)
        
        fold_results = []
        
        for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
            
            print(f"Fold {fold+1}/{num_folds}")
            
            # Split dataset
            train_subset = dataset.iloc[train_idx].reset_index(drop=True)
            val_subset = dataset.iloc[val_idx].reset_index(drop=True)
            
            train_dataset = MoleculeDataset(train_subset)
            val_dataset = MoleculeDataset(val_subset)
            
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
            
            criterion = nn.BCELoss()
            optimizer = optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
            
            self.train_model(num_epochs, train_loader, val_loader, optimizer, criterion, device)
            
            # Validate the model
            val_loss, val_accuracy, val_auc = self.validate_model(val_loader, criterion, device)
            fold_results.append((val_loss, val_accuracy, val_auc))

            PRINTC()
            print(f"Fold {fold+1} - Validation BCELoss: {val_loss:.5f}, Accuracy: {val_accuracy:.2f}, AUC: {val_auc:.5f}")
            PRINTC()
            
        avg_val_loss = sum([result[0] for result in fold_results]) / num_folds
        avg_val_accuracy = sum([result[1] for result in fold_results]) / num_folds
        avg_val_auc = sum([result[2] for result in fold_results]) / num_folds
        
        print(f"\nAverage Validation BCELoss: {avg_val_loss:.5f}")
        print(f"Average Validation Accuracy: {avg_val_accuracy:.2f}")
        print(f"Average Validation AUC: {avg_val_auc:.5f}")
        
        return fold_results

In [None]:
# With Attention - v2_1 with prob
class AUVG_PPI(nn.Module):
    def __init__(self, pretrained_chemprop_model, chemberta_model, dropout):
        
        super(AUVG_PPI, self).__init__()
        self.pretrained_chemprop_model = pretrained_chemprop_model
        self.chemberta_model = chemberta_model
        self.dropout = dropout
        self.ppi_self_attention = custom_self_attention(512, 8, 0.2)
        self.smiles_self_attention = custom_self_attention(128, 4, 0.2)
        self.cross_attention = nn.MultiheadAttention(512, 8, 0.2)
        self.max_pool = nn.MaxPool1d(2)
        
        # PPI Features MLP layers: (esm, custom, fegs, gae)
        self.esm_mlp = nn.Sequential(
            nn.Linear(in_features=1280 + 1280 , out_features=1750),
            nn.ReLU(),
            nn.BatchNorm1d(1750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1750, out_features=1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1000, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        self.fegs_mlp = nn.Sequential(
            nn.Linear(in_features=578 + 578, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )        

        self.custom_mlp = nn.Sequential(
            nn.Linear(in_features=4700 + 4700 , out_features=8000),
            nn.ReLU(),
            nn.BatchNorm1d(8000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=8000, out_features=6500),
            nn.ReLU(),
            nn.BatchNorm1d(6500),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=6500, out_features=5000),
            nn.ReLU(),
            nn.BatchNorm1d(5000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=5000, out_features=3500),
            nn.ReLU(),
            nn.BatchNorm1d(3500),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=3500, out_features=2000),
            nn.ReLU(),
            nn.BatchNorm1d(2000),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=2000, out_features=1028),
            nn.ReLU(),
            nn.BatchNorm1d(1028),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1028, out_features=512)
        )

        self.gae_mlp = nn.Sequential(
            nn.Linear(in_features=500 + 500, out_features=750),
            nn.ReLU(),
            nn.BatchNorm1d(750),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=750, out_features=512)
        )

        # MLP for ppi_features
        self.ppi_mlp = nn.Sequential(
            nn.Linear(in_features=512 * 4 , out_features= 1536),
            nn.ReLU(),
            nn.BatchNorm1d(1536),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1536, out_features=1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512)
        )
        
        self.fp_mlp = nn.Sequential(
            nn.Linear(in_features=2100, out_features=1536),
            nn.ReLU(),
            nn.BatchNorm1d(1536),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1536, out_features=1024), 
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128)
        )

        self.mfp_mlp = nn.Sequential(
            nn.Linear(in_features=1024, out_features= 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128)
        )

        self.chemberta_mlp = nn.Sequential(
            nn.Linear(in_features = 384, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128)
        )

        self.chem_descriptors_mlp = nn.Sequential(
            nn.Linear(in_features = 194, out_features=128)
        )

        # Additional layrs in order to concatinate chemprop fingerprints, chemBERTa embeddings & ppi features all together
        self.additional_layers = nn.Sequential(
            nn.Linear(in_features=256 + 256, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=64, out_features=1)
        )
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, bmg, esm, custom, fegs, gae,
                input_ids, attention_mask,
                morgan_fingerprints, chemical_descriptors):
        # Forward pass batch mol graph through pretrained chemprop model in order to get fingerprints embeddings
        # Afterwards, pass the fingerprints through MLP layer
        cp_fingerprints = self.pretrained_chemprop_model(bmg)
        cp_fingerprints = self.fp_mlp(cp_fingerprints)

        chemberta_embeddings = self.chemberta_model(input_ids, attention_mask)
        chemberta_embeddings = self.chemberta_mlp(chemberta_embeddings)
        mfp = self.mfp_mlp(morgan_fingerprints)
        chemical_descriptors = self.chem_descriptors_mlp(chemical_descriptors)
        
        # Concatenate all 4 smiles embeddings along a new dimension (4x194) & pass them throw self-attention layer
        smiles_embeddings = torch.stack([cp_fingerprints,chemberta_embeddings, mfp, chemical_descriptors], dim=1).to(device)  # shape ->> (batch_size, 4, 194)
        smiles_features = self.smiles_self_attention(smiles_embeddings)
        smiles_embeddings = smiles_features.unsqueeze(1)

        # Pass all PPI features  through MLP layers, and then pass them all together into another MLP layer
        #ppi_features = proteins.to(device)
        esm_embeddings = self.esm_mlp(esm)
        custom_embeddings = self.custom_mlp(custom)
        fegs_embeddings = self.fegs_mlp(fegs)
        gae_embeddings = self.gae_mlp(gae)

        # Concatenate all 4 ppi embeddings along a new dimension (4x320) & pass them throw self-attention layer
        ppi_embeddings = torch.stack([esm_embeddings, custom_embeddings, fegs_embeddings, gae_embeddings], dim=1).to(device)  # shape ->> (batch_size, 4, 320)
        ppi_features = self.ppi_self_attention(ppi_embeddings)
        ppi_features = self.ppi_mlp(ppi_features).unsqueeze(1)

        #Cross-attention between smiles and PPI to capture the interaction relationships
        ppi_QKV = ppi_features.permute(1, 0, 2)
        smiles_QKV = smiles_embeddings.permute(1, 0, 2)
        
        smiles_att, _ = self.cross_attention(smiles_QKV, ppi_QKV, ppi_QKV)
        ppi_att, _ = self.cross_attention(ppi_QKV, smiles_QKV, smiles_QKV)

        # permute attention outputrs to match (batch_size, embed_dim, num_heads) shape
        smiles_attn_output = (0.5* smiles_att.permute(1, 2, 0)) + (0.5* smiles_embeddings.permute(0, 2, 1))  # Add (residual connection) & apply weighted residual connection 
        ppi_attn_output = (0.5* ppi_att.permute(1, 2, 0)) + (0.5* ppi_features.permute(0, 2, 1))  # Add (residual connection) & apply weighted residual connection 

        # Drop the last dim in order to get (batch_size, embed_dim) & 
        # Pass cross-attention norm outputs throw max-pool layer before passing throw MLP layers
        smiles_att = self.max_pool(smiles_attn_output.squeeze(2))
        ppi_att = self.max_pool(ppi_attn_output.squeeze(2)) 
        combined_embeddings = torch.cat([smiles_att, ppi_att], dim=1)
        output = self.additional_layers(combined_embeddings)
        
        return self.sigmoid(output)

    def train_model(self, num_epochs, train_loader, val_loader, optimizer, criterion, device):
        PRINTM(f'Start training !')
        for epoch in range(num_epochs):
            start_time = time.time()
            self.train()
            running_loss = 0.0
            for (batch_smiles, _, batch_esm_features, batch_custom_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc, batch_labels) in train_loader:
                # Move tensors to the configured device
                batch_attention_mas = batch_attention_mas.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_custom_features = batch_custom_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)

                
                optimizer.zero_grad()
                outputs = self(batch_smiles, batch_esm_features,batch_custom_features,
                               batch_fegs_features, batch_gae_features, batch_input_ids, batch_attention_mas,batch_morgan, batch_chem_desc)

                loss = criterion(outputs.squeeze(), batch_labels)    
                loss.backward()
                optimizer.step()
    
            # Validate the model on the validation set
            val_loss, val_accuracy, val_auc = self.validate_model(val_loader, criterion, device)
            end_time = time.time()
            epoch_time = (end_time - start_time) / 60
            PRINTC()
            print(f"Epoch: {epoch+1}")
            print(f"Validation BCELoss: {val_loss:.5f}")
            print(f"Validation Accuracy (>0.8): {val_accuracy:.2f}")
            print(f"Validation AUC: {val_auc:.5f}")
            print(f"Epoch time: {epoch_time:.2f} minutes")
            PRINTC()
    
        print("Finish training !")

    def test_model(self, test_dataset, criterion, batch_size, shuffle, device):
        test_dataset = MoleculeDataset(test_dataset)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)
        self.eval()
        
        test_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []
        
        # Initialize prob_df
        prob_df = pd.DataFrame(columns=['uniprot_id1', 'uniprot_id2', 'output_prob', 'label'])
        
        with torch.no_grad():
            for (batch_smiles, batch_uniprots_tuple, batch_esm_features, batch_custom_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc, batch_labels) in test_loader:
                # Move tensors to the configured device
                batch_attention_mas = batch_attention_mas.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_custom_features = batch_custom_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)
    
                outputs = self(batch_smiles, batch_esm_features, batch_custom_features,
                               batch_fegs_features, batch_gae_features, batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc)
    
                loss = criterion(outputs.squeeze(), batch_labels)
                test_loss += loss.item()
        
                # Get the output probabilities and labels
                output_probs = outputs.squeeze().cpu().numpy()
                labels = batch_labels.cpu().numpy()
        
                # Add to prob_df
                for i in range(len(batch_uniprots_tuple)):
                    uniprot_id1, uniprot_id2 = batch_uniprots_tuple[i][0], batch_uniprots_tuple[i][1]
                    prob_df = prob_df.append({
                        'uniprot_id1': uniprot_id1,
                        'uniprot_id2': uniprot_id2,
                        'output_prob': output_probs[i],
                        'label': labels[i]
                    }, ignore_index=True)
        
                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
    
        # calc average probabilities for duplicate UniProt ID pairs (ensures only one pair is taken into consideration)
        prob_df['pair'] = prob_df.apply(lambda row: tuple(sorted([row['uniprot_id1'], row['uniprot_id2']])), axis=1)
        avg_prob_df = prob_df.groupby('pair').agg({'output_prob': 'mean', 'label': 'first'}).reset_index()
    
        # extract final all_labels and all_outputs
        all_labels = avg_prob_df['label'].tolist()
        all_outputs = avg_prob_df['output_prob'].tolist()
    
        test_loss /= len(test_loader)
        accuracy = correct / total
        test_auc = roc_auc_score(all_labels, all_outputs)
        
        print(f"Test BCELoss: {test_loss:.5f}")
        print(f"Test Accuracy: {accuracy:.2f}")
        print(f"Test AUC: {test_auc:.5f}")
    
        return prob_df, test_auc

    def validate_model(self, val_loader, criterion, device):
        self.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_outputs = []
        with torch.no_grad():
            for (batch_smiles, _, batch_esm_features, batch_custom_features, batch_fegs_features, batch_gae_features,
                 batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc ,batch_labels) in val_loader:
                # Move tensors to the configured device
                batch_attention_mas = batch_attention_mas.to(device)
                batch_input_ids = batch_input_ids.to(device)
                batch_esm_features = batch_esm_features.to(device)
                batch_custom_features = batch_custom_features.to(device)
                batch_fegs_features = batch_fegs_features.to(device)
                batch_gae_features = batch_gae_features.to(device)
                batch_morgan = batch_morgan.to(device)
                batch_chem_desc = batch_chem_desc.to(device)
                batch_labels = batch_labels.to(device)
    
                outputs = self(batch_smiles, batch_esm_features,batch_custom_features,
                               batch_fegs_features, batch_gae_features, batch_input_ids, batch_attention_mas, batch_morgan, batch_chem_desc)
                loss = criterion(outputs.squeeze(), batch_labels)
                val_loss += loss.item()
    
                all_labels.extend(batch_labels.cpu().numpy())  
                all_outputs.extend(outputs.squeeze().cpu().numpy())  
    
                predicted = (outputs.squeeze() > 0.8).float()
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
    
        val_loss /= len(val_loader)
        accuracy = correct / total
        val_auc = roc_auc_score(all_labels, all_outputs)  
        return val_loss, accuracy, val_auc

    def cross_validate(self, dataset, num_folds=5,num_epochs=10, batch_size=32, learning_rate=0.0001, weight_decay=1e-5, shuffle=True, device='cuda'):
        kf = KFold(n_splits=num_folds, shuffle=shuffle)
        
        fold_results = []
        
        for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
            
            print(f"Fold {fold+1}/{num_folds}")
            
            # Split dataset
            train_subset = dataset.iloc[train_idx].reset_index(drop=True)
            val_subset = dataset.iloc[val_idx].reset_index(drop=True)
            
            train_dataset = MoleculeDataset(train_subset)
            val_dataset = MoleculeDataset(val_subset)
            
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
            
            criterion = nn.BCELoss()
            optimizer = optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
            
            self.train_model(num_epochs, train_loader, val_loader, optimizer, criterion, device)
            
            # Validate the model
            val_loss, val_accuracy, val_auc = self.validate_model(val_loader, criterion, device)
            fold_results.append((val_loss, val_accuracy, val_auc))

            PRINTC()
            print(f"Fold {fold+1} - Validation BCELoss: {val_loss:.5f}, Accuracy: {val_accuracy:.2f}, AUC: {val_auc:.5f}")
            PRINTC()
            
        avg_val_loss = sum([result[0] for result in fold_results]) / num_folds
        avg_val_accuracy = sum([result[1] for result in fold_results]) / num_folds
        avg_val_auc = sum([result[2] for result in fold_results]) / num_folds
        
        print(f"\nAverage Validation BCELoss: {avg_val_loss:.5f}")
        print(f"Average Validation Accuracy: {avg_val_accuracy:.2f}")
        print(f"Average Validation AUC: {avg_val_auc:.5f}")
        
        return fold_results