# PhysNet Training: Energies and Forces with Memmap Data

Train PhysNet on large datasets using memory-mapped data loading.

**Hardware**: A100 GPU


In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import e3x
from mmml.data.packed_memmap_loader import PackedMemmapLoader, split_loader
from mmml.physnetjax.physnetjax.models.model import EF
from mmml.physnetjax.physnetjax.training.trainstep import train_step
from mmml.physnetjax.physnetjax.training.evalstep import eval_step
from mmml.physnetjax.physnetjax.training.optimizer import get_optimizer

print(f"JAX devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")


## Configuration


In [None]:
# Data path
DATA_PATH = "openqdc_packed_memmap"  # Change to your data path

# Training parameters
BATCH_SIZE = 64        # A100 can handle large batches
NUM_ATOMS = 60         # Max atoms per molecule
NUM_EPOCHS = 10        # Adjust as needed
LEARNING_RATE = 0.001

# Model parameters
FEATURES = 128
NUM_ITERATIONS = 3
CUTOFF = 5.0

# Loss weights
ENERGY_WEIGHT = 1.0
FORCES_WEIGHT = 52.91  # kcal/mol conversion


## Load Data


In [None]:
# Create loader
loader = PackedMemmapLoader(
    path=DATA_PATH,
    batch_size=BATCH_SIZE,
    shuffle=True,
    bucket_size=8192,
    seed=42,
)

print(f"Total molecules: {loader.N}")
print(f"Max atoms: {loader.n_atoms.max()}")
print(f"Min atoms: {loader.n_atoms.min()}")
print(f"Mean atoms: {loader.n_atoms.mean():.1f}")


In [None]:
# Split into train/validation
train_loader, valid_loader = split_loader(loader, train_fraction=0.9, seed=42)

print(f"Training molecules: {train_loader.N}")
print(f"Validation molecules: {valid_loader.N}")


## Create Model


In [None]:
model = EF(
    features=FEATURES,
    max_degree=2,
    num_iterations=NUM_ITERATIONS,
    num_basis_functions=16,
    cutoff=CUTOFF,
    max_atomic_number=118,
    charges=False,
    natoms=NUM_ATOMS,
    n_res=3,
    zbl=True,
)

print("Model created")


## Initialize Parameters


In [None]:
key = jax.random.PRNGKey(42)
key, init_key = jax.random.split(key)

# Get sample batch for initialization
sample_batch = next(train_loader.batches(num_atoms=NUM_ATOMS))
dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(NUM_ATOMS)

params = model.init(
    init_key,
    atomic_numbers=sample_batch["Z"][0],
    positions=sample_batch["R"][0],
    dst_idx=dst_idx,
    src_idx=src_idx,
)

n_params = sum(x.size for x in jax.tree_util.tree_leaves(params))
print(f"Model initialized: {n_params:,} parameters")


## Setup Optimizer


In [None]:
optimizer, transform, schedule_fn, _ = get_optimizer(
    learning_rate=LEARNING_RATE,
    schedule_fn=None,
    optimizer=None,
    transform=None,
)

ema_params = params
opt_state = optimizer.init(params)
transform_state = transform.init(params)

print("Optimizer ready")


## Training Loop


In [None]:
print("Starting training...\\n")

for epoch in range(1, NUM_EPOCHS + 1):
    # Train
    train_loss = 0.0
    train_e_mae = 0.0
    train_f_mae = 0.0
    
    for i, batch in enumerate(train_loader.batches(num_atoms=NUM_ATOMS)):
        batch_size = int(batch["Z"].shape[0])
        
        # Flatten batch arrays (model expects flattened)
        batch["Z"] = batch["Z"].reshape(-1)
        batch["R"] = batch["R"].reshape(-1, 3)
        batch["F"] = batch["F"].reshape(-1, 3)
        
        # Add masks that train_step expects
        batch["atom_mask"] = (batch["Z"] > 0).astype(jnp.float32)
        batch["batch_mask"] = jnp.ones_like(batch["dst_idx"], dtype=jnp.float32)
        
        (
            params,
            ema_params,
            opt_state,
            transform_state,
            loss,
            energy_mae,
            forces_mae,
            _,
        ) = train_step(
            model_apply=model.apply,
            optimizer_update=optimizer.update,
            transform_state=transform_state,
            batch=batch,
            batch_size=batch_size,
            energy_weight=ENERGY_WEIGHT,
            forces_weight=FORCES_WEIGHT,
            dipole_weight=0.0,
            charges_weight=0.0,
            opt_state=opt_state,
            doCharges=False,
            params=params,
            ema_params=ema_params,
            debug=False,
        )
        
        train_loss += (loss - train_loss) / (i + 1)
        train_e_mae += (energy_mae - train_e_mae) / (i + 1)
        train_f_mae += (forces_mae - train_f_mae) / (i + 1)
    
    # Validate
    valid_loss = 0.0
    valid_e_mae = 0.0
    valid_f_mae = 0.0
    
    for i, batch in enumerate(valid_loader.batches(num_atoms=NUM_ATOMS)):
        batch_size = int(batch["Z"].shape[0])
        
        # Flatten batch arrays (model expects flattened)
        batch["Z"] = batch["Z"].reshape(-1)
        batch["R"] = batch["R"].reshape(-1, 3)
        batch["F"] = batch["F"].reshape(-1, 3)
        
        # Add masks for eval_step
        batch["atom_mask"] = (batch["Z"] > 0).astype(jnp.float32)
        batch["batch_mask"] = jnp.ones_like(batch["dst_idx"], dtype=jnp.float32)
        
        loss, energy_mae, forces_mae, _ = eval_step(
            model_apply=model.apply,
            batch=batch,
            batch_size=batch_size,
            energy_weight=ENERGY_WEIGHT,
            forces_weight=FORCES_WEIGHT,
            dipole_weight=0.0,
            charges_weight=0.0,
            charges=False,
            params=ema_params,
        )
        
        valid_loss += (loss - valid_loss) / (i + 1)
        valid_e_mae += (energy_mae - valid_e_mae) / (i + 1)
        valid_f_mae += (forces_mae - valid_f_mae) / (i + 1)
    
    # Print results
    print(f"Epoch {epoch}/{NUM_EPOCHS}:")
    print(f"  Train: Loss={train_loss:.6f}, E_MAE={train_e_mae:.6f}, F_MAE={train_f_mae:.6f}")
    print(f"  Valid: Loss={valid_loss:.6f}, E_MAE={valid_e_mae:.6f}, F_MAE={valid_f_mae:.6f}")

print("\\nTraining complete!")


In [None]:
# Get a test batch
test_batch = next(valid_loader.batches(num_atoms=NUM_ATOMS))
batch_size = int(test_batch["Z"].shape[0])

# Flatten arrays
test_batch["Z"] = test_batch["Z"].reshape(-1)
test_batch["R"] = test_batch["R"].reshape(-1, 3)
test_batch["F"] = test_batch["F"].reshape(-1, 3)

# Predict energies and forces
outputs = model.apply(
    ema_params,
    atomic_numbers=test_batch["Z"],
    positions=test_batch["R"],
    dst_idx=test_batch["dst_idx"],
    src_idx=test_batch["src_idx"],
    batch_segments=test_batch["batch_segments"],
    batch_size=batch_size,
)

print(f"Predicted energies: {outputs['energy']}")
print(f"True energies: {test_batch['E']}")
print(f"\\nEnergy MAE: {jnp.mean(jnp.abs(outputs['energy'] - test_batch['E'])):.6f} kcal/mol")

# Check forces
mask = (test_batch["Z"] > 0).astype(jnp.float32)
forces_pred = outputs['forces']
forces_true = test_batch["F"]
forces_diff = (forces_pred - forces_true) * mask[:, None]

print(f"Forces MAE: {jnp.mean(jnp.abs(forces_diff)):.6f} kcal/mol/Ã…")
