# Generate molecules with a GNN diffusion model

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from diffusion_gnn.models import create_diffusion_model, create_noise_scheduler
from diffusion_gnn.data.deepchem import DeepChemMolecularDataset
from diffusion_gnn.utils.generation import sample_from_model, features_to_atom_types

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.3.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/jantinebroek/miniconda3/envs/diff_gnn/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/jantinebroek/miniconda3/envs/diff_gnn/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/jantinebroek/miniconda3/envs/diff_gnn/lib/python3.11/site-packages/ipykernel/kernelapp.py

In [2]:
# Load trained model

# Recreate the exact setup from your training
dataset = DeepChemMolecularDataset('tox21', max_atoms=30)
atom_dim = dataset._get_atom_feature_dim()
bond_dim = dataset._get_bond_feature_dim()

print(f"Dataset loaded: {len(dataset.dataset)} molecules")
print(f"Atom dim: {atom_dim}, Bond dim: {bond_dim}")



INFO:diffusion_gnn.data.deepchem:Loading DeepChem dataset: tox21
INFO:deepchem.data.datasets:Loading dataset from disk.
INFO:deepchem.data.datasets:Loading dataset from disk.
INFO:deepchem.data.datasets:Loading dataset from disk.
INFO:diffusion_gnn.data.deepchem:Loaded tox21 dataset with 6258 molecules


Dataset loaded: 6258 molecules
Atom dim: 37, Bond dim: 10


In [3]:
# Train model (or Load if saved)

# Since you don't have saved model, let's do quick training
from diffusion_gnn.utils.mol_diff_gnn import plot_training_metrics

# Create model and scheduler (same as your training)
model = create_diffusion_model(
    atom_dim=atom_dim,
    bond_dim=bond_dim,
    hidden_dim=128,
    num_layers=3,
    gnn_type='gat'
).to(device)

scheduler = create_noise_scheduler(num_timesteps=200).to(device)

# Quick training (reduced epochs for demo)
dataloader = dataset.create_dataloader(batch_size=8, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

print("Quick training for generation demo...")
model.train()
for epoch in range(5):  # Just 5 epochs for demo
    epoch_losses = []
    for batch in dataloader:
        batch = batch.to(device)
        batch_size = batch.batch.max().item() + 1

        optimizer.zero_grad()
        t = torch.randint(0, scheduler.num_timesteps, (batch_size,), device=device)
        noise = torch.randn_like(batch.x)
        x_noisy = scheduler.add_noise(batch.x, t[batch.batch], noise)
        noise_pred = model(x_noisy, batch.edge_index, batch.edge_attr, batch.batch, t)
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f"Epoch {epoch+1}: Loss = {np.mean(epoch_losses):.4f}")

print("Training complete!")

INFO:diffusion_gnn.data.deepchem:Converting 6258 molecules to graphs...
INFO:diffusion_gnn.data.deepchem:Processed 0/6258 molecules...
INFO:diffusion_gnn.data.deepchem:Processed 1000/6258 molecules...
INFO:diffusion_gnn.data.deepchem:Processed 2000/6258 molecules...
INFO:diffusion_gnn.data.deepchem:Processed 3000/6258 molecules...
INFO:diffusion_gnn.data.deepchem:Processed 4000/6258 molecules...
INFO:diffusion_gnn.data.deepchem:Processed 5000/6258 molecules...
INFO:diffusion_gnn.data.deepchem:Processed 6000/6258 molecules...
INFO:diffusion_gnn.data.deepchem:Successfully converted 5856/6258 molecules (93.6%)
INFO:diffusion_gnn.data.deepchem:Failed conversions: 402


Quick training for generation demo...
Epoch 1: Loss = 0.4555
Epoch 2: Loss = 0.2561
Epoch 3: Loss = 0.2262
Epoch 4: Loss = 0.2067
Epoch 5: Loss = 0.1948
Training complete!


In [4]:
# Generate molecular features
print("Generating molecules...")
generated_features = sample_from_model(
    model, scheduler,
    num_molecules=5,
    max_atoms=15,
    atom_dim=atom_dim,
    bond_dim=bond_dim,
    device=device
)

print(f"Generated features shape: {generated_features.shape}")

# Analyze generated atomic elements
for mol_idx in range(generated_features.shape[0]):
    mol_features = generated_features[mol_idx]
    atom_types = features_to_atom_types(mol_features)
    print(f"Molecule {mol_idx+1}: {atom_types}")

Generating molecules...
Generated features shape: torch.Size([5, 15, 37])


RuntimeError: Numpy is not available

In [5]:
# Visualize feature distributions
def plot_generation_analysis(features, dataset):
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    # Atomic element distribution
    all_features = features.reshape(-1, features.shape[-1])
    atomic_probs = torch.softmax(all_features[:, :11], dim=1)
    predicted_elements = torch.argmax(atomic_probs, dim=1)

    elements = ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'Other']
    unique_elements, counts = torch.unique(predicted_elements, return_counts=True)

    element_names = [elements[i] for i in unique_elements.cpu()]
    axes[0,0].bar(element_names, counts.cpu().numpy())
    axes[0,0].set_title('Generated Atomic Elements')
    axes[0,0].tick_params(axis='x', rotation=45)

    # Feature value distributions
    axes[0,1].hist(all_features[:, :11].cpu().numpy().flatten(), bins=50, alpha=0.7)
    axes[0,1].set_title('Atomic Number Feature Distribution')
    axes[0,1].set_xlabel('Feature Value')

    # Compare with real data sample
    real_batch = next(iter(dataset.create_dataloader(batch_size=16)))
    real_elements = torch.argmax(real_batch.x[:, :11], dim=1)
    real_unique, real_counts = torch.unique(real_elements, return_counts=True)

    axes[1,0].bar([elements[i] for i in real_unique.cpu()], real_counts.cpu().numpy(), alpha=0.7, label='Real')
    axes[1,0].bar(element_names, counts.cpu().numpy(), alpha=0.7, label='Generated')
    axes[1,0].set_title('Real vs Generated Elements')
    axes[1,0].legend()
    axes[1,0].tick_params(axis='x', rotation=45)

    # Feature magnitude comparison
    axes[1,1].hist(real_batch.x.cpu().numpy().flatten(), bins=50, alpha=0.7, label='Real', density=True)
    axes[1,1].hist(all_features.cpu().numpy().flatten(), bins=50, alpha=0.7, label='Generated', density=True)
    axes[1,1].set_title('Feature Magnitude Distribution')
    axes[1,1].legend()

    plt.tight_layout()
    plt.show()

plot_generation_analysis(generated_features, dataset)

RuntimeError: Numpy is not available

<Figure size 1000x600 with 0 Axes>