In [12]:
!pip install rdkit --pre deepchem torch_geometric triton



In [21]:
"""
DeepChem DMPNN Model Optimization with PyTorch

This script demonstrates how to optimize a DeepChem DMPNN model using various PyTorch techniques:
1. Using torch.compile for improved performance
2. Analyzing the compilation process and Inductor backend
3. Identifying bottlenecks in the training pipeline
4. Applying optimization techniques like mixed precision, quantization, and knowledge distillation
"""

import deepchem as dc
import numpy as np
import torch
import time
import matplotlib.pyplot as plt
from torch.profiler import profile, record_function, ProfilerActivity
from deepchem.models.torch_models import DMPNN
from deepchem.models.optimizers import Adam
import deepchem.models.losses as losses
import os
import datetime

from torch_geometric.data import Batch
from torch_geometric.data import Data

In [2]:
# Set random seed for reproducibility
np.random.seed(123)
# torch.manual_seed(123)

In [3]:
# Step 1: Load PCBA dataset
print("Loading PCBA dataset...")
pcba_tasks, pcba_datasets, transformers = dc.molnet.load_tox21(featurizer=dc.feat.DMPNNFeaturizer(), splitter='random')
train_dataset, valid_dataset, test_dataset = pcba_datasets

Loading PCBA dataset...


In [None]:
# Step 2: Define DMPNN model architecture
n_tasks = len(pcba_tasks)
print(f"Dataset has {n_tasks} tasks")

# Create DMPNN model
model = dc.models.DMPNNModel(
    n_tasks=n_tasks,
    mode='classification',
    n_atom_feat=75,  # Default for DMPNN
    n_pair_feat=14,  # Default for DMPNN
    n_hidden=300,    # Hidden size
    n_graph_feat=300, # Graph feature size
    dropout=0.2,
    learning_rate=0.001
)

Dataset has 12 tasks


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [7]:
model = dc.models.DMPNNModel(n_tasks=len(pcba_tasks))

In [9]:
# Step 3: Baseline training and evaluation
print("\n--- Baseline Model Performance ---")
start_time = time.time()
# Train for a small number of epochs for demonstration
model.fit(train_dataset, nb_epoch=10)
baseline_train_time = time.time() - start_time

# Evaluate model
baseline_scores = model.evaluate(test_dataset, metrics=[dc.metrics.roc_auc_score])
print(f"Baseline ROC-AUC: {baseline_scores}")
print(f"Baseline training time: {baseline_train_time:.2f} seconds")


--- Baseline Model Performance ---
Baseline ROC-AUC: {'metric-1': 0.8073138353020622}
Baseline training time: 337.48 seconds


In [10]:
# Evaluate model with proper binary classification handling
def evaluate_binary(model, dataset):
    predictions = model.predict(dataset)
    # For binary classification, we take the positive class probability
    if len(predictions.shape) == 3:  # If predictions are [batch_size, n_tasks, 2]
        predictions = predictions[:, :, 1]  # Take the positive class probability
    y_true = dataset.y
    # Ensure y_true is binary (0 or 1)
    y_true = (y_true > 0).astype(float)
    return dc.metrics.roc_auc_score(y_true, predictions)

baseline_roc_auc = evaluate_binary(model, test_dataset)
print(f"Baseline ROC-AUC: {baseline_roc_auc}")
print(f"Baseline training time: {baseline_train_time:.2f} seconds")

# Step 4: Apply torch.compile to the model
print("\n--- Applying torch.compile ---")
# Get the PyTorch model from the DeepChem wrapper
torch_model = model.model

# Create a compiled version of the model
if hasattr(torch, 'compile'):  # Check if torch.compile is available (PyTorch 2.0+)
    compiled_model = torch.compile(torch_model, backend="inductor")
    model.model = compiled_model

    # Train with compiled model
    start_time = time.time()
    model.fit(train_dataset, nb_epoch=10)
    compiled_train_time = time.time() - start_time

    # Evaluate compiled model
    compiled_roc_auc = evaluate_binary(model, test_dataset)
    print(f"Compiled model ROC-AUC: {compiled_roc_auc}")
    print(f"Compiled model training time: {compiled_train_time:.2f} seconds")
    print(f"Speedup: {baseline_train_time / compiled_train_time:.2f}x")
else:
    print("torch.compile not available in this PyTorch version. Requires PyTorch 2.0+")

Baseline ROC-AUC: 0.8073138353020622
Baseline training time: 337.48 seconds

--- Applying torch.compile ---


W0418 00:43:17.028000 3199 torch/_dynamo/exc.py:304] [0/0] Backend compiler failed with a fake tensor exception at 
W0418 00:43:17.028000 3199 torch/_dynamo/exc.py:304] [0/0]   File "/usr/local/lib/python3.11/dist-packages/deepchem/models/torch_models/dmpnn.py", line 449, in forward
W0418 00:43:17.028000 3199 torch/_dynamo/exc.py:304] [0/0]     return final_output
W0418 00:43:17.028000 3199 torch/_dynamo/exc.py:304] [0/0] Adding a graph break.
W0418 00:43:17.263000 3199 torch/_dynamo/exc.py:304] [0/0_1] Backend compiler failed with a fake tensor exception at 
W0418 00:43:17.263000 3199 torch/_dynamo/exc.py:304] [0/0_1]   File "/usr/local/lib/python3.11/dist-packages/deepchem/models/torch_models/dmpnn.py", line 449, in forward
W0418 00:43:17.263000 3199 torch/_dynamo/exc.py:304] [0/0_1]     return final_output
W0418 00:43:17.263000 3199 torch/_dynamo/exc.py:304] [0/0_1] Adding a graph break.
W0418 00:43:24.426000 3199 torch/_inductor/utils.py:1137] [2/0] Not enough SMs to use max_autotu

Compiled model ROC-AUC: 0.8149910506114614
Compiled model training time: 354.92 seconds
Speedup: 0.95x


In [28]:
def time_torch_function(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

track_dict = {}
prev_time_dict = {}

def get_time_track_callback(track_dict, track_name, track_interval):
    track_dict[track_name] = []
    prev_time_dict[track_name] = datetime.datetime.now()
    def callback(model, step):
        if step % track_interval == 0:
            elapsed_time = datetime.datetime.now() - prev_time_dict[track_name]
            track_dict[track_name].append(elapsed_time.total_seconds())
            prev_time_dict[track_name] = datetime.datetime.now()
    return callback

In [18]:
model = dc.models.DMPNNModel()
model_compiled = dc.models.DMPNNModel()
model_compiled.compile(mode='reduce-overhead')

track_interval = 20
eager_dict_name = "eager_train"
compiled_dict_name = "compiled_train"

eager_train_callback = get_time_track_callback(track_dict, eager_dict_name, track_interval)
model.fit(train_dataset, nb_epoch=10, callbacks=[eager_train_callback])

compiled_train_callback = get_time_track_callback(track_dict, compiled_dict_name, track_interval)
model_compiled.fit(train_dataset, nb_epoch=10, callbacks=[compiled_train_callback])

  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_loss(output,
  return torch.nn.functional.mse_l

RuntimeError: TODO: graph recording observed an input tensor deallocate during graph  recording that did not occur during replay. Please file an issue.

In [None]:
eager_train_times = track_dict[eager_dict_name]
compiled_train_times = track_dict[compiled_dict_name]

print(f"Eager Times (first 15): {[f'{t:.3f}' for t in eager_train_times[:15]]}")
print(f"Compiled Times (first 15): {[f'{t:.3f}' for t in compiled_train_times[:15]]}")
print(f"Total Eager Time: {sum(eager_train_times)}")
print(f"Total Compiled Time: {sum(compiled_train_times)}")
print(f"Eager Median: {np.median(eager_train_times)}")
print(f"Compiled Median: {np.median(compiled_train_times)}")
print(f"Median Speedup: {((np.median(eager_train_times) / np.median(compiled_train_times)) - 1) * 100:.2f}%")

In [None]:
# Step 6: Apply mixed precision training
print("\n--- Mixed Precision Training ---")
if torch.cuda.is_available():
    # Reset model
    model = DMPNN(
        n_tasks=n_tasks,
        mode='classification',
        n_atom_feat=75,
        n_pair_feat=14,
        n_hidden=300,
        n_graph_feat=300,
        dropout=0.2,
        learning_rate=0.001
    )

    # Move model to GPU
    model.model = model.model.cuda()

    # Apply mixed precision
    from torch.amp import autocast, GradScaler

    # Train with mixed precision
    scaler = GradScaler('cuda')
    start_time = time.time()

    # Use BCEWithLogitsLoss since it's safe with autocast
    loss_fn = torch.nn.BCEWithLogitsLoss()

    # Create a PyTorch optimizer for the model parameters
    optimizer = torch.optim.Adam(model.model.parameters(), lr=0.001)

    for epoch in range(10):
        for X_batch, y_batch, w_batch, ids_batch in train_dataset.iterbatches(batch_size=32, pad_batches=True):
            X_tensor = torch.tensor(X_batch, dtype=torch.float32).cuda()
            y_tensor = torch.tensor(y_batch, dtype=torch.float32).cuda()

            # Forward pass with autocast
            with autocast('cuda'):
                output = model.model(X_tensor)
                loss = loss_fn(output, y_tensor)

            # Backward pass with gradient scaling
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

    mixed_precision_time = time.time() - start_time

    # Evaluate model
    mixed_precision_scores = model.evaluate(test_dataset, metrics=[dc.metrics.Metric(dc.metrics.roc_auc_score)])
    print(f"Mixed Precision ROC-AUC: {mixed_precision_scores['roc_auc_score']}")
    print(f"Mixed Precision training time: {mixed_precision_time:.2f} seconds")
    print(f"Speedup vs baseline: {baseline_train_time / mixed_precision_time:.2f}x")
else:
    print("CUDA not available for mixed precision training")

In [None]:
# Compare model sizes
def get_model_size(model):
    torch.save(model.state_dict(), "temp_model.pt")
    size = os.path.getsize("temp_model.pt") / (1024 * 1024)  # Size in MB
    os.remove("temp_model.pt")
    return size

In [None]:
# Step 7: Quantization
print("\n--- Model Quantization ---")
# Reset model
model = DMPNN(
    n_tasks=n_tasks,
    mode='classification',
    n_atom_feat=75,
    n_pair_feat=14,
    n_hidden=300,
    n_graph_feat=300,
    dropout=0.2,
    learning_rate=0.001
)

# Train on GPU
if torch.cuda.is_available():
    model.model = model.model.cuda()
    print("Training on GPU...")
else:
    print("CUDA not available. Exiting quantization...")
    exit()

# Train the model
start_time = time.time()
model.fit(train_dataset, nb_epoch=10)
training_time = time.time() - start_time
print(f"Training completed in {training_time:.2f} seconds")

# Convert model to FP16
print("Converting model to FP16...")
model.model = model.model.half()  # Convert all parameters to float16

# Create a new model for evaluation with the FP16 model
quantized_eval_model = DMPNN(
    n_tasks=n_tasks,
    mode='classification',
    n_atom_feat=75,
    n_pair_feat=14,
    n_hidden=300,
    n_graph_feat=300,
    dropout=0.2,
    learning_rate=0.001
)

# Set the FP16 model
quantized_eval_model.model = model.model

# Evaluate quantized model
print("Evaluating quantized model...")
def evaluate_fp16(model, dataset):
    model.model.eval()
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for X_batch, y_batch, w_batch, ids_batch in dataset.iterbatches(batch_size=32, pad_batches=True):
            # Convert input to float16 and move to GPU
            X_tensor = torch.tensor(X_batch, dtype=torch.float16).cuda()
            y_tensor = torch.tensor(y_batch, dtype=torch.float32).cuda()

            # Forward pass
            predictions = model.model(X_tensor)
            all_predictions.append(predictions.float().cpu().numpy())
            all_targets.append(y_tensor.cpu().numpy())

    predictions = np.concatenate(all_predictions)
    targets = np.concatenate(all_targets)
    return dc.metrics.roc_auc_score(targets, predictions)

# Evaluate using our FP16-aware function
quantized_roc_auc = evaluate_fp16(quantized_eval_model, test_dataset)
print(f"Quantized model ROC-AUC: {quantized_roc_auc}")

original_size = get_model_size(model.model)
print(f"Original model size: {original_size:.2f} MB")
print(f"FP16 model size: ~{original_size/2:.2f} MB")
print(f"Size reduction: ~50%")

# Compare inference times
print("\nComparing inference times...")
def measure_inference_time(model, dataset, name, num_batches=50):
    # Ensure model is on GPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.model = model.model.to(device)
    model.model.eval()

    batch_size = 32

    # Determine input dtype based on model type
    if name == "Quantized Model":
        input_dtype = torch.float16
        # Ensure model is in float16
        model.model = model.model.half()
    else:
        input_dtype = torch.float32
        # Ensure model is in float32
        model.model = model.model.float()

    # Warmup
    for i, (X_batch, y_batch, w_batch, ids_batch) in enumerate(dataset.iterbatches(batch_size=batch_size, pad_batches=True)):
        X_tensor = torch.tensor(X_batch, dtype=input_dtype).to(device)
        with torch.no_grad():
            _ = model.model(X_tensor)
        if i >= 5:
            break

    # Measure inference time
    start_time = time.time()
    with torch.no_grad():
        for i, (X_batch, y_batch, w_batch, ids_batch) in enumerate(dataset.iterbatches(batch_size=batch_size, pad_batches=True)):
            X_tensor = torch.tensor(X_batch, dtype=input_dtype).to(device)
            _ = model.model(X_tensor)
            if i >= num_batches:
                break

    inference_time = (time.time() - start_time) / (num_batches + 1)
    print(f"{name} average inference time per batch: {inference_time*1000:.2f} ms")
    return inference_time

# Compare original and quantized model inference times
original_time = measure_inference_time(model, test_dataset, "Original Model")
quantized_time = measure_inference_time(quantized_eval_model, test_dataset, "Quantized Model")
print(f"Speedup: {original_time/quantized_time:.2f}x faster")

In [None]:
# Step 8: Knowledge Distillation
print("\n--- Knowledge Distillation ---")
# Create a smaller student model
student_model = DMPNN(
    n_tasks=n_tasks,
    mode='classification',
    n_atom_feat=75,
    n_pair_feat=14,
    n_hidden=150,    # Smaller hidden size
    n_graph_feat=150, # Smaller graph feature size
    dropout=0.2,
    learning_rate=0.001
)

# Use the original model as teacher
teacher_model = model
teacher_model.model = torch_model  # Use the original non-quantized model

# Move models to the same device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacher_model.model = teacher_model.model.to(device)
student_model.model = student_model.model.to(device)

# Implement knowledge distillation
def distill(teacher, student, dataset, temperature=3.0, alpha=0.5, epochs=10):
    student_model = student.model
    teacher_model = teacher.model
    teacher_model.eval()  # Teacher in eval mode

    # KL divergence loss for soft targets
    kl_criterion = torch.nn.KLDivLoss(reduction='batchmean')
    # Hard target loss
    ce_criterion = torch.nn.BCEWithLogitsLoss()

    optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

    start_time = time.time()
    for epoch in range(epochs):
        total_loss = 0
        batches = 0

        for X_batch, y_batch, w_batch, ids_batch in dataset.iterbatches(batch_size=32, pad_batches=True):
            # Move data to the same device as models
            X_tensor = torch.tensor(X_batch, dtype=torch.float32).to(device)
            y_tensor = torch.tensor(y_batch, dtype=torch.float32).to(device)

            # Forward pass for teacher (no grad)
            with torch.no_grad():
                teacher_output = teacher_model(X_tensor)

            # Forward pass for student
            student_output = student_model(X_tensor)

            # Compute soft targets with temperature
            soft_targets = torch.sigmoid(teacher_output / temperature)
            soft_student = torch.sigmoid(student_output / temperature)

            # Distillation loss (soft targets)
            distillation_loss = kl_criterion(
                torch.log(soft_student + 1e-8),  # Add small epsilon to avoid log(0)
                soft_targets
            ) * (temperature ** 2)

            # Student loss on hard targets
            student_loss = ce_criterion(student_output, y_tensor)

            # Combined loss
            loss = alpha * distillation_loss + (1 - alpha) * student_loss

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            batches += 1

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/batches:.4f}")

    distill_time = time.time() - start_time
    print(f"Distillation completed in {distill_time:.2f} seconds")

# Perform distillation
distill(teacher_model, student_model, train_dataset)

# Evaluate student model
student_scores = student_model.evaluate(test_dataset, metrics=[dc.metrics.Metric(dc.metrics.roc_auc_score)])
print(f"Student model ROC-AUC: {student_scores['roc_auc_score']}")
print(f"Teacher model ROC-AUC: {baseline_scores['roc_auc_score']}")

# Compare model sizes
student_size = get_model_size(student_model.model)
print(f"Teacher model size: {original_size:.2f} MB")
print(f"Student model size: {student_size:.2f} MB")
print(f"Size reduction: {(1 - student_size/original_size) * 100:.2f}%")

In [None]:
# Step 9: Visualization of results
print("\n--- Visualization ---")
# Plot inference times
models = ["Original", "Quantized", "Student"]
times = [original_time, quantized_time, measure_inference_time(student_model, test_dataset, "Student Model")]

if hasattr(torch, 'compile'):
    models.append("Compiled")
    times.append(compiled_time)

plt.figure(figsize=(10, 6))
plt.bar(models, times)
plt.title("Inference Time Comparison")
plt.ylabel("Time per batch (seconds)")
plt.savefig("dmpnn_inference_times.png")
print("Saved inference time comparison to 'dmpnn_inference_times.png'")

# Plot model sizes
sizes = [original_size, original_size/2, student_size]
model_types = ["Original", "FP16", "Student"]

plt.figure(figsize=(10, 6))
plt.bar(model_types, sizes)
plt.title("Model Size Comparison")
plt.ylabel("Size (MB)")
plt.savefig("dmpnn_model_sizes.png")
print("Saved model size comparison to 'dmpnn_model_sizes.png'")

print("\nOptimization study complete!")