In [18]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from sklearn.preprocessing import MultiLabelBinarizer
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

Using device: cuda


In [19]:
# Load the main dataset
data_path = "scope_onside_common_v3.parquet"
main_df = pd.read_parquet(data_path)

print(f"Main dataset shape: {main_df.shape}")
print(f"Columns: {main_df.columns.tolist()}")
print("\nFirst few rows:")
print(main_df.head())

# Get unique counts
n_unique_drugs = main_df['drug_id'].nunique() if 'drug_id' in main_df.columns else main_df.iloc[:, 0].nunique()
n_unique_proteins = main_df['protein_id'].nunique() if 'protein_id' in main_df.columns else main_df.iloc[:, 1].nunique()

print(f"\nUnique drugs: {n_unique_drugs}")
print(f"Unique proteins: {n_unique_proteins}")
print(f"Total interactions: {len(main_df)}")

Main dataset shape: (34741, 7)
Columns: ['drug_chembl_id', 'target_uniprot_id', 'label', 'smiles', 'sequence', 'molfile_3d', 'rxcui']

First few rows:
  drug_chembl_id target_uniprot_id  label  \
0     CHEMBL1000            O15245      0   
1     CHEMBL1000            P08183      1   
2     CHEMBL1000            P35367      1   
3     CHEMBL1000            Q02763      0   
4     CHEMBL1000            Q12809      0   

                                        smiles  \
0  O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1   
1  O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1   
2  O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1   
3  O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1   
4  O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1   

                                            sequence  \
0  MPTVDDILEQVGESGWFQKQAFLILCLLSAAFAPICVGIVFLGFTP...   
1  MDLEGDRNGGAKKKNFFKLNNKSEKDKKEKKPTVSVFSMFRYSNWL...   
2  MSLPNSSCLLEDKMCEGNKTTMASPQLMPLVVVLSTICLVTVGLNL...   
3  MDSLASLVLCGVSLLLSGTVEGAMDLILINSLPLVSDAETSLTCIA... 

In [20]:
# CONFIGURATION - Select which embeddings to use
DRUG_ENCODING = 'chemberta'    # Options: 'smiles2vec', 'chemberta'
PROTEIN_ENCODING = 'esm'        # Options: 'esm' (add more options here when available)
USE_3D_FEATURES = True          # Whether to use 3D features (EGNN for drugs, GVP-GNN for proteins)

print(f"üî¨ EXPERIMENT CONFIGURATION:")
print(f"   Drug encoding: {DRUG_ENCODING}")
print(f"   Protein encoding: {PROTEIN_ENCODING}")
print(f"   3D features: {'Enabled' if USE_3D_FEATURES else 'Disabled'}")
print(f"   ADR encoding: TF-IDF (fixed)")

# Available embedding files
EMBEDDING_PATHS = {
    'drug': {
        'smiles2vec': "2. Drug_embeddings/smiles_embeddings_smiles2vec.parquet",
        'chemberta': "2. Drug_embeddings/smiles_embeddings_chemberta.parquet",
    },
    'protein': {
        'esm': "3. Protein_enbeddings/ESM_embeddings_(t33_650m model).parquet",
    },
    '3d_drug': "2. Drug_embeddings/EGNN_drug_embeddings_v2.parquet",
    '3d_protein': "3. Protein_enbeddings/GVP-GNN_protein_embeddings.parquet",
    'adr': "1. Adr_embeddings/TFIDF_ADR_vectors/"
}

# Load selected embeddings
print(f"\nüìÅ Loading embeddings...")

# 1. Load selected drug embeddings
drug_embeddings_df = pd.read_parquet(EMBEDDING_PATHS['drug'][DRUG_ENCODING])
print(f"Drug {DRUG_ENCODING} embeddings loaded: {drug_embeddings_df.shape}")

# 2. Load selected protein embeddings  
protein_embeddings_df = pd.read_parquet(EMBEDDING_PATHS['protein'][PROTEIN_ENCODING])
print(f"Protein {PROTEIN_ENCODING} embeddings loaded: {protein_embeddings_df.shape}")

# 3. Load 3D embeddings if enabled
if USE_3D_FEATURES:
    egnn_drug_df = pd.read_parquet(EMBEDDING_PATHS['3d_drug'])
    gvp_protein_df = pd.read_parquet(EMBEDDING_PATHS['3d_protein'])
    print(f"EGNN Drug 3D embeddings loaded: {egnn_drug_df.shape}")
    print(f"GVP-GNN Protein 3D embeddings loaded: {gvp_protein_df.shape}")
else:
    egnn_drug_df = None
    gvp_protein_df = None
    print("3D features disabled")

# 4. Load ADR TF-IDF data
print("\nüíä Loading ADR TF-IDF data...")

# Load all three splits 
adr_train_df = pd.read_parquet(f"{EMBEDDING_PATHS['adr']}train/tfidf_wide.parquet")
adr_val_df = pd.read_parquet(f"{EMBEDDING_PATHS['adr']}val/tfidf_wide.parquet") 
adr_test_df = pd.read_parquet(f"{EMBEDDING_PATHS['adr']}test/tfidf_wide.parquet")

print(f"ADR TF-IDF train: {adr_train_df.shape}")
print(f"ADR TF-IDF val: {adr_val_df.shape}")
print(f"ADR TF-IDF test: {adr_test_df.shape}")

# Load global stats
import json
with open(f"{EMBEDDING_PATHS['adr']}global_stats.json", 'r') as f:
    adr_stats = json.load(f)

print(f"ADR dimensions: {adr_stats['n_adrs_kept']} (from {adr_stats['n_adrs_original']} original)")

# Combine all splits 
adr_embeddings_df = pd.concat([adr_train_df, adr_val_df, adr_test_df], ignore_index=True)
print(f"Combined ADR data: {adr_embeddings_df.shape}")

# 5. Calculate embedding dimensions
print(f"\nüìè Calculating embedding dimensions...")

# Get base dimensions
sample_drug_emb = drug_embeddings_df['embedding'].iloc[0]
sample_protein_emb = protein_embeddings_df['embedding'].iloc[0]

BASE_DRUG_DIM = len(sample_drug_emb)
BASE_PROTEIN_DIM = len(sample_protein_emb)
ADR_EMBEDDING_DIM = adr_stats['n_adrs_kept']

# 3D dimensions
if USE_3D_FEATURES:
    sample_egnn_drug = egnn_drug_df['embedding'].iloc[0]
    sample_gvp_protein = gvp_protein_df['embedding'].iloc[0]
    EGNN_DRUG_3D_DIM = len(sample_egnn_drug)
    GVP_PROTEIN_3D_DIM = len(sample_gvp_protein)
else:
    EGNN_DRUG_3D_DIM = 0
    GVP_PROTEIN_3D_DIM = 0

# Total dimensions
DRUG_EMBEDDING_DIM = BASE_DRUG_DIM + EGNN_DRUG_3D_DIM
PROTEIN_EMBEDDING_DIM = BASE_PROTEIN_DIM + GVP_PROTEIN_3D_DIM
SHARED_DIM = 512

print(f"\nüìä EMBEDDING DIMENSIONS:")
print(f"Drug ({DRUG_ENCODING}): {BASE_DRUG_DIM} + {EGNN_DRUG_3D_DIM} (3D) = {DRUG_EMBEDDING_DIM}")
print(f"Protein ({PROTEIN_ENCODING}): {BASE_PROTEIN_DIM} + {GVP_PROTEIN_3D_DIM} (3D) = {PROTEIN_EMBEDDING_DIM}")
print(f"ADR (TF-IDF): {ADR_EMBEDDING_DIM}")
print(f"Shared space: {SHARED_DIM}")

# ID column mapping
ID_COLUMNS = {
    'drug': 'drug_chembl_id',
    'protein': 'id' if PROTEIN_ENCODING == 'esm' else 'target_uniprot_id',
    '3d_drug': 'drug_chembl_id',
    '3d_protein': 'uniprot_id'
}

üî¨ EXPERIMENT CONFIGURATION:
   Drug encoding: chemberta
   Protein encoding: esm
   3D features: Enabled
   ADR encoding: TF-IDF (fixed)

üìÅ Loading embeddings...
Drug chemberta embeddings loaded: (1028, 5)
Protein esm embeddings loaded: (2385, 4)
EGNN Drug 3D embeddings loaded: (1028, 3)
GVP-GNN Protein 3D embeddings loaded: (2385, 8)

üíä Loading ADR TF-IDF data...
EGNN Drug 3D embeddings loaded: (1028, 3)
GVP-GNN Protein 3D embeddings loaded: (2385, 8)

üíä Loading ADR TF-IDF data...
ADR TF-IDF train: (719, 4049)
ADR TF-IDF val: (154, 4049)
ADR TF-IDF test: (155, 4049)
ADR dimensions: 4048 (from 4817 original)
Combined ADR data: (1028, 4049)

üìè Calculating embedding dimensions...

üìä EMBEDDING DIMENSIONS:
Drug (chemberta): 384 + 256 (3D) = 640
Protein (esm): 1280 + 1024 (3D) = 2304
ADR (TF-IDF): 4048
Shared space: 512
ADR TF-IDF train: (719, 4049)
ADR TF-IDF val: (154, 4049)
ADR TF-IDF test: (155, 4049)
ADR dimensions: 4048 (from 4817 original)
Combined ADR data: (1028, 

In [21]:
# DATA PREPARATION AND COMBINATION
print(f"\nüîÑ Preparing data for training...")

# Create ID mappings
drug_ids = drug_embeddings_df[ID_COLUMNS['drug']].values
protein_ids = protein_embeddings_df[ID_COLUMNS['protein']].values

drug_id_to_idx = {drug_id: idx for idx, drug_id in enumerate(drug_ids)}
protein_id_to_idx = {protein_id: idx for idx, protein_id in enumerate(protein_ids)}

print(f"Available embeddings: {len(drug_ids)} drugs, {len(protein_ids)} proteins")

# Extract embedding matrices
drug_embedding_matrix = np.vstack(drug_embeddings_df['embedding'].values).astype(np.float32)
protein_embedding_matrix = np.vstack(protein_embeddings_df['embedding'].values).astype(np.float32)
adr_embedding_matrix = adr_embeddings_df.iloc[:, 1:].values.astype(np.float32)

print(f"Base embeddings extracted:")
print(f"  Drug: {drug_embedding_matrix.shape}")
print(f"  Protein: {protein_embedding_matrix.shape}")
print(f"  ADR: {adr_embedding_matrix.shape}")

# Add 3D features if enabled
if USE_3D_FEATURES:
    print("\nüß¨ Adding 3D features...")
    
    # Drug 3D features (EGNN)
    drug_3d_dict = {}
    for _, row in egnn_drug_df.iterrows():
        drug_id = row[ID_COLUMNS['3d_drug']]
        embedding = np.array(row['embedding'], dtype=np.float32)
        drug_3d_dict[drug_id] = embedding
    
    drug_3d_matrix = []
    drug_3d_success = 0
    for drug_id in drug_ids:
        if drug_id in drug_3d_dict:
            drug_3d_matrix.append(drug_3d_dict[drug_id])
            drug_3d_success += 1
        else:
            drug_3d_matrix.append(np.zeros(EGNN_DRUG_3D_DIM, dtype=np.float32))
    drug_3d_matrix = np.array(drug_3d_matrix)
    
    # Protein 3D features (GVP-GNN)
    protein_3d_dict = {}
    for _, row in gvp_protein_df.iterrows():
        protein_id = row[ID_COLUMNS['3d_protein']]
        embedding = np.array(row['embedding'], dtype=np.float32)
        protein_3d_dict[protein_id] = embedding
    
    protein_3d_matrix = []
    protein_3d_success = 0
    for protein_id in protein_ids:
        if protein_id in protein_3d_dict:
            protein_3d_matrix.append(protein_3d_dict[protein_id])
            protein_3d_success += 1
        else:
            protein_3d_matrix.append(np.zeros(GVP_PROTEIN_3D_DIM, dtype=np.float32))
    protein_3d_matrix = np.array(protein_3d_matrix)
    
    # Combine with base embeddings
    drug_embedding_matrix = np.concatenate([drug_embedding_matrix, drug_3d_matrix], axis=1)
    protein_embedding_matrix = np.concatenate([protein_embedding_matrix, protein_3d_matrix], axis=1)
    
    print(f"  3D Drug success: {drug_3d_success}/{len(drug_ids)} ({100*drug_3d_success/len(drug_ids):.1f}%)")
    print(f"  3D Protein success: {protein_3d_success}/{len(protein_ids)} ({100*protein_3d_success/len(protein_ids):.1f}%)")
    print(f"  Enhanced drug embeddings: {drug_embedding_matrix.shape}")
    print(f"  Enhanced protein embeddings: {protein_embedding_matrix.shape}")

# Filter main dataset to only include available embeddings
main_df_filtered = main_df[
    (main_df['drug_chembl_id'].isin(drug_ids)) & 
    (main_df['target_uniprot_id'].isin(protein_ids))
].copy()

print(f"\nFiltered dataset: {main_df_filtered.shape}")

# Map IDs to indices
main_df_filtered['drug_idx'] = main_df_filtered['drug_chembl_id'].map(drug_id_to_idx)
main_df_filtered['protein_idx'] = main_df_filtered['target_uniprot_id'].map(protein_id_to_idx)

# Get corresponding embeddings for each sample
sample_drug_embeddings = drug_embedding_matrix[main_df_filtered['drug_idx'].values]
sample_protein_embeddings = protein_embedding_matrix[main_df_filtered['protein_idx'].values]

# Map to ADR embeddings using rxcui
adr_drug_ids = adr_embeddings_df['rxcui'].values
adr_drug_to_idx = {drug_id: idx for idx, drug_id in enumerate(adr_drug_ids)}
main_df_filtered['adr_idx'] = main_df_filtered['rxcui'].map(adr_drug_to_idx)

# Filter out samples without ADR mapping
valid_mask = main_df_filtered['adr_idx'].notna()
main_df_filtered = main_df_filtered[valid_mask].copy()
sample_drug_embeddings = sample_drug_embeddings[valid_mask]
sample_protein_embeddings = sample_protein_embeddings[valid_mask]

sample_adr_embeddings = adr_embedding_matrix[main_df_filtered['adr_idx'].values.astype(int)]

print(f"Final dataset after ADR filtering: {len(main_df_filtered):,} samples")
print(f"Sample embeddings:")
print(f"  Drug: {sample_drug_embeddings.shape}")
print(f"  Protein: {sample_protein_embeddings.shape}")
print(f"  ADR: {sample_adr_embeddings.shape}")

# Prepare labels
dti_labels = main_df_filtered['label'].values.astype(np.float32)

# ADR labels with adaptive threshold
adr_values = sample_adr_embeddings.flatten()
adr_nonzero = adr_values[adr_values > 0]
adr_threshold = np.percentile(adr_nonzero, 80) if len(adr_nonzero) > 0 else 0.1
adr_labels = (sample_adr_embeddings > adr_threshold).astype(np.float32)

dti_positive_rate = dti_labels.mean()
adr_avg_labels = adr_labels.sum(axis=1).mean()

print(f"\nLabel statistics:")
print(f"  DTI positive rate: {dti_positive_rate:.3f}")
print(f"  ADR threshold: {adr_threshold:.4f}")
print(f"  Average ADR labels per sample: {adr_avg_labels:.2f}")

print(f"\n‚úÖ Data preparation complete!")
print(f"Ready for model training with {DRUG_ENCODING} + {PROTEIN_ENCODING} + {'3D' if USE_3D_FEATURES else 'no3D'}")


üîÑ Preparing data for training...
Available embeddings: 1028 drugs, 2385 proteins
Base embeddings extracted:
  Drug: (1028, 384)
  Protein: (2385, 1280)
  ADR: (1028, 4048)

üß¨ Adding 3D features...
  3D Drug success: 1028/1028 (100.0%)
  3D Protein success: 2385/2385 (100.0%)
  Enhanced drug embeddings: (1028, 640)
  Enhanced protein embeddings: (2385, 2304)

Filtered dataset: (34741, 7)
  3D Drug success: 1028/1028 (100.0%)
  3D Protein success: 2385/2385 (100.0%)
  Enhanced drug embeddings: (1028, 640)
  Enhanced protein embeddings: (2385, 2304)

Filtered dataset: (34741, 7)
Final dataset after ADR filtering: 34,741 samples
Sample embeddings:
  Drug: (34741, 640)
  Protein: (34741, 2304)
  ADR: (34741, 4048)
Final dataset after ADR filtering: 34,741 samples
Sample embeddings:
  Drug: (34741, 640)
  Protein: (34741, 2304)
  ADR: (34741, 4048)

Label statistics:
  DTI positive rate: 0.352
  ADR threshold: 0.1245
  Average ADR labels per sample: 17.49

‚úÖ Data preparation complet

In [22]:
# HOW TO ADD YOUR NEW ENCODINGS AND RUN EXPERIMENTS

print("üìù TO ADD YOUR NEW ENCODINGS:")
print("="*50)
print("1. Add your new encoding files to the available_embeddings dictionary:")
print("   embedding_config.available_embeddings['drug']['your_new_encoding'] = 'path/to/your/file.parquet'")
print("   embedding_config.available_embeddings['protein']['your_new_encoding'] = 'path/to/your/file.parquet'")
print()
print("2. Example of adding new encodings:")

# Example: Add new encodings (uncomment and modify paths when you have them)
# embedding_config.available_embeddings['drug']['morgan_fp'] = "path/to/morgan_fingerprints.parquet"
# embedding_config.available_embeddings['drug']['desc_2d'] = "path/to/2d_descriptors.parquet"
# embedding_config.available_embeddings['protein']['unirep'] = "path/to/unirep_embeddings.parquet"
# embedding_config.available_embeddings['protein']['protbert'] = "path/to/protbert_embeddings.parquet"

print("# Add new drug encodings")
print("embedding_config.available_embeddings['drug']['morgan_fp'] = 'path/to/morgan_fingerprints.parquet'")
print("embedding_config.available_embeddings['drug']['desc_2d'] = 'path/to/2d_descriptors.parquet'")
print()
print("# Add new protein encodings")  
print("embedding_config.available_embeddings['protein']['unirep'] = 'path/to/unirep_embeddings.parquet'")
print("embedding_config.available_embeddings['protein']['protbert'] = 'path/to/protbert_embeddings.parquet'")
print()

print("üöÄ READY TO RUN EXPERIMENTS!")
print("="*50)
print("Now you can run experiments with any combination:")
print()

print("# Example 1: Run current best combination")
print("result1 = run_complete_experiment('smiles2vec', 'esm', use_3d=True, num_epochs=20)")
print()

print("# Example 2: Test without 3D features")  
print("result2 = run_complete_experiment('smiles2vec', 'esm', use_3d=False, num_epochs=20)")
print()

print("# Example 3: When you add new encodings, test them")
print("# result3 = run_complete_experiment('morgan_fp', 'esm', use_3d=True, num_epochs=20)")
print("# result4 = run_complete_experiment('smiles2vec', 'protbert', use_3d=True, num_epochs=20)")
print()

print("üîç COMPARE RESULTS:")
print("="*50)
print("# After running multiple experiments, compare them:")
print("results_files = ['results_file1.json', 'results_file2.json', 'results_file3.json']")
print("comparison = create_comparison_report(results_files)")
print()

print("üìä JSON RESULTS STRUCTURE:")
print("="*50)
print("Each experiment saves a JSON file with:")
print("- config: experiment configuration")
print("- train_history: metrics for each training epoch")
print("- val_history: validation metrics for each epoch")  
print("- test_results: final test set performance")
print("- best_epoch: epoch with best validation performance")
print("- best_val_accuracy: best validation accuracy achieved")
print()

print("‚ú® KEY BENEFITS:")
print("="*50)
print("‚úÖ Modular: Easy to swap any encoding combination")
print("‚úÖ Automatic tracking: All metrics saved automatically")
print("‚úÖ Detailed epochs: Accuracy, AUC, F1, Precision, Recall per epoch")
print("‚úÖ Easy comparison: JSON files can be loaded and compared")
print("‚úÖ Reproducible: All configurations saved with results")
print("‚úÖ Extensible: Easy to add new encodings and metrics")
print()

print("üéØ NEXT STEPS:")
print("="*50) 
print("1. Add paths to your new encoding files in the available_embeddings dictionary")
print("2. Run experiments: result = run_complete_experiment('encoding1', 'encoding2', use_3d=True)")
print("3. Compare results using the generated JSON files")
print("4. Identify the best encoding combination for your task")
print("5. Use the best model for deployment!")

# Run a quick test to make sure everything works
print("\nüß™ TESTING MODULAR SYSTEM...")
try:
    # Test configuration creation
    test_config = embedding_config.create_experiment_config('smiles2vec', 'esm', use_3d=True)
    print(f"‚úÖ Configuration system working: {test_config['experiment_name']}")
    
    # Test results tracker
    test_tracker = ResultsTracker(test_config)
    print("‚úÖ Results tracking system working")
    
    print("üéâ ALL SYSTEMS READY! You can now run experiments with different encoding combinations.")
    
except Exception as e:
    print(f"‚ùå Error in system test: {e}")
    print("Please check the setup and try again.")

üìù TO ADD YOUR NEW ENCODINGS:
1. Add your new encoding files to the available_embeddings dictionary:
   embedding_config.available_embeddings['drug']['your_new_encoding'] = 'path/to/your/file.parquet'
   embedding_config.available_embeddings['protein']['your_new_encoding'] = 'path/to/your/file.parquet'

2. Example of adding new encodings:
# Add new drug encodings
embedding_config.available_embeddings['drug']['morgan_fp'] = 'path/to/morgan_fingerprints.parquet'
embedding_config.available_embeddings['drug']['desc_2d'] = 'path/to/2d_descriptors.parquet'

# Add new protein encodings
embedding_config.available_embeddings['protein']['unirep'] = 'path/to/unirep_embeddings.parquet'
embedding_config.available_embeddings['protein']['protbert'] = 'path/to/protbert_embeddings.parquet'

üöÄ READY TO RUN EXPERIMENTS!
Now you can run experiments with any combination:

# Example 1: Run current best combination
result1 = run_complete_experiment('smiles2vec', 'esm', use_3d=True, num_epochs=20)

# Exa

In [23]:
# Check the structure of new 3D embeddings - EGNN and GVP-GNN
print("=== CHECKING NEW 3D EMBEDDING FILES ===")

# Check EGNN drug embeddings
try:
    egnn_drug_df = pd.read_parquet("EGNN_drug_embeddings_v2.parquet")
    print(f"EGNN Drug embeddings shape: {egnn_drug_df.shape}")
    print(f"EGNN Drug columns: {egnn_drug_df.columns.tolist()}")
    print(f"EGNN Drug sample:")
    print(egnn_drug_df.head(2))
    
    # Check embedding dimension
    if 'embedding' in egnn_drug_df.columns:
        sample_egnn = egnn_drug_df['embedding'].iloc[0]
        print(f"EGNN embedding dimension: {len(sample_egnn) if hasattr(sample_egnn, '__len__') else 'scalar'}")
    elif len(egnn_drug_df.columns) > 1:
        # If embeddings are in separate columns
        embedding_cols = [col for col in egnn_drug_df.columns if col not in ['drug_chembl_id', 'drug_id']]
        print(f"EGNN embedding dimension (from columns): {len(embedding_cols)}")
    
except Exception as e:
    print(f"Error loading EGNN drug embeddings: {e}")

print("\n" + "="*50)

# Check GVP-GNN protein embeddings  
try:
    gvp_protein_df = pd.read_parquet("GVP-GNN_protein_embeddings.parquet")
    print(f"GVP-GNN Protein embeddings shape: {gvp_protein_df.shape}")
    print(f"GVP-GNN Protein columns: {gvp_protein_df.columns.tolist()}")
    print(f"GVP-GNN Protein sample:")
    print(gvp_protein_df.head(2))
    
    # Check embedding dimension
    if 'embedding' in gvp_protein_df.columns:
        sample_gvp = gvp_protein_df['embedding'].iloc[0]
        print(f"GVP-GNN embedding dimension: {len(sample_gvp) if hasattr(sample_gvp, '__len__') else 'scalar'}")
    elif len(gvp_protein_df.columns) > 1:
        # If embeddings are in separate columns
        embedding_cols = [col for col in gvp_protein_df.columns if col not in ['target_uniprot_id', 'protein_id']]
        print(f"GVP-GNN embedding dimension (from columns): {len(embedding_cols)}")
        
except Exception as e:
    print(f"Error loading GVP-GNN protein embeddings: {e}")

print("\n" + "="*50)

=== CHECKING NEW 3D EMBEDDING FILES ===
Error loading EGNN drug embeddings: [Errno 2] No such file or directory: 'EGNN_drug_embeddings_v2.parquet'

Error loading GVP-GNN protein embeddings: [Errno 2] No such file or directory: 'GVP-GNN_protein_embeddings.parquet'



In [24]:
# MODEL ARCHITECTURE
class ProjectionHead(nn.Module):
    """Maps embeddings to shared latent space"""
    
    def __init__(self, input_dim, output_dim, dropout=0.2):
        super(ProjectionHead, self).__init__()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, output_dim * 2),
            nn.LayerNorm(output_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(output_dim * 2, output_dim),
            nn.LayerNorm(output_dim)
        )
        
    def forward(self, x):
        return F.normalize(self.projection(x), dim=1)

class DTIHead(nn.Module):
    """Drug-Target Interaction prediction (binary classification)"""
    
    def __init__(self, input_dim, dropout=0.3):
        super(DTIHead, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
    def forward(self, drug_emb, protein_emb):
        combined = torch.cat([drug_emb, protein_emb], dim=1)
        return self.classifier(combined)

class ADRHead(nn.Module):
    """Adverse Drug Reaction prediction (multi-label classification)"""
    
    def __init__(self, input_dim, num_adr_labels, dropout=0.3):
        super(ADRHead, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_adr_labels),
            nn.Sigmoid()
        )
        
    def forward(self, drug_emb, protein_emb=None):
        if protein_emb is not None:
            input_features = torch.cat([drug_emb, protein_emb], dim=1)
        else:
            input_features = drug_emb
        return self.classifier(input_features)

class ContrastiveLoss(nn.Module):
    """InfoNCE contrastive loss for cross-modal alignment"""
    
    def __init__(self, temperature=0.1):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
        
    def forward(self, embeddings1, embeddings2):
        batch_size = embeddings1.size(0)
        sim_matrix = torch.matmul(embeddings1, embeddings2.T) / self.temperature
        labels = torch.arange(batch_size).to(embeddings1.device)
        return F.cross_entropy(sim_matrix, labels)

class MultimodalDTIADRModel(nn.Module):
    """Complete multimodal DTI-ADR model"""
    
    def __init__(self, drug_dim, protein_dim, adr_dim, shared_dim, num_adr_labels):
        super(MultimodalDTIADRModel, self).__init__()
        
        # Projection heads
        self.drug_projection = ProjectionHead(drug_dim, shared_dim)
        self.protein_projection = ProjectionHead(protein_dim, shared_dim)
        self.adr_projection = ProjectionHead(adr_dim, shared_dim)
        
        # Task heads
        self.dti_head = DTIHead(shared_dim * 2)
        self.adr_head = ADRHead(shared_dim * 2, num_adr_labels)
        
        # Contrastive loss
        self.contrastive_loss = ContrastiveLoss()
        
    def forward(self, drug_emb, protein_emb, adr_emb, mode='train'):
        # Project to shared space
        drug_shared = self.drug_projection(drug_emb)
        protein_shared = self.protein_projection(protein_emb)
        adr_shared = self.adr_projection(adr_emb)
        
        # Task predictions
        dti_pred = self.dti_head(drug_shared, protein_shared)
        adr_pred = self.adr_head(drug_shared, protein_shared)
        
        outputs = {
            'dti_pred': dti_pred,
            'adr_pred': adr_pred,
            'drug_shared': drug_shared,
            'protein_shared': protein_shared,
            'adr_shared': adr_shared
        }
        
        # Add contrastive losses during training
        if mode == 'train':
            outputs['contrastive_dp'] = self.contrastive_loss(drug_shared, protein_shared)
            outputs['contrastive_da'] = self.contrastive_loss(drug_shared, adr_shared)
            
        return outputs

print(f"\nüß† Creating model...")
print(f"Architecture:")
print(f"  Drug -> Shared: {DRUG_EMBEDDING_DIM} -> {SHARED_DIM}")
print(f"  Protein -> Shared: {PROTEIN_EMBEDDING_DIM} -> {SHARED_DIM}")
print(f"  ADR -> Shared: {ADR_EMBEDDING_DIM} -> {SHARED_DIM}")
print(f"  DTI Head: {SHARED_DIM * 2} -> 1")
print(f"  ADR Head: {SHARED_DIM * 2} -> {ADR_EMBEDDING_DIM}")


üß† Creating model...
Architecture:
  Drug -> Shared: 640 -> 512
  Protein -> Shared: 2304 -> 512
  ADR -> Shared: 4048 -> 512
  DTI Head: 1024 -> 1
  ADR Head: 1024 -> 4048


In [25]:
# CREATE AND TEST MODEL
model = MultimodalDTIADRModel(
    drug_dim=DRUG_EMBEDDING_DIM,
    protein_dim=PROTEIN_EMBEDDING_DIM,
    adr_dim=ADR_EMBEDDING_DIM,
    shared_dim=SHARED_DIM,
    num_adr_labels=ADR_EMBEDDING_DIM
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"\nü§ñ Model created:")
print(f"  Total parameters: {total_params:,}")
print(f"  Device: {device}")

# Model will be tested after data loaders are created
print(f"  ‚úÖ Model architecture ready!")

# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

# Loss functions
dti_criterion = nn.BCELoss()
adr_criterion = nn.BCELoss()

# Calculate class weights for DTI imbalance
pos_weight = (1 - dti_positive_rate) / dti_positive_rate
print(f"\n‚öñÔ∏è Training setup:")
print(f"  DTI positive rate: {dti_positive_rate:.3f}")
print(f"  DTI positive weight: {pos_weight:.3f}")
print(f"  Optimizer: AdamW (lr=5e-4, wd=1e-4)")
print(f"  Scheduler: ReduceLROnPlateau")

print(f"\nüöÄ Ready to train {DRUG_ENCODING} + {PROTEIN_ENCODING} + {'3D' if USE_3D_FEATURES else 'no3D'}!")


ü§ñ Model created:
  Total parameters: 9,860,945
  Device: cuda
  ‚úÖ Model architecture ready!

‚öñÔ∏è Training setup:
  DTI positive rate: 0.352
  DTI positive weight: 1.840
  Optimizer: AdamW (lr=5e-4, wd=1e-4)
  Scheduler: ReduceLROnPlateau

üöÄ Ready to train chemberta + esm + 3D!


In [26]:
# DATASET AND DATA LOADERS
class DTIADRDataset(Dataset):
    """PyTorch Dataset for DTI-ADR multimodal data"""
    
    def __init__(self, data_dict):
        self.drug_embeddings = torch.FloatTensor(data_dict['drug_embeddings'])
        self.protein_embeddings = torch.FloatTensor(data_dict['protein_embeddings'])
        self.adr_embeddings = torch.FloatTensor(data_dict['adr_embeddings'])
        self.dti_labels = torch.FloatTensor(data_dict['dti_labels'])
        self.adr_labels = torch.FloatTensor(data_dict['adr_labels'])
        
    def __len__(self):
        return len(self.dti_labels)
    
    def __getitem__(self, idx):
        return {
            'drug_embedding': self.drug_embeddings[idx],
            'protein_embedding': self.protein_embeddings[idx],
            'adr_embedding': self.adr_embeddings[idx],
            'dti_label': self.dti_labels[idx],
            'adr_label': self.adr_labels[idx]
        }

# Create train/validation/test splits
print("üîÑ Creating dataset splits...")

# Create data dictionary from prepared arrays
dataset_dict = {
    'drug_embeddings': sample_drug_embeddings,
    'protein_embeddings': sample_protein_embeddings, 
    'adr_embeddings': sample_adr_embeddings,
    'dti_labels': dti_labels.reshape(-1, 1),
    'adr_labels': adr_labels
}

# Split indices
total_samples = len(dataset_dict['dti_labels'])
train_size = int(0.7 * total_samples)
val_size = int(0.15 * total_samples)
test_size = total_samples - train_size - val_size

indices = np.random.permutation(total_samples)
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]

print(f"Split sizes - Train: {len(train_indices):,}, Val: {len(val_indices):,}, Test: {len(test_indices):,}")

# Create data dictionaries for each split
def create_split_data(indices, data_dict):
    return {
        'drug_embeddings': data_dict['drug_embeddings'][indices],
        'protein_embeddings': data_dict['protein_embeddings'][indices],
        'adr_embeddings': data_dict['adr_embeddings'][indices],
        'dti_labels': data_dict['dti_labels'][indices],
        'adr_labels': data_dict['adr_labels'][indices]
    }

train_data = create_split_data(train_indices, dataset_dict)
val_data = create_split_data(val_indices, dataset_dict)
test_data = create_split_data(test_indices, dataset_dict)

# Create datasets
train_dataset = DTIADRDataset(train_data)
val_dataset = DTIADRDataset(val_data) 
test_dataset = DTIADRDataset(test_data)

# Create data loaders
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"‚úÖ Data loaders created!")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")

# Test model with a sample batch
print("\nüß™ Testing model with sample batch...")
with torch.no_grad():
    test_batch = next(iter(train_loader))
    outputs = model(
        test_batch['drug_embedding'].to(device),
        test_batch['protein_embedding'].to(device),
        test_batch['adr_embedding'].to(device)
    )
    print(f"  DTI predictions shape: {outputs['dti_pred'].shape}")
    print(f"  ADR predictions shape: {outputs['adr_pred'].shape}")
    print(f"  ‚úÖ Model working correctly!")

print(f"\nüöÄ Ready to train {DRUG_ENCODING} + {PROTEIN_ENCODING} + {'3D' if USE_3D_FEATURES else 'no3D'}!")

üîÑ Creating dataset splits...
Split sizes - Train: 24,318, Val: 5,211, Test: 5,212


‚úÖ Data loaders created!
   Batch size: 64
   Train batches: 380
   Val batches: 82
   Test batches: 82

üß™ Testing model with sample batch...
  DTI predictions shape: torch.Size([64, 1])
  ADR predictions shape: torch.Size([64, 4048])
  ‚úÖ Model working correctly!

üöÄ Ready to train chemberta + esm + 3D!


In [27]:
# TRAINING FUNCTION WITH DETAILED TRACKING
import json
from datetime import datetime

def train_model(model, train_loader, val_loader, num_epochs=50, save_results=True):
    """Train model with detailed metrics tracking and automatic result saving"""
    
    # Setup optimizer and loss functions
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
    dti_criterion = nn.BCELoss()
    adr_criterion = nn.BCELoss()
    
    # Create experiment name and results directory
    experiment_name = f"{DRUG_ENCODING}_{PROTEIN_ENCODING}_{'3D' if USE_3D_FEATURES else 'no3D'}"
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f"results_{experiment_name}_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)
    
    print(f"üöÄ Starting training: {experiment_name}")
    print(f"üìÅ Results will be saved to: {results_dir}")
    print("-" * 60)
    
    # Initialize tracking
    training_history = {
        'experiment_name': experiment_name,
        'config': {
            'drug_encoding': DRUG_ENCODING,
            'protein_encoding': PROTEIN_ENCODING,
            'use_3d': USE_3D_FEATURES,
            'drug_dim': DRUG_EMBEDDING_DIM,
            'protein_dim': PROTEIN_EMBEDDING_DIM,
            'adr_dim': ADR_EMBEDDING_DIM,
            'shared_dim': SHARED_DIM,
            'batch_size': BATCH_SIZE,
            'num_epochs': num_epochs
        },
        'train_history': [],
        'val_history': [],
        'best_epoch': 0,
        'best_val_accuracy': 0.0
    }
    
    best_val_accuracy = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_losses = {'dti': 0, 'adr': 0, 'contrastive': 0, 'total': 0}
        train_dti_preds, train_dti_labels = [], []
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False):
            drug_emb = batch['drug_embedding'].to(device)
            protein_emb = batch['protein_embedding'].to(device)
            adr_emb = batch['adr_embedding'].to(device)
            dti_labels_batch = batch['dti_label'].to(device)
            adr_labels_batch = batch['adr_label'].to(device)
            
            optimizer.zero_grad()
            
            outputs = model(drug_emb, protein_emb, adr_emb, mode='train')
            
            # Compute losses
            dti_loss = dti_criterion(outputs['dti_pred'], dti_labels_batch)
            adr_loss = adr_criterion(outputs['adr_pred'], adr_labels_batch)
            contrastive_loss = outputs.get('contrastive_dp', 0) + outputs.get('contrastive_da', 0)
            
            total_loss = dti_loss + 0.5 * adr_loss + 0.5 * contrastive_loss
            
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            # Track losses
            train_losses['dti'] += dti_loss.item()
            train_losses['adr'] += adr_loss.item()
            train_losses['contrastive'] += contrastive_loss.item() if isinstance(contrastive_loss, torch.Tensor) else contrastive_loss
            train_losses['total'] += total_loss.item()
            
            # Collect predictions for metrics
            train_dti_preds.extend(outputs['dti_pred'].detach().cpu().numpy())
            train_dti_labels.extend(dti_labels_batch.detach().cpu().numpy())
        
        # Calculate training metrics
        train_dti_preds = np.array(train_dti_preds).flatten()
        train_dti_labels = np.array(train_dti_labels).flatten()
        train_dti_pred_binary = (train_dti_preds > 0.5).astype(int)
        
        train_accuracy = accuracy_score(train_dti_labels, train_dti_pred_binary)
        train_auc = roc_auc_score(train_dti_labels, train_dti_preds) if len(np.unique(train_dti_labels)) > 1 else 0.0
        train_precision, train_recall, train_f1, _ = precision_recall_fscore_support(
            train_dti_labels, train_dti_pred_binary, average='binary', zero_division=0
        )
        
        # Validation phase
        model.eval()
        val_losses = {'dti': 0, 'adr': 0, 'total': 0}
        val_dti_preds, val_dti_labels = [], []
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]  ", leave=False):
                drug_emb = batch['drug_embedding'].to(device)
                protein_emb = batch['protein_embedding'].to(device)
                adr_emb = batch['adr_embedding'].to(device)
                dti_labels_batch = batch['dti_label'].to(device)
                adr_labels_batch = batch['adr_label'].to(device)
                
                outputs = model(drug_emb, protein_emb, adr_emb, mode='eval')
                
                dti_loss = dti_criterion(outputs['dti_pred'], dti_labels_batch)
                adr_loss = adr_criterion(outputs['adr_pred'], adr_labels_batch)
                total_loss = dti_loss + 0.5 * adr_loss
                
                val_losses['dti'] += dti_loss.item()
                val_losses['adr'] += adr_loss.item()
                val_losses['total'] += total_loss.item()
                
                val_dti_preds.extend(outputs['dti_pred'].cpu().numpy())
                val_dti_labels.extend(dti_labels_batch.cpu().numpy())
        
        # Calculate validation metrics
        val_dti_preds = np.array(val_dti_preds).flatten()
        val_dti_labels = np.array(val_dti_labels).flatten()
        val_dti_pred_binary = (val_dti_preds > 0.5).astype(int)
        
        val_accuracy = accuracy_score(val_dti_labels, val_dti_pred_binary)
        val_auc = roc_auc_score(val_dti_labels, val_dti_preds) if len(np.unique(val_dti_labels)) > 1 else 0.0
        val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(
            val_dti_labels, val_dti_pred_binary, average='binary', zero_division=0
        )
        
        # Average losses
        train_avg_losses = {k: v/len(train_loader) for k, v in train_losses.items()}
        val_avg_losses = {k: v/len(val_loader) for k, v in val_losses.items()}
        
        # Log metrics
        train_metrics = {
            'epoch': epoch + 1,
            'timestamp': datetime.now().isoformat(),
            'dti_accuracy': train_accuracy,
            'dti_auc': train_auc,
            'dti_precision': train_precision,
            'dti_recall': train_recall,
            'dti_f1': train_f1,
            **train_avg_losses
        }
        
        val_metrics = {
            'epoch': epoch + 1,
            'timestamp': datetime.now().isoformat(),
            'dti_accuracy': val_accuracy,
            'dti_auc': val_auc,
            'dti_precision': val_precision,
            'dti_recall': val_recall,
            'dti_f1': val_f1,
            **val_avg_losses
        }
        
        training_history['train_history'].append(train_metrics)
        training_history['val_history'].append(val_metrics)
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  üéØ DTI Accuracy:  Train {train_accuracy:.4f} | Val {val_accuracy:.4f}")
        print(f"  üìà AUC-ROC:       Train {train_auc:.4f} | Val {val_auc:.4f}")
        print(f"  üéõÔ∏è  Precision:     Train {train_precision:.4f} | Val {val_precision:.4f}")
        print(f"  üîç Recall:        Train {train_recall:.4f} | Val {val_recall:.4f}")
        print(f"  ‚öñÔ∏è  F1-Score:      Train {train_f1:.4f} | Val {val_f1:.4f}")
        print(f"  üí• Loss:          Train {train_avg_losses['total']:.4f} | Val {val_avg_losses['total']:.4f}")
        
        # Learning rate scheduling
        scheduler.step(val_accuracy)
        
        # Track best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            training_history['best_val_accuracy'] = best_val_accuracy
            training_history['best_epoch'] = epoch + 1
            print(f"  üåü NEW BEST VAL ACCURACY: {val_accuracy:.4f} ({val_accuracy*100:.1f}%)")
            
            # Save best model
            if save_results:
                torch.save(model.state_dict(), f"{results_dir}/best_model.pth")
        
        print("-" * 60)
    
    print(f"\nüéâ Training completed!")
    print(f"   Best validation accuracy: {best_val_accuracy:.4f} ({best_val_accuracy*100:.1f}%)")
    
    # Save training history
    if save_results:
        results_file = f"{results_dir}/training_results.json"
        with open(results_file, 'w') as f:
            json.dump(training_history, f, indent=4)
        print(f"   üìä Results saved to: {results_file}")
        
        return training_history, results_file
    else:
        return training_history, None

print("üöÄ Training function ready!")
print("Usage: history, results_file = train_model(model, train_loader, val_loader, num_epochs=50)")

üöÄ Training function ready!
Usage: history, results_file = train_model(model, train_loader, val_loader, num_epochs=50)


In [28]:
# EVALUATION FUNCTION
def evaluate_model(model, test_loader):
    """Comprehensive evaluation on test set"""
    print(f"üî¨ Evaluating model on test set...")
    
    model.eval()
    all_dti_preds, all_dti_labels = [], []
    all_adr_preds, all_adr_labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            drug_emb = batch['drug_embedding'].to(device)
            protein_emb = batch['protein_embedding'].to(device)
            adr_emb = batch['adr_embedding'].to(device)
            dti_labels_batch = batch['dti_label'].to(device)
            adr_labels_batch = batch['adr_label'].to(device)
            
            outputs = model(drug_emb, protein_emb, adr_emb, mode='eval')
            
            all_dti_preds.extend(outputs['dti_pred'].cpu().numpy())
            all_dti_labels.extend(dti_labels_batch.cpu().numpy())
            all_adr_preds.extend(outputs['adr_pred'].cpu().numpy())
            all_adr_labels.extend(adr_labels_batch.cpu().numpy())
    
    # DTI Metrics
    all_dti_preds = np.array(all_dti_preds).flatten()
    all_dti_labels = np.array(all_dti_labels).flatten()
    dti_pred_binary = (all_dti_preds > 0.5).astype(int)
    
    dti_accuracy = accuracy_score(all_dti_labels, dti_pred_binary)
    dti_auc = roc_auc_score(all_dti_labels, all_dti_preds) if len(np.unique(all_dti_labels)) > 1 else 0.0
    dti_precision, dti_recall, dti_f1, _ = precision_recall_fscore_support(
        all_dti_labels, dti_pred_binary, average='binary', zero_division=0
    )
    
    # ADR Metrics
    all_adr_preds = np.array(all_adr_preds)
    all_adr_labels = np.array(all_adr_labels)
    adr_pred_binary = (all_adr_preds > 0.5).astype(int)
    adr_accuracy = accuracy_score(all_adr_labels.flatten(), adr_pred_binary.flatten())
    
    test_results = {
        'timestamp': datetime.now().isoformat(),
        'dti_accuracy': dti_accuracy,
        'dti_auc': dti_auc,
        'dti_precision': dti_precision,
        'dti_recall': dti_recall,
        'dti_f1': dti_f1,
        'adr_accuracy': adr_accuracy,
        'n_test_samples': len(all_dti_labels)
    }
    
    print(f"\nüéØ TEST RESULTS:")
    print("=" * 50)
    print(f"Drug-Target Interaction (DTI):")
    print(f"  üéØ Accuracy:  {dti_accuracy:.4f} ({dti_accuracy*100:.1f}%)")
    print(f"  üìà AUC-ROC:   {dti_auc:.4f}")
    print(f"  üéõÔ∏è  Precision: {dti_precision:.4f}")
    print(f"  üîç Recall:    {dti_recall:.4f}")
    print(f"  ‚öñÔ∏è  F1-Score:  {dti_f1:.4f}")
    print(f"\nAdverse Drug Reaction (ADR):")
    print(f"  üéØ Accuracy:  {adr_accuracy:.4f} ({adr_accuracy*100:.1f}%)")
    print(f"\nüìä Dataset: {len(all_dti_labels):,} test samples")
    
    return test_results

# EXAMPLE USAGE AND COMPARISON FUNCTION
def compare_results(results_files):
    """Compare results from multiple experiments"""
    if not results_files:
        print("No results files provided")
        return
    
    comparison_data = []
    
    for results_file in results_files:
        try:
            with open(results_file, 'r') as f:
                results = json.load(f)
            
            comparison_data.append({
                'experiment': results['experiment_name'],
                'drug_encoding': results['config']['drug_encoding'],
                'protein_encoding': results['config']['protein_encoding'],
                'use_3d': results['config']['use_3d'],
                'best_val_accuracy': results['best_val_accuracy'],
                'best_epoch': results['best_epoch'],
                'file': results_file
            })
        except Exception as e:
            print(f"Error loading {results_file}: {e}")
    
    if comparison_data:
        comparison_data.sort(key=lambda x: x['best_val_accuracy'], reverse=True)
        
        print("\nüèÜ EXPERIMENT COMPARISON:")
        print("=" * 80)
        print(f"{'Rank':<4} {'Experiment':<30} {'Drug':<12} {'Protein':<8} {'3D':<3} {'Best Acc':<9} {'Epoch':<5}")
        print("-" * 80)
        
        for i, exp in enumerate(comparison_data, 1):
            use_3d = "‚úì" if exp['use_3d'] else "‚úó"
            print(f"{i:<4} {exp['experiment']:<30} {exp['drug_encoding']:<12} {exp['protein_encoding']:<8} {use_3d:<3} {exp['best_val_accuracy']:.4f}    {exp['best_epoch']:<5}")
    
    return comparison_data

print("üî¨ Evaluation functions ready!")
print("\nüí° EXAMPLE USAGE:")
print("="*50)
print("# 1. Train model")
print("history, results_file = train_model(model, train_loader, val_loader, num_epochs=30)")
print()
print("# 2. Evaluate on test set")
print("test_results = evaluate_model(model, test_loader)")
print()
print("# 3. Compare multiple experiments")
print("compare_results(['results_file1.json', 'results_file2.json'])")
print()
print("# 4. Change configuration and run again")
print("# Just change DRUG_ENCODING or PROTEIN_ENCODING at the top and re-run!")

print(f"\nüéØ CURRENT EXPERIMENT:")
print(f"   {DRUG_ENCODING} + {PROTEIN_ENCODING} + {'3D' if USE_3D_FEATURES else 'no3D'}")
print(f"   Ready to train!")

üî¨ Evaluation functions ready!

üí° EXAMPLE USAGE:
# 1. Train model
history, results_file = train_model(model, train_loader, val_loader, num_epochs=30)

# 2. Evaluate on test set
test_results = evaluate_model(model, test_loader)

# 3. Compare multiple experiments
compare_results(['results_file1.json', 'results_file2.json'])

# 4. Change configuration and run again
# Just change DRUG_ENCODING or PROTEIN_ENCODING at the top and re-run!

üéØ CURRENT EXPERIMENT:
   chemberta + esm + 3D
   Ready to train!


In [29]:
# START TRAINING
print("üöÄ Starting training process...")
print(f"Configuration: {DRUG_ENCODING} + {PROTEIN_ENCODING} + {'3D' if USE_3D_FEATURES else 'no3D'}")
print(f"Device: {device}")
print(f"Training samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")
print(f"Test samples: {len(test_dataset):,}")

# Start training with 20 epochs for faster initial results
history, results_file = train_model(model, train_loader, val_loader, num_epochs=20)

print(f"\n‚úÖ Training completed! Results saved to: {results_file}")

# Evaluate on test set
test_results = evaluate_model(model, test_loader)

print(f"\nüéâ Experiment completed successfully!")
print(f"üìä Best validation accuracy: {history['best_val_accuracy']:.4f}")
print(f"üìÅ Results saved and ready for comparison")

üöÄ Starting training process...
Configuration: chemberta + esm + 3D
Device: cuda
Training samples: 24,318
Validation samples: 5,211
Test samples: 5,212
üöÄ Starting training: chemberta_esm_3D
üìÅ Results will be saved to: results_chemberta_esm_3D_20251007_195822
------------------------------------------------------------


Epoch 1/20 [Train]:   0%|          | 0/380 [00:00<?, ?it/s]

                                                                     

Epoch 1/20:
  üéØ DTI Accuracy:  Train 0.7826 | Val 0.8096
  üìà AUC-ROC:       Train 0.8412 | Val 0.8869
  üéõÔ∏è  Precision:     Train 0.6824 | Val 0.6965
  üîç Recall:        Train 0.7114 | Val 0.8209
  ‚öñÔ∏è  F1-Score:      Train 0.6966 | Val 0.7536
  üí• Loss:          Train 2.7348 | Val 0.4258
  üåü NEW BEST VAL ACCURACY: 0.8096 (81.0%)
------------------------------------------------------------


                                                                     

Epoch 2/20:
  üéØ DTI Accuracy:  Train 0.8161 | Val 0.8309
  üìà AUC-ROC:       Train 0.8773 | Val 0.9006
  üéõÔ∏è  Precision:     Train 0.7345 | Val 0.7876
  üîç Recall:        Train 0.7449 | Val 0.7165
  ‚öñÔ∏è  F1-Score:      Train 0.7397 | Val 0.7504
  üí• Loss:          Train 2.4072 | Val 0.4001
  üåü NEW BEST VAL ACCURACY: 0.8309 (83.1%)
------------------------------------------------------------


                                                                     

Epoch 3/20:
  üéØ DTI Accuracy:  Train 0.8269 | Val 0.8411
  üìà AUC-ROC:       Train 0.8879 | Val 0.9088
  üéõÔ∏è  Precision:     Train 0.7493 | Val 0.7640
  üîç Recall:        Train 0.7611 | Val 0.7987
  ‚öñÔ∏è  F1-Score:      Train 0.7552 | Val 0.7810
  üí• Loss:          Train 2.3401 | Val 0.3785
  üåü NEW BEST VAL ACCURACY: 0.8411 (84.1%)
------------------------------------------------------------


                                                                     

Epoch 4/20:
  üéØ DTI Accuracy:  Train 0.8320 | Val 0.8449
  üìà AUC-ROC:       Train 0.8972 | Val 0.9134
  üéõÔ∏è  Precision:     Train 0.7617 | Val 0.7672
  üîç Recall:        Train 0.7584 | Val 0.8079
  ‚öñÔ∏è  F1-Score:      Train 0.7601 | Val 0.7870
  üí• Loss:          Train 2.2911 | Val 0.3840
  üåü NEW BEST VAL ACCURACY: 0.8449 (84.5%)
------------------------------------------------------------


                                                                     

Epoch 5/20:
  üéØ DTI Accuracy:  Train 0.8332 | Val 0.8411
  üìà AUC-ROC:       Train 0.9021 | Val 0.9159
  üéõÔ∏è  Precision:     Train 0.7666 | Val 0.7632
  üîç Recall:        Train 0.7541 | Val 0.8003
  ‚öñÔ∏è  F1-Score:      Train 0.7603 | Val 0.7813
  üí• Loss:          Train 2.2590 | Val 0.3604
------------------------------------------------------------


                                                                     

Epoch 6/20:
  üéØ DTI Accuracy:  Train 0.8385 | Val 0.8478
  üìà AUC-ROC:       Train 0.9070 | Val 0.9192
  üéõÔ∏è  Precision:     Train 0.7742 | Val 0.7720
  üîç Recall:        Train 0.7618 | Val 0.8101
  ‚öñÔ∏è  F1-Score:      Train 0.7680 | Val 0.7906
  üí• Loss:          Train 2.2313 | Val 0.3519
  üåü NEW BEST VAL ACCURACY: 0.8478 (84.8%)
------------------------------------------------------------


                                                                     

Epoch 7/20:
  üéØ DTI Accuracy:  Train 0.8407 | Val 0.8380
  üìà AUC-ROC:       Train 0.9112 | Val 0.9187
  üéõÔ∏è  Precision:     Train 0.7825 | Val 0.7377
  üîç Recall:        Train 0.7559 | Val 0.8431
  ‚öñÔ∏è  F1-Score:      Train 0.7690 | Val 0.7869
  üí• Loss:          Train 2.2047 | Val 0.3683
------------------------------------------------------------


                                                                     

Epoch 8/20:
  üéØ DTI Accuracy:  Train 0.8420 | Val 0.8542
  üìà AUC-ROC:       Train 0.9138 | Val 0.9238
  üéõÔ∏è  Precision:     Train 0.7833 | Val 0.7922
  üîç Recall:        Train 0.7599 | Val 0.7982
  ‚öñÔ∏è  F1-Score:      Train 0.7714 | Val 0.7951
  üí• Loss:          Train 2.1876 | Val 0.3395
  üåü NEW BEST VAL ACCURACY: 0.8542 (85.4%)
------------------------------------------------------------


                                                                     

Epoch 9/20:
  üéØ DTI Accuracy:  Train 0.8446 | Val 0.8551
  üìà AUC-ROC:       Train 0.9169 | Val 0.9255
  üéõÔ∏è  Precision:     Train 0.7912 | Val 0.8021
  üîç Recall:        Train 0.7569 | Val 0.7852
  ‚öñÔ∏è  F1-Score:      Train 0.7737 | Val 0.7935
  üí• Loss:          Train 2.1645 | Val 0.3372
  üåü NEW BEST VAL ACCURACY: 0.8551 (85.5%)
------------------------------------------------------------


                                                                      

Epoch 10/20:
  üéØ DTI Accuracy:  Train 0.8481 | Val 0.8553
  üìà AUC-ROC:       Train 0.9202 | Val 0.9283
  üéõÔ∏è  Precision:     Train 0.7972 | Val 0.7774
  üîç Recall:        Train 0.7605 | Val 0.8295
  ‚öñÔ∏è  F1-Score:      Train 0.7785 | Val 0.8026
  üí• Loss:          Train 2.1460 | Val 0.3319
  üåü NEW BEST VAL ACCURACY: 0.8553 (85.5%)
------------------------------------------------------------


                                                                      

Epoch 11/20:
  üéØ DTI Accuracy:  Train 0.8494 | Val 0.8576
  üìà AUC-ROC:       Train 0.9206 | Val 0.9288
  üéõÔ∏è  Precision:     Train 0.7994 | Val 0.8319
  üîç Recall:        Train 0.7618 | Val 0.7500
  ‚öñÔ∏è  F1-Score:      Train 0.7801 | Val 0.7888
  üí• Loss:          Train 2.1288 | Val 0.3304
  üåü NEW BEST VAL ACCURACY: 0.8576 (85.8%)
------------------------------------------------------------


                                                                      

Epoch 12/20:
  üéØ DTI Accuracy:  Train 0.8537 | Val 0.8536
  üìà AUC-ROC:       Train 0.9236 | Val 0.9296
  üéõÔ∏è  Precision:     Train 0.8109 | Val 0.8006
  üîç Recall:        Train 0.7604 | Val 0.7819
  ‚öñÔ∏è  F1-Score:      Train 0.7848 | Val 0.7911
  üí• Loss:          Train 2.1181 | Val 0.3288
------------------------------------------------------------


                                                                      

Epoch 13/20:
  üéØ DTI Accuracy:  Train 0.8556 | Val 0.8611
  üìà AUC-ROC:       Train 0.9263 | Val 0.9308
  üéõÔ∏è  Precision:     Train 0.8159 | Val 0.8283
  üîç Recall:        Train 0.7599 | Val 0.7673
  ‚öñÔ∏è  F1-Score:      Train 0.7869 | Val 0.7966
  üí• Loss:          Train 2.1049 | Val 0.3227
  üåü NEW BEST VAL ACCURACY: 0.8611 (86.1%)
------------------------------------------------------------


                                                                      

Epoch 14/20:
  üéØ DTI Accuracy:  Train 0.8565 | Val 0.8584
  üìà AUC-ROC:       Train 0.9272 | Val 0.9303
  üéõÔ∏è  Precision:     Train 0.8166 | Val 0.7921
  üîç Recall:        Train 0.7622 | Val 0.8144
  ‚öñÔ∏è  F1-Score:      Train 0.7885 | Val 0.8031
  üí• Loss:          Train 2.0929 | Val 0.3255
------------------------------------------------------------


                                                                      

Epoch 15/20:
  üéØ DTI Accuracy:  Train 0.8585 | Val 0.8586
  üìà AUC-ROC:       Train 0.9294 | Val 0.9304
  üéõÔ∏è  Precision:     Train 0.8218 | Val 0.7907
  üîç Recall:        Train 0.7619 | Val 0.8176
  ‚öñÔ∏è  F1-Score:      Train 0.7908 | Val 0.8039
  üí• Loss:          Train 2.0778 | Val 0.3262
------------------------------------------------------------


                                                                      

Epoch 16/20:
  üéØ DTI Accuracy:  Train 0.8593 | Val 0.8637
  üìà AUC-ROC:       Train 0.9313 | Val 0.9309
  üéõÔ∏è  Precision:     Train 0.8242 | Val 0.8391
  üîç Recall:        Train 0.7615 | Val 0.7619
  ‚öñÔ∏è  F1-Score:      Train 0.7916 | Val 0.7986
  üí• Loss:          Train 2.0715 | Val 0.3298
  üåü NEW BEST VAL ACCURACY: 0.8637 (86.4%)
------------------------------------------------------------


                                                                      

Epoch 17/20:
  üéØ DTI Accuracy:  Train 0.8617 | Val 0.8542
  üìà AUC-ROC:       Train 0.9314 | Val 0.9320
  üéõÔ∏è  Precision:     Train 0.8251 | Val 0.7603
  üîç Recall:        Train 0.7686 | Val 0.8598
  ‚öñÔ∏è  F1-Score:      Train 0.7958 | Val 0.8070
  üí• Loss:          Train 2.0640 | Val 0.3302
------------------------------------------------------------


                                                                      

Epoch 18/20:
  üéØ DTI Accuracy:  Train 0.8631 | Val 0.8542
  üìà AUC-ROC:       Train 0.9328 | Val 0.9322
  üéõÔ∏è  Precision:     Train 0.8270 | Val 0.8666
  üîç Recall:        Train 0.7711 | Val 0.6959
  ‚öñÔ∏è  F1-Score:      Train 0.7981 | Val 0.7719
  üí• Loss:          Train 2.0525 | Val 0.3342
------------------------------------------------------------


                                                                      

Epoch 19/20:
  üéØ DTI Accuracy:  Train 0.8657 | Val 0.8605
  üìà AUC-ROC:       Train 0.9349 | Val 0.9340
  üéõÔ∏è  Precision:     Train 0.8333 | Val 0.7876
  üîç Recall:        Train 0.7715 | Val 0.8306
  ‚öñÔ∏è  F1-Score:      Train 0.8012 | Val 0.8085
  üí• Loss:          Train 2.0411 | Val 0.3232
------------------------------------------------------------


                                                                      

Epoch 20/20:
  üéØ DTI Accuracy:  Train 0.8665 | Val 0.8593
  üìà AUC-ROC:       Train 0.9365 | Val 0.9351
  üéõÔ∏è  Precision:     Train 0.8322 | Val 0.8821
  üîç Recall:        Train 0.7760 | Val 0.6964
  ‚öñÔ∏è  F1-Score:      Train 0.8031 | Val 0.7783
  üí• Loss:          Train 2.0355 | Val 0.3197
------------------------------------------------------------

üéâ Training completed!
   Best validation accuracy: 0.8637 (86.4%)
   üìä Results saved to: results_chemberta_esm_3D_20251007_195822/training_results.json

‚úÖ Training completed! Results saved to: results_chemberta_esm_3D_20251007_195822/training_results.json
üî¨ Evaluating model on test set...


Testing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:01<00:00, 75.63it/s]




üéØ TEST RESULTS:
Drug-Target Interaction (DTI):
  üéØ Accuracy:  0.8576 (85.8%)
  üìà AUC-ROC:   0.9285
  üéõÔ∏è  Precision: 0.8778
  üîç Recall:    0.6970
  ‚öñÔ∏è  F1-Score:  0.7770

Adverse Drug Reaction (ADR):
  üéØ Accuracy:  0.9991 (99.9%)

üìä Dataset: 5,212 test samples

üéâ Experiment completed successfully!
üìä Best validation accuracy: 0.8637
üìÅ Results saved and ready for comparison
