# Temporal Drug Interaction Prediction - Data Exploration

This notebook provides an exploration of the temporal drug interaction prediction system, demonstrating the key components and their functionality.

In [None]:
# Setup imports and environment
import sys
from pathlib import Path

# Add src to path
sys.path.append(str(Path().parent / "src"))

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem
from rdkit.Chem import Draw
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Environment setup complete!")

## 1. Configuration and Setup

In [None]:
from temporal_drug_interaction_prediction_with_heterogeneous_molecular_graphs.utils.config import Config, load_config

# Load configuration
config = load_config("../configs/default.yaml")
print("Configuration loaded successfully!")
print(f"Device: {config.device}")
print(f"Hidden dim: {config.model.hidden_dim}")
print(f"Number of layers: {config.model.num_layers}")

## 2. Data Loading and Preprocessing

In [None]:
from temporal_drug_interaction_prediction_with_heterogeneous_molecular_graphs.data.loader import DrugInteractionDataLoader
from temporal_drug_interaction_prediction_with_heterogeneous_molecular_graphs.data.preprocessing import MolecularGraphPreprocessor

# Initialize data loader
data_loader = DrugInteractionDataLoader(config.data)
mol_preprocessor = MolecularGraphPreprocessor()

print("Data loader initialized!")

In [None]:
# Sample SMILES for common drugs
sample_drugs = {
    "Aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
    "Caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
    "Ibuprofen": "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O",
    "Warfarin": "CC1=C(C2=C(C=C1)OC(=O)C(C2=O)C(C3=CC=CC=C3)O)C",
    "Metformin": "CN(C)C(=N)NC(=N)N"
}

# Display molecular structures
mols = []
legends = []
for name, smiles in sample_drugs.items():
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        mols.append(mol)
        legends.append(name)

# Create a grid of molecular structures
img = Draw.MolsToGridImage(mols, molsPerRow=3, subImgSize=(200, 200), legends=legends)
display(img)

In [None]:
# Process molecules to graph representations
sample_smiles = list(sample_drugs.values())
molecular_features = data_loader.preprocess_molecules(sample_smiles, cache_key="exploration")

print(f"Processed {len(molecular_features)} molecules")
print(f"Feature dimension: {list(molecular_features.values())[0].shape}")

# Visualize molecular descriptors
feature_matrix = torch.stack(list(molecular_features.values())).numpy()
drug_names = list(sample_drugs.keys())

# Plot first 20 molecular descriptors
plt.figure(figsize=(12, 8))
sns.heatmap(feature_matrix[:, :20], 
            xticklabels=[f"Desc_{i+1}" for i in range(20)],
            yticklabels=drug_names,
            annot=True, fmt=".2f", cmap="viridis")
plt.title("Molecular Descriptors (First 20 Features)")
plt.tight_layout()
plt.show()

## 3. Drug Interaction Pair Generation

In [None]:
# Generate drug interaction pairs
drug_pairs, interaction_labels = data_loader.create_drug_interaction_pairs(
    sample_smiles, interaction_probability=0.3, seed=42
)

print(f"Generated {len(drug_pairs)} drug pairs")
print(f"Interaction rate: {np.mean(interaction_labels):.2%}")

# Create interaction matrix
interaction_matrix = np.zeros((len(sample_drugs), len(sample_drugs)))
smiles_to_idx = {smiles: idx for idx, smiles in enumerate(sample_smiles)}

for (drug1, drug2), label in zip(drug_pairs, interaction_labels):
    idx1 = smiles_to_idx[drug1]
    idx2 = smiles_to_idx[drug2]
    interaction_matrix[idx1, idx2] = label
    interaction_matrix[idx2, idx1] = label

# Plot interaction matrix
plt.figure(figsize=(8, 6))
sns.heatmap(interaction_matrix, 
            xticklabels=drug_names, 
            yticklabels=drug_names,
            annot=True, fmt=".0f", cmap="RdYlBu_r")
plt.title("Drug-Drug Interaction Matrix")
plt.tight_layout()
plt.show()

## 4. Temporal Features

In [None]:
# Generate temporal features (pharmacokinetic profiles)
temporal_features = data_loader.create_temporal_features(drug_pairs[:5])  # Use first 5 pairs

# Plot temporal concentration profiles
time_points = np.linspace(0, 24, 25)  # 24 hours

plt.figure(figsize=(12, 8))
for i, (drug_smiles, concentrations) in enumerate(list(temporal_features.items())[:5]):
    drug_name = [name for name, smiles in sample_drugs.items() if smiles == drug_smiles]
    label = drug_name[0] if drug_name else f"Drug {i+1}"
    
    plt.plot(time_points, concentrations.numpy(), 'o-', label=label, linewidth=2, markersize=4)

plt.xlabel("Time (hours)")
plt.ylabel("Drug Concentration (normalized)")
plt.title("Temporal Drug Concentration Profiles")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Heterogeneous Graph Construction

In [None]:
from temporal_drug_interaction_prediction_with_heterogeneous_molecular_graphs.data.preprocessing import TemporalGraphConstructor

# Create heterogeneous graph
graph_constructor = TemporalGraphConstructor()
hetero_graph = graph_constructor.construct_heterogeneous_graph(
    drug_pairs[:5], molecular_features, interaction_labels[:5]
)

print("Heterogeneous Graph Structure:")
print(f"Node types: {list(hetero_graph.x_dict.keys())}")
print(f"Edge types: {len(hetero_graph.edge_index_dict)}")

for node_type, features in hetero_graph.x_dict.items():
    print(f"{node_type}: {features.shape[0]} nodes, {features.shape[1]} features")

print("\nEdge types:")
for edge_type, edges in hetero_graph.edge_index_dict.items():
    print(f"{edge_type}: {edges.shape[1]} edges")

## 6. Model Architecture

In [None]:
from temporal_drug_interaction_prediction_with_heterogeneous_molecular_graphs.models.model import TemporalDrugInteractionGNN

# Define model configuration
node_type_dims = {
    'drug': 265,  # Molecular descriptor dimension
    'metabolite': 64,
    'target': 64
}

edge_types = [
    ('drug', 'metabolizes_to', 'metabolite'),
    ('metabolite', 'metabolized_from', 'drug'),
    ('drug', 'targets', 'target'),
    ('target', 'targeted_by', 'drug'),
    ('metabolite', 'affects', 'target'),
    ('target', 'affected_by', 'metabolite'),
    ('drug', 'interacts', 'drug')
]

# Create model
model = TemporalDrugInteractionGNN(
    node_type_dims=node_type_dims,
    edge_types=edge_types,
    hidden_dim=128,  # Smaller for exploration
    num_layers=2,
    num_heads=4,
    dropout=0.1,
    temporal_attention_dim=64,
    metabolite_pathway_dim=32,
    max_time_steps=25,
)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Model architecture:")
print(model)

## 7. Forward Pass Demonstration

In [None]:
# Run forward pass
model.eval()
with torch.no_grad():
    outputs = model(hetero_graph)

print("Model outputs:")
for key, value in outputs.items():
    if torch.is_tensor(value):
        print(f"{key}: {value.shape}")
    else:
        print(f"{key}: {type(value)}")

# Visualize interaction predictions
interaction_probs = torch.sigmoid(outputs['interaction_logits']).numpy().flatten()
pathway_probs = outputs['pathway_probs'].numpy()

print(f"\nInteraction probabilities: {interaction_probs}")
print(f"Mean interaction probability: {np.mean(interaction_probs):.3f}")

## 8. Evaluation Metrics

In [None]:
from temporal_drug_interaction_prediction_with_heterogeneous_molecular_graphs.evaluation.metrics import DrugInteractionMetrics

# Initialize metrics calculator
metrics = DrugInteractionMetrics(config.target_metrics)

# Generate synthetic predictions for demonstration
n_samples = 100
synthetic_predictions = torch.sigmoid(torch.randn(n_samples) * 2)
synthetic_targets = torch.randint(0, 2, (n_samples,)).float()

# Update metrics
metrics.update(synthetic_predictions, synthetic_targets)

# Compute and display metrics
all_metrics = metrics.compute_all_metrics()

print("Evaluation Metrics:")
for metric_name, value in all_metrics.items():
    print(f"{metric_name}: {value:.4f}")

In [None]:
# Plot ROC and PR curves
from sklearn.metrics import roc_curve, precision_recall_curve, auc

fpr, tpr, _ = roc_curve(synthetic_targets.numpy(), synthetic_predictions.numpy())
precision, recall, _ = precision_recall_curve(synthetic_targets.numpy(), synthetic_predictions.numpy())

roc_auc = auc(fpr, tpr)
pr_auc = auc(recall, precision)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# ROC Curve
ax1.plot(fpr, tpr, linewidth=2, label=f'ROC Curve (AUC = {roc_auc:.3f})')
ax1.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
ax1.set_xlabel('False Positive Rate')
ax1.set_ylabel('True Positive Rate')
ax1.set_title('ROC Curve')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Precision-Recall Curve
ax2.plot(recall, precision, linewidth=2, label=f'PR Curve (AUC = {pr_auc:.3f})')
ax2.set_xlabel('Recall')
ax2.set_ylabel('Precision')
ax2.set_title('Precision-Recall Curve')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Target Metrics Comparison

In [None]:
# Compare with target metrics
target_comparison = metrics.get_target_metric_comparison()

# Create comparison visualization
metrics_names = list(target_comparison.keys())
current_values = [comp['current'] for comp in target_comparison.values()]
target_values = [comp['target'] for comp in target_comparison.values()]
achieved = [comp['achieved'] for comp in target_comparison.values()]

x = np.arange(len(metrics_names))
width = 0.35

fig, ax = plt.subplots(figsize=(12, 6))
bars1 = ax.bar(x - width/2, current_values, width, label='Current', alpha=0.8)
bars2 = ax.bar(x + width/2, target_values, width, label='Target', alpha=0.8)

# Color bars based on achievement
for i, (bar, is_achieved) in enumerate(zip(bars1, achieved)):
    bar.set_color('green' if is_achieved else 'red')

ax.set_xlabel('Metrics')
ax.set_ylabel('Value')
ax.set_title('Current vs Target Metrics')
ax.set_xticks(x)
ax.set_xticklabels([name.replace('_', ' ').title() for name in metrics_names], rotation=45)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print achievement summary
achieved_count = sum(achieved)
total_count = len(achieved)
achievement_rate = achieved_count / total_count * 100

print(f"\nTarget Achievement Summary:")
print(f"Achieved: {achieved_count}/{total_count} metrics ({achievement_rate:.1f}%)")
for name, comp in target_comparison.items():
    status = "‚úÖ" if comp['achieved'] else "‚ùå"
    print(f"{status} {name}: {comp['current']:.3f} (target: {comp['target']:.3f})")

## 10. Molecular Feature Analysis

In [None]:
# Analyze molecular feature distributions
feature_df = pd.DataFrame(feature_matrix, index=drug_names)

# Basic statistics
print("Molecular Feature Statistics:")
print(feature_df.describe().iloc[:, :10])  # First 10 features

# Feature correlation heatmap
plt.figure(figsize=(10, 8))
correlation_matrix = feature_df.iloc[:, :20].corr()  # First 20 features
sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', center=0)
plt.title('Molecular Feature Correlations (First 20 Features)')
plt.tight_layout()
plt.show()

## 11. System Performance Summary

In [None]:
print("üéØ TEMPORAL DRUG INTERACTION PREDICTION SYSTEM")
print("="*60)
print(f"üìä Dataset: {len(molecular_features)} molecules processed")
print(f"üî¨ Drug Pairs: {len(drug_pairs)} interaction pairs generated")
print(f"üß† Model: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"‚ö° Architecture: {config.model.num_layers} layers, {config.model.hidden_dim}D hidden")

print("\nüéØ Key Features:")
print("‚Ä¢ Heterogeneous graph neural networks")
print("‚Ä¢ Temporal attention mechanisms")
print("‚Ä¢ Metabolite pathway prediction")
print("‚Ä¢ Multi-task learning framework")
print("‚Ä¢ Comprehensive evaluation metrics")

print("\nüìà Target Metrics:")
for metric, target in config.target_metrics.items():
    print(f"‚Ä¢ {metric.replace('_', ' ').title()}: {target:.3f}")

print("\n‚úÖ Exploration completed successfully!")

## Next Steps

1. **Training**: Use the training script to train the model on larger datasets
2. **Evaluation**: Run comprehensive evaluation using the evaluation script
3. **Hyperparameter Tuning**: Experiment with different model configurations
4. **Real Data**: Apply to real drug interaction datasets
5. **Production**: Deploy for real-world pharmacovigilance applications

---

This notebook demonstrates the core functionality of the temporal drug interaction prediction system. The system combines state-of-the-art graph neural networks with temporal modeling to predict adverse drug-drug interactions, enabling early detection of dangerous combinations before clinical trials.