
# MolDiff: Complete 3D Molecular Generation Tutorial

This notebook demonstrates a complete implementation of MolDiff - a diffusion model for 3D molecular generation that addresses the atom-bond inconsistency problem.

## Key Improvements Over Basic Implementations

1. **Joint Atom-Bond Modeling**: Explicit bond prediction alongside atoms and positions
2. **Categorical Diffusion**: Proper handling of discrete features (atom types, bond types)
3. **Bond-Aware Message Passing**: GNN that considers both atoms and bonds
4. **Chemical Validity**: Constraints and guidance for realistic molecules
5. **Different Noise Schedules**: Bonds diffuse faster than atoms (key MolDiff insight)

---

In [None]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
import logging
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from rdkit import Chem

# Import all improved components
from mol_diff_3d.data.datasets import ImprovedQm9MolecularDataset
from mol_diff_3d.models.diffusion import MolecularDiffusionModel, BondPredictor
from mol_diff_3d.models.noise_scheduler import NoiseScheduler
from mol_diff_3d.sampling.samplers import ImprovedDDPMQSampler, ImprovedDDPMPsampler
from mol_diff_3d.training.trainer import ImprovedDDPMTrainer
from mol_diff_3d.models.categorical_diffusion import CategoricalNoiseScheduler, CategoricalDiffusion
from mol_diff_3d.generation.generator import generate_molecules_with_bond_guidance
from mol_diff_3d.utils.checkpoints import save_checkpoint, load_checkpoint

# Set up logging and visualization
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Setup matplotlib for notebook
%matplotlib inline
plt.style.use('default')


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Complete configuration for MolDiff training
config = {
    # Data Configuration
    'max_atoms': 25,              # Maximum atoms per molecule
    'max_samples': 10000,         # Number of training samples (increase for better results)
    'batch_size': 32,             # Batch size for training

    # Model Architecture
    'atom_dim': 15,               # Expanded atom feature dimension (10 elements + 5 properties)
    'bond_dim': 7,                # Bond feature dimension (5 types + 2 properties)
    'pos_dim': 3,                 # 3D coordinates
    'hidden_dim': 128,            # Hidden dimension for GNN
    'time_dim': 128,              # Time embedding dimension
    'num_gnn_layers': 4,          # Number of E3GNN layers

    # Training Configuration
    'num_timesteps': 1000,        # Diffusion timesteps
    'learning_rate': 5e-4,        # Learning rate (slightly higher for joint training)
    'epochs': 200,                # Training epochs (more needed for complex model)
    'log_interval': 10,           # Logging frequency

    # Loss Weights (Critical for good results)
    'atom_loss_weight': 1.0,      # Atom type prediction loss weight
    'pos_loss_weight': 1.0,       # Position prediction loss weight
    'bond_loss_weight': 2.0,      # Bond prediction loss weight (higher - bonds are crucial!)
    'guidance_loss_weight': 0.5,  # Bond predictor guidance loss weight

    # Generation Configuration
    'num_molecules_to_generate': 20,
    'guidance_steps': 200,        # Number of timesteps to apply bond guidance
    'temperature': 0.7,           # Sampling temperature
}

In [None]:
results_dir = Path("../results/moldiff_improved")  # This will be experiments/results/moldiff_improved
results_dir.mkdir(parents=True, exist_ok=True)

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

In [None]:
# Data Preparation

print("="*50)
print("STEP 1: DATA PREPARATION")
print("="*50)

# Load improved dataset with explicit bond features
dataset = ImprovedQm9MolecularDataset(max_atoms=config['max_atoms'])
info = dataset.get_dataset_info()

print("\nDataset Information:")
for key, value in info.items():
    print(f"  {key}: {value}")

# Create dataloader with bond features
dataloader = dataset.create_dataloader(
    batch_size=config['batch_size'],
    shuffle=True,
    max_samples=config['max_samples']
)

# Inspect sample batch to verify bond features are included
sample_batch = next(iter(dataloader))
print(f"\nSample Batch Verification:")
print(f"  Number of molecules: {sample_batch.num_graphs}")
print(f"  Atom features shape: {sample_batch.x.shape}")
print(f"  Bond features shape: {sample_batch.edge_attr.shape}")  # This should exist now!
print(f"  Positions shape: {sample_batch.pos.shape}")
print(f"  Edge index shape: {sample_batch.edge_index.shape}")

# Verify bond feature distribution
bond_types = sample_batch.edge_attr.argmax(dim=1)
print(f"  Bond type distribution in sample: {torch.bincount(bond_types)}")


In [None]:
print("="*50)
print("STEP 2: MODEL INITIALIZATION")
print("="*50)

# 1. Main diffusion model with joint atom-bond-position modeling
model = MolecularDiffusionModel(
    atom_dim=info['atom_feature_dim'],    # 15
    bond_dim=info['bond_feature_dim'],    # 7 (NEW - this was missing!)
    pos_dim=config['pos_dim'],            # 3
    hidden_dim=config['hidden_dim'],      # 128
    time_dim=config['time_dim'],          # 128
    num_gnn_layers=config['num_gnn_layers'], # 4
    max_atoms=config['max_atoms']         # 25
).to(device)

# 2. Bond predictor for guidance (NEW - helps with chemical validity)
bond_predictor = BondPredictor(
    atom_dim=info['atom_feature_dim'],
    pos_dim=config['pos_dim'],
    hidden_dim=config['hidden_dim']
).to(device)

# 3. Noise schedulers
# Continuous scheduler for positions (standard)
continuous_scheduler = NoiseScheduler(
    num_timesteps=config['num_timesteps']
).to(device)

# Categorical scheduler for discrete features (NEW - different schedule for bonds!)
categorical_scheduler = CategoricalNoiseScheduler(
    num_timesteps=config['num_timesteps'],
    schedule_type='linear'  # Bonds diffuse faster than atoms
).to(device)

# 4. Categorical diffusion handler (NEW - proper discrete diffusion)
categorical_diffusion = CategoricalDiffusion(
    num_atom_types=info['num_atom_types'],    # 10
    num_bond_types=info['num_bond_types'],    # 5
    scheduler=categorical_scheduler
)

# 5. Samplers
q_sampler = ImprovedDDPMQSampler(
    continuous_scheduler.get_parameters(),
    categorical_diffusion
)

# 6. Optimizer with different learning rates for different components
optimizer = AdamW([
    {'params': model.parameters(), 'lr': config['learning_rate']},
    {'params': bond_predictor.parameters(), 'lr': config['learning_rate'] * 0.5}
])

# Print model information
total_params = sum(p.numel() for p in model.parameters())
bond_params = sum(p.numel() for p in bond_predictor.parameters())
print(f"\nModel Architecture:")
print(f"  Main model parameters: {total_params:,}")
print(f"  Bond predictor parameters: {bond_params:,}")
print(f"  Total parameters: {total_params + bond_params:,}")

In [None]:
print("="*50)
print("STEP 3: TRAINING SETUP")
print("="*50)

# Initialize improved trainer with joint loss function
trainer = ImprovedDDPMTrainer(
    model=model,
    bond_predictor=bond_predictor,       # NEW - bond predictor for guidance
    q_sampler=q_sampler,
    categorical_diffusion=categorical_diffusion,  # NEW - proper discrete diffusion
    optimizer=optimizer,
    device=device,
    config=config
)

print("Trainer Configuration:")
print(f"  Atom loss weight: {config['atom_loss_weight']}")
print(f"  Position loss weight: {config['pos_loss_weight']}")
print(f"  Bond loss weight: {config['bond_loss_weight']}")  # Higher weight for bonds!
print(f"  Guidance loss weight: {config['guidance_loss_weight']}")
print(f"  Training epochs: {config['epochs']}")
print(f"  Learning rate: {config['learning_rate']}")

In [None]:
print("="*50)
print("STEP 4: JOINT TRAINING")
print("="*50)

# Train the models with joint atom-bond-position loss
print("Starting improved joint training...")
print("This will take significant time - monitor the loss components!")

losses, detailed_losses = trainer.train(dataloader, num_epochs=config['epochs'])

print(f"\nTraining completed!")
print(f"Final total loss: {losses[-1]:.4f}")
print(f"Final atom loss: {detailed_losses['atom'][-1]:.4f}")
print(f"Final position loss: {detailed_losses['pos'][-1]:.4f}")
print(f"Final bond loss: {detailed_losses['bond'][-1]:.4f}")
if detailed_losses['guidance']:
    print(f"Final guidance loss: {detailed_losses['guidance'][-1]:.4f}")

In [None]:
print("="*50)
print("STEP 5: TRAINING ANALYSIS")
print("="*50)

# Create comprehensive training visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Total loss
axes[0, 0].plot(losses, 'b-', linewidth=2)
axes[0, 0].set_title('Total Training Loss', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True, alpha=0.3)

# Individual loss components
loss_titles = ['Atom Loss', 'Position Loss', 'Bond Loss', 'Guidance Loss']
loss_keys = ['atom', 'pos', 'bond', 'guidance']
colors = ['red', 'green', 'blue', 'orange']

for i, (title, key, color) in enumerate(zip(loss_titles, loss_keys, colors)):
    if i < 5:  # Avoid index error
        row, col = divmod(i + 1, 3)
        if row < 2:
            axes[row, col].plot(detailed_losses[key], color=color, linewidth=2)
            axes[row, col].set_title(title, fontsize=14, fontweight='bold')
            axes[row, col].set_xlabel('Epoch')
            axes[row, col].set_ylabel('Loss')
            axes[row, col].grid(True, alpha=0.3)

# Hide unused subplot
axes[1, 2].axis('off')

plt.tight_layout()
plt.savefig(results_dir / 'improved_training_losses.png', dpi=300, bbox_inches='tight')
plt.show()

# Print training summary
print("\nTraining Summary:")
print(f"  Total epochs: {len(losses)}")
print(f"  Best total loss: {min(losses):.4f} (epoch {losses.index(min(losses)) + 1})")
print(f"  Final total loss: {losses[-1]:.4f}")
print(f"  Loss reduction: {((losses[0] - losses[-1]) / losses[0] * 100):.1f}%")

In [None]:
print("="*50)
print("STEP 6: SAVING CHECKPOINTS")
print("="*50)

# Save main model checkpoint
main_checkpoint_path = results_dir / "moldiff_improved_final.pth"
save_checkpoint(
    filepath=main_checkpoint_path,
    model=model,
    optimizer=optimizer,
    epoch=trainer.epoch,
    loss=np.mean(losses[-10:]),
    metrics={'detailed_losses': detailed_losses}
)

# Save bond predictor separately
bond_checkpoint_path = results_dir / "bond_predictor_final.pth"
torch.save({
    'model_state_dict': bond_predictor.state_dict(),
    'model_config': {
        'atom_dim': info['atom_feature_dim'],
        'pos_dim': config['pos_dim'],
        'hidden_dim': config['hidden_dim']
    }
}, bond_checkpoint_path)

print(f"Checkpoints saved:")
print(f"  Main model: {main_checkpoint_path}")
print(f"  Bond predictor: {bond_checkpoint_path}")

In [None]:
print("="*50)
print("STEP 7: MOLECULE GENERATION WITH BOND GUIDANCE")
print("="*50)

# Create reverse sampler with bond guidance
p_sampler = ImprovedDDPMPsampler(
    scheduler_params=continuous_scheduler.get_parameters(),
    categorical_diffusion=categorical_diffusion,
    bond_predictor=bond_predictor,      # NEW - bond guidance during sampling
    guidance_scale=1.0
)

print("Starting guided molecule generation...")
print(f"Generating {config['num_molecules_to_generate']} molecules...")

# Generate molecules with bond guidance
generated_mols, generated_smiles, generation_stats = generate_molecules_with_bond_guidance(
    model=model,
    bond_predictor=bond_predictor,
    p_sampler=p_sampler,
    num_molecules=config['num_molecules_to_generate'],
    max_atoms=config['max_atoms'],
    atom_dim=info['atom_feature_dim'],
    bond_dim=info['bond_feature_dim'],
    pos_dim=config['pos_dim'],
    device=device,
    guidance_steps=config['guidance_steps'],
    temperature=config['temperature']
)

print(f"Generation completed!")

In [None]:
print("="*50)
print("STEP 8: RESULTS ANALYSIS")
print("="*50)

print("GENERATION RESULTS:")
print(f"  Success Rate: {generation_stats['success_rate']:.2%}")
print(f"  Valid Molecules: {generation_stats['successful_generations']}")
print(f"  Failed Generations: {generation_stats['failed_generations']}")
print(f"  Invalid Chemistry: {generation_stats['invalid_chemistry']}")
print(f"  Guidance Applications: {generation_stats['guidance_applications']}")

if len(generated_smiles) > 0:
    print(f"\nGENERATED MOLECULES:")
    for i, smiles in enumerate(generated_smiles[:10]):  # Show first 10
        print(f"  {i+1:2d}: {smiles}")

    if len(generated_smiles) > 10:
        print(f"  ... and {len(generated_smiles) - 10} more")

    # Diversity analysis
    unique_smiles = set(generated_smiles)
    diversity = len(unique_smiles) / len(generated_smiles)
    print(f"\nDIVERSITY ANALYSIS:")
    print(f"  Total molecules: {len(generated_smiles)}")
    print(f"  Unique molecules: {len(unique_smiles)}")
    print(f"  Uniqueness rate: {diversity:.2%}")

    # Molecule size analysis
    mol_sizes = [mol.GetNumAtoms() for mol in generated_mols if mol]
    if mol_sizes:
        print(f"\nSIZE ANALYSIS:")
        print(f"  Average atoms: {np.mean(mol_sizes):.1f}")
        print(f"  Size range: {min(mol_sizes)} - {max(mol_sizes)} atoms")
        print(f"  Standard deviation: {np.std(mol_sizes):.1f}")

    # Chemical properties analysis
    print(f"\nCHEMICAL PROPERTIES:")
    valid_mols = [mol for mol in generated_mols if mol]
    if valid_mols:
        bond_counts = [mol.GetNumBonds() for mol in valid_mols]
        ring_counts = [mol.GetRingInfo().NumRings() for mol in valid_mols]

        print(f"  Average bonds: {np.mean(bond_counts):.1f}")
        print(f"  Average rings: {np.mean(ring_counts):.1f}")

        # Atom type distribution
        atom_symbols = []
        for mol in valid_mols:
            for atom in mol.GetAtoms():
                atom_symbols.append(atom.GetSymbol())

        from collections import Counter
        atom_counts = Counter(atom_symbols)
        print(f"  Atom distribution: {dict(atom_counts.most_common(5))}")

else:
    print("\nWARNING: No valid molecules were generated!")
    print("Consider:")
    print("  - Training for more epochs")
    print("  - Adjusting loss weights (increase bond_loss_weight)")
    print("  - Reducing guidance scale")
    print("  - Using a larger dataset (increase max_samples)")
    print("  - Checking model architecture")

In [None]:
print("="*50)
print("STEP 9: SAVING RESULTS")
print("="*50)

# Save generated molecules as SDF file
if generated_mols:
    sdf_path = results_dir / "generated_molecules.sdf"
    writer = Chem.SDWriter(str(sdf_path))
    for mol in generated_mols:
        if mol:
            writer.write(mol)
    writer.close()
    print(f"Saved {len(generated_mols)} molecules to: {sdf_path}")

# Save SMILES to text file
if generated_smiles:
    smiles_path = results_dir / "generated_smiles.txt"
    with open(smiles_path, 'w') as f:
        for smiles in generated_smiles:
            f.write(f"{smiles}\n")
    print(f"Saved {len(generated_smiles)} SMILES to: {smiles_path}")

# Save generation statistics
import json
stats_path = results_dir / "generation_stats.json"
with open(stats_path, 'w') as f:
    # Convert any tensors to lists for JSON serialization
    json_stats = {}
    for key, value in generation_stats.items():
        if isinstance(value, torch.Tensor):
            json_stats[key] = value.tolist()
        else:
            json_stats[key] = value
    json.dump(json_stats, f, indent=2)
print(f"Saved generation statistics to: {stats_path}")

# Save training curves data
curves_path = results_dir / "training_curves.npz"
np.savez(curves_path,
         total_losses=np.array(losses),
         atom_losses=np.array(detailed_losses['atom']),
         pos_losses=np.array(detailed_losses['pos']),
         bond_losses=np.array(detailed_losses['bond']),
         guidance_losses=np.array(detailed_losses['guidance'])
)
print(f"Saved training curves to: {curves_path}")

print(f"\nAll results saved to: {results_dir}")

In [None]:
print("="*50)
print("STEP 10: VISUAL COMPARISON")
print("="*50)

# Create a comparison visualization of some generated molecules
if len(generated_mols) >= 4:
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()

    from rdkit.Chem import Draw
    from rdkit.Chem.Draw import rdMolDraw2D
    import io
    from PIL import Image

    for i in range(min(4, len(generated_mols))):
        mol = generated_mols[i]
        if mol:
            # Create 2D depiction
            img = Draw.MolToImage(mol, size=(300, 300))

            # Convert to numpy array for matplotlib
            img_array = np.array(img)

            axes[i].imshow(img_array)
            axes[i].set_title(f"Generated Molecule {i+1}\n{generated_smiles[i][:30]}...")
            axes[i].axis('off')

    plt.tight_layout()
    plt.savefig(results_dir / 'generated_molecules_visual.png', dpi=300, bbox_inches='tight')
    plt.show()

print("Tutorial completed successfully!")
print("\nNext steps:")
print("1. Experiment with different hyperparameters")
print("2. Try larger datasets for better results")
print("3. Implement additional guidance mechanisms")
print("4. Add property-based generation objectives")