# Drug-Target Interaction (DTI) Training Pipeline

This notebook walks through the complete end-to-end process of training a deep learning model to predict interactions between drugs and genes.

**Pipeline Steps:**
1.  **Setup**: Import libraries and load configuration.
2.  **Data Loading & Preprocessing**: Load the raw relationship data, filter for drug-gene pairs, and enrich it with protein sequences and chemical SMILES strings.
3.  **Data Splitting**: Split the dataset into training, validation, and test sets using a scaffold-based approach to prevent data leakage.
4.  **Featurization**: Convert SMILES strings into 3D molecular graphs and protein sequences into embeddings using an ESM model.
5.  **Dataset Creation**: Create PyTorch Datasets and DataLoaders.
6.  **Model Training**: Initialize and train the DTI prediction model.
7.  **Evaluation**: Evaluate the trained model on the test set and visualize performance metrics.
8.  **Inference**: Demonstrate how to use the trained model for predicting interactions on new pairs.

In [None]:
import os
import yaml
import json
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import DataLoader

# Import from our source code
from src.utils.api_clients import DataEnricher
from src.preprocessing.data_enricher import process_raw_data
from src.preprocessing.feature_engineer import (
    scaffold_split,
    ProteinEmbedder,
    DTIDataset,
    collate_fn
)
from src.models.dti_model import DTIModel, Trainer, evaluate_model, predict_interaction
from src.molecular_3d.visualization import plot_training_history, plot_evaluation_metrics

print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')

## 1. Setup: Load Configuration and Set Seed

In [None]:
def load_config(config_path='../config/config.yaml'):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

config = load_config()
set_seed(42)

# Create directories
Path(config['paths']['cache_dir']).mkdir(exist_ok=True)
Path(config['paths']['output_dir']).mkdir(exist_ok=True)
Path(config['paths']['checkpoint_dir']).mkdir(exist_ok=True)

print('Configuration loaded successfully:')
print(json.dumps(config, indent=2))

## 2. Data Loading, Enrichment, and Preprocessing

In [None]:
# Initialize the data enricher
enricher = DataEnricher(config)

# Process the raw data file
df_clean = process_raw_data(config, enricher)

if len(df_clean) > 0:
    print(f'
Successfully processed data. Total valid pairs: {len(df_clean)}')
    df_clean.to_csv(Path(config['paths']['cache_dir']) / 'enriched_dataset.csv', index=False)
    print(f'Enriched dataset saved to cache.')
else:
    print('
Critical error: No data left after processing. Halting execution.')

## 3. Data Splitting

In [None]:
if len(df_clean) > 0:
    df_train, df_test, df_val = scaffold_split(
        df_clean,
        train_size=config['data_split']['train_size'],
        test_size=config['data_split']['test_size'],
        val_size=config['data_split']['val_size']
    )

    # Save splits
    df_train.to_csv(Path(config['paths']['cache_dir']) / 'train_set.csv', index=False)
    df_test.to_csv(Path(config['paths']['cache_dir']) / 'test_set.csv', index=False)
    df_val.to_csv(Path(config['paths']['cache_dir']) / 'val_set.csv', index=False)
    print('Data splits saved to cache.')

## 4. Featurization & Dataset Creation

In [None]:
# Initialize the protein embedder (this will download the ESM model)
embedder = ProteinEmbedder(config)

# Create PyTorch datasets
print('
Creating training dataset...')
train_dataset = DTIDataset(df_train, config, embedder)

print('
Creating validation dataset...')
val_dataset = DTIDataset(df_val, config, embedder)

print('
Creating test dataset...')
test_dataset = DTIDataset(df_test, config, embedder)

print(f'
Dataset sizes:
Train: {len(train_dataset)}
Val: {len(val_dataset)}
Test: {len(test_dataset)}')

## 5. Create DataLoaders

In [None]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=config['training']['batch_size'], 
    shuffle=True, 
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=config['training']['batch_size'], 
    shuffle=False, 
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=config['training']['batch_size'], 
    shuffle=False, 
    collate_fn=collate_fn
)

print('DataLoaders created.')

## 6. Model Training

In [None]:
# Initialize the model
model = DTIModel(config)

# Initialize the trainer
trainer = Trainer(model, train_loader, val_loader, config)

# Start training
trainer.train_loop()

## 7. Evaluation

In [None]:
# Plot training history
history_plot_path = Path(config['paths']['output_dir']) / 'training_history.png'
plot_training_history(trainer.history, history_plot_path)

# Evaluate on the test set
print('
Evaluating model on the test set...')
test_metrics, test_preds, test_labels = evaluate_model(model, test_loader, config)

print(f'
TEST SET RESULTS')
print(f'='*30)
print(f'ROC AUC:     {test_metrics["roc_auc"]:.4f}')
print(f'PR AUC:      {test_metrics["pr_auc"]:.4f}')
print(f'Accuracy:    {test_metrics["accuracy"]:.4f}')
print(f'F1 Score:    {test_metrics["f1"]:.4f}')
print(f'='*30)

# Plot evaluation metrics
eval_plot_path = Path(config['paths']['output_dir']) / 'evaluation_metrics.png'
plot_evaluation_metrics(test_metrics, eval_plot_path)

# Save predictions
num_preds = len(test_preds)
df_test_synced = df_test.iloc[:num_preds].copy()
results_df = pd.DataFrame({
    'gene_id': df_test_synced['gene_id'],
    'chem_id': df_test_synced['chem_id'],
    'true_label': test_labels,
    'predicted_prob': test_preds
})
results_path = Path(config['paths']['output_dir']) / 'test_predictions.csv'
results_df.to_csv(results_path, index=False)
print(f'Test predictions saved to {results_path}')

## 8. Inference on a New Pair

In [None]:
# Example of predicting a new interaction
# Using a sample from the test set for demonstration
if len(df_test) > 0:
    sample = df_test.iloc[0]
    gene_id_to_predict = sample['gene_id']
    chem_id_to_predict = sample['chem_id']
    
    result = predict_interaction(
        model, 
        gene_id_to_predict, 
        chem_id_to_predict,
        enricher,
        embedder,
        config
    )
    
    if result:
        print(f'
Prediction for {gene_id_to_predict} - {chem_id_to_predict}:')
        print(f'  Probability: {result["probability"]:.4f}')
        print(f'  True Label in test set was: {sample["label"]}')