# Drug Discovery - Molecular Property Prediction
## Graph Neural Networks for Molecular Analysis

This notebook uses Graph Neural Networks to predict molecular properties for drug discovery.

**Estimated GPU Time:** 4-8 hours

In [None]:
# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
project_dir = '/content/drive/MyDrive/medical-ai-project/drug-discovery'
os.makedirs(project_dir, exist_ok=True)
os.chdir(project_dir)

In [None]:
# Install dependencies for molecular ML
!pip install torch torchvision
!pip install torch-geometric
!pip install rdkit-pypi
!pip install scikit-learn matplotlib seaborn
!pip install pandas numpy
!pip install deepchem  # Optional: for easy dataset loading

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import warnings
warnings.filterwarnings('ignore')

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Dataset Preparation

Common molecular property datasets:
- **QM9**: Quantum mechanical properties of small molecules
- **ESOL**: Aqueous solubility prediction
- **FreeSolv**: Hydration free energy
- **Lipophilicity**: Octanol/water distribution coefficient
- **BACE**: Binding affinity for Beta-secretase 1

In [None]:
# Configuration
BATCH_SIZE = 64
NUM_EPOCHS = 100
LEARNING_RATE = 0.001
HIDDEN_DIM = 128
NUM_LAYERS = 3

In [None]:
# Molecular graph conversion
def smiles_to_graph(smiles):
    """
    Convert SMILES string to PyTorch Geometric graph
    
    Args:
        smiles: SMILES string representation of molecule
    
    Returns:
        PyG Data object
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Atom features
    atom_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization().real,
            atom.GetIsAromatic(),
            atom.GetTotalNumHs()
        ]
        atom_features.append(features)
    
    x = torch.tensor(atom_features, dtype=torch.float)
    
    # Edge indices (bonds)
    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.append([i, j])
        edge_indices.append([j, i])  # Undirected graph
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    
    return Data(x=x, edge_index=edge_index)


def prepare_dataset(smiles_list, labels):
    """
    Prepare dataset from SMILES strings and labels
    """
    data_list = []
    
    for smiles, label in zip(smiles_list, labels):
        graph = smiles_to_graph(smiles)
        if graph is not None:
            graph.y = torch.tensor([label], dtype=torch.float)
            data_list.append(graph)
    
    return data_list

In [None]:
# Load dataset (example using DeepChem)
# import deepchem as dc
# tasks, datasets, transformers = dc.molnet.load_esol(featurizer='ECFP')
# train_dataset, val_dataset, test_dataset = datasets

# OR load from CSV
# df = pd.read_csv('./data/molecules.csv')
# smiles_list = df['smiles'].tolist()
# labels = df['target_property'].tolist()

# data_list = prepare_dataset(smiles_list, labels)
# train_data, val_data = train_test_split(data_list, test_size=0.2, random_state=42)

# train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
# val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)

print("Dataset preparation functions ready")

## 2. Graph Neural Network Architecture

In [None]:
class MolecularGNN(nn.Module):
    """
    Graph Neural Network for molecular property prediction
    Uses Graph Attention Networks (GAT) for better performance
    """
    def __init__(self, num_node_features, hidden_dim=128, num_layers=3, dropout=0.2):
        super(MolecularGNN, self).__init__()
        
        self.num_layers = num_layers
        
        # Graph convolution layers
        self.convs = nn.ModuleList()
        self.convs.append(GATConv(num_node_features, hidden_dim, heads=4, dropout=dropout))
        
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_dim * 4, hidden_dim, heads=4, dropout=dropout))
        
        self.convs.append(GATConv(hidden_dim * 4, hidden_dim, heads=1, dropout=dropout))
        
        # Batch normalization
        self.batch_norms = nn.ModuleList([
            nn.BatchNorm1d(hidden_dim * 4) for _ in range(num_layers - 1)
        ])
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # Fully connected layers
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 64)
        self.fc3 = nn.Linear(64, 1)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Graph convolutions
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = self.batch_norms[i](x)
            x = F.relu(x)
            x = self.dropout(x)
        
        # Global pooling
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

In [None]:
# Initialize model
NUM_NODE_FEATURES = 6  # Based on atom features we extract

model = MolecularGNN(
    num_node_features=NUM_NODE_FEATURES,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    dropout=0.2
).to(device)

print(f"Model loaded on {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Training Setup

In [None]:
# Loss and optimizer
criterion = nn.MSELoss()  # For regression tasks
# Use nn.CrossEntropyLoss() for classification tasks

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=10, verbose=True
)

In [None]:
# Training functions
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        output = model(data)
        loss = criterion(output, data.y)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * data.num_graphs
    
    return total_loss / len(train_loader.dataset)


def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    true_values = []
    
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            output = model(data)
            loss = criterion(output, data.y)
            
            total_loss += loss.item() * data.num_graphs
            predictions.extend(output.cpu().numpy())
            true_values.extend(data.y.cpu().numpy())
    
    avg_loss = total_loss / len(val_loader.dataset)
    predictions = np.array(predictions)
    true_values = np.array(true_values)
    
    # Calculate metrics
    mae = mean_absolute_error(true_values, predictions)
    rmse = np.sqrt(mean_squared_error(true_values, predictions))
    r2 = r2_score(true_values, predictions)
    
    return avg_loss, mae, rmse, r2

## 4. Training Loop

In [None]:
history = {
    'train_loss': [],
    'val_loss': [],
    'val_mae': [],
    'val_rmse': [],
    'val_r2': []
}

best_loss = float('inf')

# TODO: Uncomment when data is ready
# for epoch in range(NUM_EPOCHS):
#     print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
#     print('-' * 50)
    
#     train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
#     val_loss, val_mae, val_rmse, val_r2 = validate_epoch(model, val_loader, criterion, device)
    
#     history['train_loss'].append(train_loss)
#     history['val_loss'].append(val_loss)
#     history['val_mae'].append(val_mae)
#     history['val_rmse'].append(val_rmse)
#     history['val_r2'].append(val_r2)
    
#     print(f'Train Loss: {train_loss:.4f}')
#     print(f'Val Loss: {val_loss:.4f} | MAE: {val_mae:.4f} | RMSE: {val_rmse:.4f} | R²: {val_r2:.4f}')
    
#     scheduler.step(val_loss)
    
#     if val_loss < best_loss:
#         best_loss = val_loss
#         torch.save(model.state_dict(), './models/best_drug_discovery_model.pth')
#         print(f'✓ Best model saved with loss: {best_loss:.4f}')

print("Training setup complete")

## 5. Evaluation & Visualization

In [None]:
# Plot training history
def plot_history(history):
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train Loss')
    axes[0, 0].plot(history['val_loss'], label='Val Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # MAE
    axes[0, 1].plot(history['val_mae'])
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('MAE')
    axes[0, 1].set_title('Validation MAE')
    axes[0, 1].grid(True)
    
    # RMSE
    axes[1, 0].plot(history['val_rmse'])
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('RMSE')
    axes[1, 0].set_title('Validation RMSE')
    axes[1, 0].grid(True)
    
    # R²
    axes[1, 1].plot(history['val_r2'])
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('R²')
    axes[1, 1].set_title('Validation R² Score')
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.savefig('./training_history.png', dpi=300, bbox_inches='tight')
    plt.show()

# plot_history(history)

In [None]:
# Prediction vs actual plot
def plot_predictions(y_true, y_pred):
    plt.figure(figsize=(10, 10))
    plt.scatter(y_true, y_pred, alpha=0.5)
    plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', lw=2)
    plt.xlabel('True Values')
    plt.ylabel('Predictions')
    plt.title('Predicted vs Actual Molecular Properties')
    plt.grid(True)
    plt.savefig('./predictions.png', dpi=300, bbox_inches='tight')
    plt.show()

## 6. Molecular Visualization

Visualize molecules and their predicted properties

In [None]:
from rdkit.Chem import Draw

def visualize_molecules(smiles_list, predictions, true_values, n_samples=6):
    """
    Visualize molecules with their predicted and true property values
    """
    mols = [Chem.MolFromSmiles(s) for s in smiles_list[:n_samples]]
    legends = [
        f'Pred: {pred:.2f}\nTrue: {true:.2f}' 
        for pred, true in zip(predictions[:n_samples], true_values[:n_samples])
    ]
    
    img = Draw.MolsToGridImage(mols, molsPerRow=3, subImgSize=(300, 300), legends=legends)
    return img

# Example usage:
# img = visualize_molecules(test_smiles, predictions, true_values)
# img.save('./molecule_predictions.png')