# Day 21: Neural Message Passing for Quantum Chemistry

> Gilmer, Schoenholz, Riley, Vinyals, Dahl (2017) — [arXiv:1704.01212](https://arxiv.org/abs/1704.01212)

This notebook walks through the MPNN framework from the paper:
1. Build molecular graphs from atom and bond data
2. Implement the three design choices: message function, update function, readout
3. Train an MPNN on synthetic molecular data
4. Visualize message passing and compare architecture variants

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from implementation import (
    MPNN,
    MolecularGraph,
    BatchedGraph,
    ATOM_TYPES,
    BOND_TYPES,
    generate_dataset,
    generate_synthetic_molecule,
    batch_graphs,
    train_epoch,
    evaluate,
    summarize_model,
    SimpleMessage,
    MatrixMessage,
    EdgeNetwork,
    SumReadout,
    Set2SetReadout,
    scatter_sum,
)

torch.manual_seed(42)
np.random.seed(42)
print(f'PyTorch {torch.__version__}')
print(f'Device: {"cuda" if torch.cuda.is_available() else "cpu"}')

## 1. Molecular Graphs

A molecule is represented as a graph (Section 2):
- **Nodes** = atoms, with features (one-hot atom type: H, C, N, O, F)
- **Edges** = bonds, with features (one-hot bond type: single, double, triple, aromatic)
- **Target** = property values to predict (energy, dipole moment, etc.)

Let's generate a synthetic molecule and inspect it.

In [None]:
mol = generate_synthetic_molecule(min_atoms=6, max_atoms=10, n_targets=3, seed=42)

print(f'Atoms: {mol.num_nodes}')
print(f'Bonds (directed): {mol.num_edges}')
print(f'Node features shape: {mol.node_features.shape}')
print(f'Edge index shape: {mol.edge_index.shape}')
print(f'Edge features shape: {mol.edge_features.shape}')
print(f'Target: {mol.target}')

# Decode atom types
atom_indices = mol.node_features.argmax(dim=1)
atom_names = [ATOM_TYPES[i] for i in atom_indices.tolist()]
print(f'\nAtom types: {atom_names}')

## 2. Message Functions (Sections 3-4)

The message function $M(h_v, h_w, e_{vw})$ determines how information flows
from neighbor $w$ to node $v$. Three variants:

| Variant | Formula | Paper Section |
|---------|---------|---------------|
| Simple | $M = h_w$ | 3 (Duvenaud) |
| Matrix | $M = A_{e_{vw}} h_w$ | 3 (Li et al.) |
| Edge Network | $M = A(e_{vw}) h_w$ | 4.1 (this paper) |

In [None]:
hidden_dim = 16
edge_dim = len(BOND_TYPES)  # 4
n_edges = 10

# Create fake data
h_v = torch.randn(n_edges, hidden_dim)
h_w = torch.randn(n_edges, hidden_dim)
edge_feats = torch.zeros(n_edges, edge_dim)
edge_feats.scatter_(1, torch.randint(0, edge_dim, (n_edges, 1)), 1.0)

# Compare message functions
for name, cls in [('Simple', SimpleMessage), ('Matrix', MatrixMessage), ('EdgeNetwork', EdgeNetwork)]:
    msg_fn = cls(hidden_dim, edge_dim)
    out = msg_fn(h_v, h_w, edge_feats)
    print(f'{name:12s} output shape: {out.shape}, mean: {out.mean():.4f}, std: {out.std():.4f}')

## 3. Full MPNN Training

Now we train a complete MPNN on synthetic molecular data.
Architecture: edge network + GRU + Set2Set (the best variant from Table 2).

In [None]:
# Generate dataset
dataset = generate_dataset(n_molecules=400, n_targets=3, seed=42)
train_data = dataset[:320]
test_data = dataset[320:]

print(f'Train: {len(train_data)}, Test: {len(test_data)}')

# Create model
model = MPNN(
    node_dim=len(ATOM_TYPES),
    edge_dim=len(BOND_TYPES),
    hidden_dim=64,
    output_dim=3,
    n_messages=3,
    message_type='edge_network',
    readout_type='set2set',
    set2set_steps=6
)

print(summarize_model(model))

In [None]:
# Train
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 30

train_losses = []
test_maes = []

for epoch in range(1, n_epochs + 1):
    loss = train_epoch(model, train_data, optimizer, batch_size=32)
    metrics = evaluate(model, test_data, batch_size=32)
    train_losses.append(loss)
    test_maes.append(metrics['mae'])
    if epoch % 5 == 0 or epoch == 1:
        print(f'Epoch {epoch:2d}: train_loss={loss:.4f}, test_mae={metrics["mae"]:.4f}')

print(f'\nFinal test MAE: {test_maes[-1]:.4f}')

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(train_losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss (MAE)')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)

axes[1].plot(test_maes, color='orange')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Test MAE')
axes[1].set_title('Test MAE')
axes[1].grid(True, alpha=0.3)

fig.suptitle('MPNN Training (Gilmer et al. 2017)', fontsize=13)
plt.tight_layout()
plt.show()

## 4. Comparing Message Function Variants

Let's compare the three message function variants to see how they affect
learning. On real QM9 data, the edge network consistently outperforms
simpler variants (Table 2). On synthetic data, differences may be smaller.

In [None]:
comparison_results = {}

for msg_type in ['simple', 'matrix', 'edge_network']:
    torch.manual_seed(42)
    m = MPNN(
        node_dim=len(ATOM_TYPES), edge_dim=len(BOND_TYPES),
        hidden_dim=32, output_dim=3, n_messages=3,
        message_type=msg_type, readout_type='set2set'
    )
    opt = torch.optim.Adam(m.parameters(), lr=1e-3)
    losses = []
    for epoch in range(20):
        l = train_epoch(m, train_data, opt, batch_size=32)
        losses.append(l)
    final_mae = evaluate(m, test_data)['mae']
    comparison_results[msg_type] = {'losses': losses, 'final_mae': final_mae}
    print(f'{msg_type:15s}: final MAE = {final_mae:.4f}')

# Plot
fig, ax = plt.subplots(figsize=(8, 4))
for msg_type, result in comparison_results.items():
    ax.plot(result['losses'], label=msg_type)
ax.set_xlabel('Epoch')
ax.set_ylabel('Training Loss')
ax.set_title('Message Function Comparison (Sections 3-4)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Message Passing Receptive Field

After $T$ message passing rounds, each node has received information from
all nodes within $T$ hops. This is the "receptive field" of the node.

For QM9 molecules (diameter typically 3-5 bonds), $T = 6$ ensures every
atom can communicate with every other atom (Section 2).

In [None]:
try:
    from visualization import plot_message_passing_steps, plot_molecule
    mol_vis = generate_synthetic_molecule(min_atoms=8, max_atoms=12, seed=42)
    plot_molecule(mol_vis, title=f'Synthetic Molecule ({mol_vis.num_nodes} atoms)')
    plt.show()
    plot_message_passing_steps(mol_vis, n_steps=3)
    plt.show()
except ImportError as e:
    print(f'networkx required for molecule visualization: {e}')

## Key Takeaways

1. **MPNN is a unifying framework**: message function $M$, update function $U$, readout $R$ — three choices that define the model (Section 2)
2. **Edge network handles continuous features**: maps edge attributes to $d \times d$ transformation matrices (Section 4.1)
3. **Set2Set beats sum pooling**: iterative attention captures the distribution of node states, not just their sum (Section 4.3)
4. **Chemical accuracy on 11/13 targets**: with 3D coordinates as input; 5/13 without coordinates (Table 2)
5. **The framework became standard**: virtually every GNN paper since 2017 defines models in MPNN terms