In [1]:
!export JAX_ENABLE_X64=0
!export XLA_PYTHON_CLIENT_PREALLOCATE=true
!export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95
!export XLA_FLAGS="--xla_gpu_enable_triton_gemm=true --xla_gpu_enable_triton_softmax_fusion=true"

import jax
jax.devices()

[CudaDevice(id=0)]

In [2]:
from mmml import *

In [3]:
#!/usr/bin/env python3
"""
PhysNet Training Script (Memmap + JAX-accelerated graph build)

Usage example:
  python train_physnet_memmap.py \
    --data_path openqdc_packed_memmap \
    --batch_size 256 \
    --num_epochs 10 \
    --learning_rate 1e-4 \
    --num_atoms 128 \
    --graph_mode dense   # or sparse
"""

import os
import time
from pathlib import Path
from typing import Iterator, Dict, Optional, Tuple

import numpy as np
import jax
import jax.numpy as jnp

# Your project imports (keep as-is)
import e3x
from flax.training import orbax_utils, train_state
from mmml.physnetjax.physnetjax.models.model import EF
from mmml.physnetjax.physnetjax.training.trainstep import train_step
from mmml.physnetjax.physnetjax.training.evalstep import eval_step
from mmml.physnetjax.physnetjax.training.optimizer import get_optimizer
from mmml.physnetjax.physnetjax.restart.restart import orbax_checkpointer
from mmml.physnetjax.physnetjax.directories import BASE_CKPT_DIR

# ---------------------------
# Prefetcher (threaded)
# ---------------------------
import threading, queue
class Prefetcher:
    def __init__(self, iterator, n_prefetch: int = 2):
        self._it = iter(iterator)
        self._q = queue.Queue(maxsize=n_prefetch)
        self._done = object()
        def _worker():
            try:
                for x in self._it:
                    self._q.put(x)
            finally:
                self._q.put(self._done)
        threading.Thread(target=_worker, daemon=True).start()
    def __iter__(self):
        return self
    def __next__(self):
        x = self._q.get()
        if x is self._done:
            raise StopIteration
        return x

# ===== FIX: make dense graph JIT-safe via fixed-size nonzero; ensure fixed B by drop_last =====
import jax
import jax.numpy as jnp
import numpy as np

# ---------------------------
# Packed Memmap Loader (drop_last for static B)
# ---------------------------
class PackedMemmapLoader:
    """Yields padded (B, A, ...) numpy batches; mask is from Z>0. Drops last short batch."""
    def __init__(self, path: str, batch_size: int, shuffle: bool=True, bucket_size: int=8192, seed: int=0):
        from pathlib import Path
        self.path = path
        self.batch_size = int(batch_size)
        self.shuffle = shuffle
        self.bucket_size = int(bucket_size)
        self.rng = np.random.default_rng(seed)

        # metadata
        self.offsets = np.load(os.path.join(path, "offsets.npy"))
        self.n_atoms = np.load(os.path.join(path, "n_atoms.npy"))
        self.N = int(self.n_atoms.shape[0])
        sumA = int(self.offsets[-1])

        # packed arrays
        self.Z_pack = np.memmap(os.path.join(path, "Z_pack.int32"), dtype=np.int32,  mode="r", shape=(sumA,))
        self.R_pack = np.memmap(os.path.join(path, "R_pack.f32"),  dtype=np.float32, mode="r", shape=(sumA, 3))
        self.F_pack = np.memmap(os.path.join(path, "F_pack.f32"),  dtype=np.float32, mode="r", shape=(sumA, 3))
        self.E      = np.memmap(os.path.join(path, "E.f64"),       dtype=np.float64, mode="r", shape=(self.N,))
        self.Qtot   = np.memmap(os.path.join(path, "Qtot.f64"),    dtype=np.float64, mode="r", shape=(self.N,))

        self.indices = np.arange(self.N, dtype=np.int64)

    def _yield_indices_bucketed(self):
        order = self.indices.copy()
        if self.shuffle:
            self.rng.shuffle(order)
        # ensure total count is a multiple of batch_size (drop_last)
        limit = (len(order) // self.batch_size) * self.batch_size
        order = order[:limit]
        for s in range(0, limit, self.bucket_size):
            chunk = order[s:s+self.bucket_size]
            # sort by size within bucket to minimize padding
            idx = np.argsort(self.n_atoms[chunk], kind="mergesort")
            chunk = chunk[idx]
            for b in range(0, len(chunk), self.batch_size):
                yield chunk[b:b+self.batch_size]  # always length == batch_size

    def _slice(self, k: int):
        a0, a1 = int(self.offsets[k]), int(self.offsets[k+1])
        return (
            self.Z_pack[a0:a1],
            self.R_pack[a0:a1],
            self.F_pack[a0:a1],
            self.E[k],
            self.Qtot[k],
        )

    def batches(self, num_atoms: int):
        """Yield fixed-shape (B,A,...) numpy dicts."""
        A = int(num_atoms)
        for idx in self._yield_indices_bucketed():
            B = len(idx)                      # fixed = batch_size
            Z = np.zeros((B, A), dtype=np.int32)
            R = np.zeros((B, A, 3), dtype=np.float32)
            F = np.zeros((B, A, 3), dtype=np.float32)
            E = np.zeros((B,), dtype=np.float64)
            Qtot = np.zeros((B,), dtype=np.float64)
            for j, k in enumerate(idx):
                z, r, f, e, q = self._slice(int(k))
                a = min(z.shape[0], A)
                Z[j, :a] = z[:a]
                R[j, :a] = r[:a]
                F[j, :a] = f[:a]
                E[j] = e
                Qtot[j] = q
            yield {"Z": Z, "R": R, "F": F, "E": E, "Qtot": Qtot}

# ------------------ Dense graph (JIT) with fixed-size nonzero ------------------
@jax.jit
def _build_graph_dense_fixed(Z: jnp.ndarray, R: jnp.ndarray, cutoff: float):
    """
    Returns padded edges with fixed size Emax = B*A*(A-1).
    Also returns edge_mask (Emax,) marking real edges.
    """
    B, A = Z.shape
    valid = (Z > 0)                                  # (B,A)
    dR = R[:, :, None, :] - R[:, None, :, :]         # (B,A,A,3)
    D  = jnp.linalg.norm(dR, axis=-1)                # (B,A,A)
    adj = (D < cutoff) & (valid[:, :, None] & valid[:, None, :]) & (~jnp.eye(A, dtype=bool)[None])

    # Fixed-size nonzero with padding
    Emax = B * A * (A - 1)                           # directed edges, no self
    b, i, j = jnp.nonzero(adj, size=Emax, fill_value=0)  # (Emax,)
    edge_mask = adj[b, i, j]                         # (Emax,) True for real edges, False for padded

    src_idx = (b * A + i).astype(jnp.int32)          # (Emax,)
    dst_idx = (b * A + j).astype(jnp.int32)          # (Emax,)
    batch_segments = jnp.repeat(jnp.arange(B, dtype=jnp.int32), A)  # (B*A,)
    return src_idx, dst_idx, batch_segments, edge_mask

import jax
import jax.numpy as jnp
from jax import lax

import jax
import jax.numpy as jnp
from jax import lax
import functools, jax, jax.numpy as jnp
from jax import lax

@functools.partial(jax.jit, static_argnames=('tile_j',))
def _build_graph_dense_tiled_static(Z: jnp.ndarray, R: jnp.ndarray, cutoff: float, tile_j: int = 32):
    B, A = Z.shape
    A_pad = ((A + tile_j - 1) // tile_j) * tile_j
    padJ  = A_pad - A

    valid     = (Z > 0)
    R_pad     = jnp.pad(R,     ((0,0),(0,padJ),(0,0)), constant_values=0.0)
    valid_pad = jnp.pad(valid, ((0,0),(0,padJ)),       constant_values=False)

    adj0 = jnp.zeros((B, A_pad, A_pad), dtype=bool)

    def body_fun(t, adj):
        j0 = t * tile_j  # dynamic int tracer
        # slice j-block (static size = tile_j)
        Rj = lax.dynamic_slice_in_dim(R_pad,     start_index=j0, slice_size=tile_j, axis=1)  # (B,tile_j,3)
        Vj = lax.dynamic_slice_in_dim(valid_pad, start_index=j0, slice_size=tile_j, axis=1)  # (B,tile_j)

        dR_block = R_pad[:, :, None, :] - Rj[:, None, :, :]   # (B,A_pad,tile_j,3)
        D_block  = jnp.linalg.norm(dR_block, axis=-1)         # (B,A_pad,tile_j)

        Vi = valid_pad[:, :, None]                            # (B,A_pad,1)
        Vj_ = Vj[:, None, :]                                  # (B,1,tile_j)
        mask_block = (D_block < cutoff) & (Vi & Vj_)          # (B,A_pad,tile_j)

        # no self-edges on diagonal positions of this tile
        i_idx = jnp.arange(A_pad)
        j_idx = j0 + jnp.arange(tile_j)                       # static length; values traced
        diag_mask = (i_idx[:, None] != j_idx[None, :])
        mask_block = mask_block & diag_mask[None, :, :]

        # ✨ dynamic update instead of slicing with dynamic j0
        adj = lax.dynamic_update_slice(adj, mask_block, (0, 0, j0))
        return adj

    n_tiles = A_pad // tile_j
    adj = lax.fori_loop(0, n_tiles, body_fun, adj0)

    # fixed-size nonzero (padded)
    Emax = B * A_pad * (A_pad - 1)
    b, i, j   = jnp.nonzero(adj, size=Emax, fill_value=0)
    edge_mask = adj[b, i, j]

    # clamp padded rows/cols to 0 (masked anyway)
    i_safe = jnp.where((i < A) & edge_mask, i, 0)
    j_safe = jnp.where((j < A) & edge_mask, j, 0)

    src_idx = (b * A + i_safe).astype(jnp.int32)
    dst_idx = (b * A + j_safe).astype(jnp.int32)

    batch_segments = jnp.repeat(jnp.arange(B, dtype=jnp.int32), A)
    return src_idx, dst_idx, batch_segments, edge_mask

@functools.partial(jax.jit, static_argnames=('tile_j',))
def preprocess_batch_dense(Z, R, F, E, Qtot, cutoff, tile_j: int = 32):
    src_idx, dst_idx, batch_segments, edge_mask = _build_graph_dense_tiled_static(Z, R, cutoff, tile_j)
    B, A = Z.shape
    Zf = Z.reshape(B * A); Rf = R.reshape(B * A, 3); Ff = F.reshape(B * A, 3)
    atom_mask  = (Zf > 0)
    batch_mask = edge_mask  # (Emax,)
    return {
        "Z": Zf, "R": Rf, "F": Ff, "E": E, "Qtot": Qtot,
        "atom_mask": atom_mask,
        "batch_mask": batch_mask,
        "batch_segments": batch_segments,
        "src_idx": src_idx, "dst_idx": dst_idx,
    }
# ------------------ Sparse graph (eager) unchanged from prior working version ------------------
def _per_sample_sparse_eager(rs: jnp.ndarray, z: jnp.ndarray, b: int, cutoff: float):
    A = int(z.shape[0])
    valid_mask = (z > 0)
    a = int(valid_mask.sum())
    if a <= 1:
        return jnp.array([], dtype=jnp.int32), jnp.array([], dtype=jnp.int32)
    valid_idx = jnp.where(valid_mask)[0]          # (a,)
    rs_v = rs[valid_idx, :]                       # (a,3)
    diff = rs_v[:, None, :] - rs_v[None, :, :]    # (a,a,3)
    dist = jnp.linalg.norm(diff, axis=-1)         # (a,a)
    adj  = (dist < cutoff) & (~jnp.eye(a, dtype=bool))
    i_loc, j_loc = jnp.where(adj)                 # (e,)
    if i_loc.size == 0:
        return jnp.array([], dtype=jnp.int32), jnp.array([], dtype=jnp.int32)
    src = b * A + valid_idx[i_loc]
    dst = b * A + valid_idx[j_loc]
    return src.astype(jnp.int32), dst.astype(jnp.int32)

def _build_graph_sparse_eager(Z: jnp.ndarray, R: jnp.ndarray, cutoff: float):
    B, A = map(int, Z.shape)
    src_chunks, dst_chunks = [], []
    for b in range(B):
        s, d = _per_sample_sparse_eager(R[b], Z[b], b, cutoff)
        src_chunks.append(s); dst_chunks.append(d)
    if len(src_chunks) == 0 or src_chunks[0].size == 0:
        src_idx = jnp.array([], dtype=jnp.int32)
        dst_idx = jnp.array([], dtype=jnp.int32)
    else:
        src_idx = jnp.concatenate(src_chunks, axis=0)
        dst_idx = jnp.concatenate(dst_chunks, axis=0)
    batch_segments = jnp.repeat(jnp.arange(B, dtype=jnp.int32), A)
    return src_idx, dst_idx, batch_segments

@jax.jit
def _preprocess_common_masks(Z: jnp.ndarray, R: jnp.ndarray, F: jnp.ndarray, E: jnp.ndarray, Qtot: jnp.ndarray):
    # same as _preprocess_common but kept separate to avoid confusion
    B, A = Z.shape
    Zf = Z.reshape(B * A)
    Rf = R.reshape(B * A, 3)
    Ff = F.reshape(B * A, 3)
    atom_mask  = (Zf > 0)
    batch_segments = jnp.repeat(jnp.arange(B, dtype=jnp.int32), A)
    return Zf, Rf, Ff, atom_mask, batch_segments

def preprocess_batch_sparse(Z, R, F, E, Qtot, cutoff):
    src_idx, dst_idx, batch_segments = _build_graph_sparse_eager(Z, R, cutoff)
    Zf, Rf, Ff, atom_mask, _ = _preprocess_common_masks(Z, R, F, E, Qtot)
    # edge-wise mask of ones for the edges we actually have
    batch_mask = jnp.ones((src_idx.shape[0],), dtype=bool)
    return {
        "Z": Zf, "R": Rf, "F": Ff, "E": E, "Qtot": Qtot,
        "atom_mask": atom_mask,
        "batch_mask": batch_mask,            # (E,)
        "batch_segments": batch_segments,    # (B*A,)
        "src_idx": src_idx, "dst_idx": dst_idx,
    }

# ------------------ Public dispatcher ------------------
def preprocess_batch(Z, R, F, E, Qtot, cutoff, graph_mode: str):
    if graph_mode == "dense":
        # return preprocess_batch_dense(Z, R, F, E, Qtot, cutoff)
        pre = preprocess_batch_dense(Z, R, F, E, Qtot, cutoff, tile_j=16)   # 8..32; 16 is a safe start
        return pre
    else:
        return preprocess_batch_sparse(Z, R, F, E, Qtot, cutoff)

def preprocess_batch(Z, R, F, E, Qtot, cutoff, graph_mode: str):
    if graph_mode == "dense":
        return preprocess_batch_dense(Z, R, F, E, Qtot, cutoff, tile_j=16)  # try 8–32
    else:
        return preprocess_batch_sparse(Z, R, F, E, Qtot, cutoff)

# ------------------ Iteration helper (prefetch) ------------------
def iter_loader(loader: PackedMemmapLoader, num_atoms: int, prefetch: int = 2):
    import queue, threading
    def _gen():
        for b in loader.batches(num_atoms=num_atoms):
            yield b
    class _Pref:
        def __init__(self, it, n=2):
            self._it = iter(it); self._q = queue.Queue(maxsize=n); self._done = object()
            def _w():
                try:
                    for x in self._it: self._q.put(x)
                finally: self._q.put(self._done)
            threading.Thread(target=_w, daemon=True).start()
        def __iter__(self): return self
        def __next__(self):
            x = self._q.get()
            if x is self._done: raise StopIteration
            return x
    return _Pref(_gen(), prefetch)

# ------------------ Train / Validate (unchanged API) ------------------
def train_epoch(
    model, params, ema_params, opt_state, transform_state, optimizer,
    loader, energy_weight: float, forces_weight: float, num_atoms: int,
    cutoff: float, graph_mode: str,
):
    train_loss = 0.0; train_energy_mae = 0.0; train_forces_mae = 0.0
    for i, batch_np in enumerate(iter_loader(loader, num_atoms=num_atoms)):
        Z = jnp.asarray(batch_np["Z"]); R = jnp.asarray(batch_np["R"]); F = jnp.asarray(batch_np["F"])
        E = jnp.asarray(batch_np["E"]); Qtot = jnp.asarray(batch_np["Qtot"])
        B = int(Z.shape[0])
        pre = preprocess_batch(Z, R, F, E, Qtot, cutoff, graph_mode)
        (params, ema_params, opt_state, transform_state,
         loss, energy_mae, forces_mae, _) = train_step(
            model_apply=model.apply,
            optimizer_update=optimizer.update,
            transform_state=transform_state,
            batch=pre, batch_size=B,
            energy_weight=energy_weight, forces_weight=forces_weight,
            dipole_weight=0.0, charges_weight=0.0, opt_state=opt_state,
            doCharges=False, params=params, ema_params=ema_params, debug=False,
         )
        train_loss += (loss - train_loss) / (i + 1)
        train_energy_mae += (energy_mae - train_energy_mae) / (i + 1)
        train_forces_mae += (forces_mae - train_forces_mae) / (i + 1)
        if i % 50 == 0:
            print(f"  Batch {i}: Loss={float(loss):.6f}  E_MAE={float(energy_mae):.6f}  F_MAE={float(forces_mae):.6f}")
    return params, ema_params, opt_state, transform_state, float(train_loss), float(train_energy_mae), float(train_forces_mae)

def validate(
    model, ema_params, loader, energy_weight: float, forces_weight: float,
    num_atoms: int, cutoff: float, graph_mode: str,
):
    valid_loss = 0.0; valid_energy_mae = 0.0; valid_forces_mae = 0.0
    for i, batch_np in enumerate(iter_loader(loader, num_atoms=num_atoms)):
        Z = jnp.asarray(batch_np["Z"]); R = jnp.asarray(batch_np["R"]); F = jnp.asarray(batch_np["F"])
        E = jnp.asarray(batch_np["E"]); Qtot = jnp.asarray(batch_np["Qtot"])
        B = int(Z.shape[0])
        pre = preprocess_batch(Z, R, F, E, Qtot, cutoff, graph_mode)
        loss, energy_mae, forces_mae, _ = eval_step(
            model_apply=model.apply, batch=pre, batch_size=B,
            energy_weight=energy_weight, forces_weight=forces_weight,
            dipole_weight=0.0, charges_weight=0.0, charges=False, params=ema_params,
        )
        valid_loss += (loss - valid_loss) / (i + 1)
        valid_energy_mae += (energy_mae - valid_energy_mae) / (i + 1)
        valid_forces_mae += (forces_mae - valid_forces_mae) / (i + 1)
    return float(valid_loss), float(valid_energy_mae), float(valid_forces_mae)

# ---------------------------
def main(args):
    print("="*80)
    print("PhysNet Training (Memmap + JAX graph)")
    print("="*80)
    for k,v in sorted(vars(args).items()):
        print(f"{k}: {v}")
    print("="*80)

    key = jax.random.PRNGKey(args.seed)

    # Build loaders and split
    full_loader = PackedMemmapLoader(
        args.data_path, batch_size=args.batch_size, shuffle=True, bucket_size=args.bucket_size, seed=args.seed
    )
    n_total = full_loader.N
    n_valid = int(n_total * args.valid_split)
    n_train = n_total - n_valid
    print(f"Total molecules: {n_total} | Train: {n_train} | Valid: {n_valid}")

    # train loader
    train_loader = PackedMemmapLoader(
        args.data_path, batch_size=args.batch_size, shuffle=True, bucket_size=args.bucket_size, seed=args.seed
    )
    train_loader.indices = train_loader.indices[:n_train]
    train_loader.N = n_train

    # valid loader
    valid_loader = PackedMemmapLoader(
        args.data_path, batch_size=args.batch_size, shuffle=False, bucket_size=args.bucket_size, seed=args.seed+1
    )
    valid_loader.indices = valid_loader.indices[n_train:]
    valid_loader.N = n_valid

    # Model
    model = EF(
        features=args.features,
        max_degree=args.max_degree,
        num_iterations=args.num_iterations,
        num_basis_functions=args.num_basis_functions,
        cutoff=args.cutoff,
        max_atomic_number=118,
        charges=False,
        natoms=args.num_atoms,
        total_charge=0.0,
        n_res=args.n_res,
        zbl=False,
        debug=False,
        efa=False,
    )

    # Init params (use e3x dense pair indices for the fixed A)
    init_key = jax.random.PRNGKey(args.seed + 1)
    dst_idx_e3x, src_idx_e3x = e3x.ops.sparse_pairwise_indices(args.num_atoms)
    # Get a small sample (one batch) for shapes
    sample_np = next(train_loader.batches(num_atoms=args.num_atoms))
    params = model.init(
        init_key,
        atomic_numbers=jnp.asarray(sample_np["Z"][0]),   # (A,)
        positions=jnp.asarray(sample_np["R"][0]),        # (A,3)
        dst_idx=dst_idx_e3x,
        src_idx=src_idx_e3x,
    )
    n_params = sum(int(x.size) for x in jax.tree_util.tree_leaves(params))
    print(f"Model initialized with {n_params:,} parameters")

    # Optimizer
    optimizer, transform, schedule_fn, optimizer_kwargs = get_optimizer(
        learning_rate=args.learning_rate,
        schedule_fn="constant",
        optimizer="amsgrad",
        transform=None,
    )
    ema_params = params
    opt_state = optimizer.init(params)
    transform_state = transform.init(params)

    best_valid = float("inf")
    ckpt_dir = Path(args.ckpt_dir) if args.ckpt_dir else BASE_CKPT_DIR
    save_root = ckpt_dir / args.name
    save_root.mkdir(parents=True, exist_ok=True)

    print("\nStarting training...")
    print("="*80)
    for epoch in range(1, args.num_epochs + 1):
        t0 = time.time()
        print(f"\nEpoch {epoch}/{args.num_epochs}")
        print("-"*80)

        # Train
        params, ema_params, opt_state, transform_state, tr_loss, tr_e_mae, tr_f_mae = train_epoch(
            model, params, ema_params, opt_state, transform_state, optimizer,
            loader=train_loader,
            energy_weight=args.energy_weight,
            forces_weight=args.forces_weight,
            num_atoms=args.num_atoms,
            cutoff=args.cutoff,
            graph_mode=args.graph_mode,
        )

        # Validate
        va_loss, va_e_mae, va_f_mae = validate(
            model, ema_params, valid_loader,
            energy_weight=args.energy_weight,
            forces_weight=args.forces_weight,
            num_atoms=args.num_atoms,
            cutoff=args.cutoff,
            graph_mode=args.graph_mode,
        )

        dt = time.time() - t0
        print(f"\nEpoch {epoch} Results:")
        print(f"  Train: Loss={tr_loss:.6f}  E_MAE={tr_e_mae:.6f}  F_MAE={tr_f_mae:.6f}")
        print(f"  Valid: Loss={va_loss:.6f}  E_MAE={va_e_mae:.6f}  F_MAE={va_f_mae:.6f}")
        print(f"  Time: {dt:.2f}s")

        if va_loss < best_valid:
            best_valid = va_loss
            print("  → New best validation loss. Saving checkpoint...")
            state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
            ckpt = {
                "model": state,
                "model_attributes": model.return_attributes(),
                "transform_state": transform_state,
                "ema_params": ema_params,
                "params": params,
                "epoch": epoch,
                "opt_state": opt_state,
                "best_loss": best_valid,
                "train_loss": tr_loss,
                "valid_loss": va_loss,
            }
            save_args = orbax_utils.save_args_from_target(ckpt)
            orbax_checkpointer.save(save_root / f"epoch-{epoch}", ckpt, save_args=save_args)

    print("\nDone.")
    print(f"Best validation loss: {best_valid:.6f}")
    print(f"Checkpoints at: {save_root}")


In [4]:
# ---------------------------
# CLI
# ---------------------------
def build_args():
    import argparse
    p = argparse.ArgumentParser("PhysNet Memmap Trainer (dense/sparse graph)")
    # Data
    p.add_argument("--data_path", type=str, required=True)
    p.add_argument("--valid_split", type=float, default=0.125)
    # Model
    p.add_argument("--features", type=int, default=128)
    p.add_argument("--max_degree", type=int, default=0)
    p.add_argument("--num_iterations", type=int, default=5)
    p.add_argument("--num_basis_functions", type=int, default=128)
    p.add_argument("--cutoff", type=float, default=5.0)
    p.add_argument("--num_atoms", type=int, default=128)
    p.add_argument("--n_res", type=int, default=3)
    # Train
    p.add_argument("--batch_size", type=int, default=256)
    p.add_argument("--num_epochs", type=int, default=10)
    p.add_argument("--learning_rate", type=float, default=1e-4)
    p.add_argument("--energy_weight", type=float, default=1000.0)
    p.add_argument("--forces_weight", type=float, default=0.1)
    p.add_argument("--bucket_size", type=int, default=8192)
    p.add_argument("--graph_mode", choices=["dense","sparse"], default="dense")
    # Checkpointing
    p.add_argument("--name", type=str, default="physnet_memmap")
    p.add_argument("--ckpt_dir", type=str, default=None)
    p.add_argument("--seed", type=int, default=0)
    return p.parse_args()

# ---------------------------
# Notebook helper
# ---------------------------
def make_dummy_args(**overrides):
    defaults = {
        "data_path": "openqdc_packed_memmap",
        "valid_split": 0.125,
        "features": 128,
        "max_degree": 0,
        "num_iterations": 5,
        "num_basis_functions": 128,
        "cutoff": 5.0,
        "num_atoms": 128,
        "n_res": 3,
        "batch_size": 64,
        "num_epochs": 10,
        "learning_rate": 1e-4,
        "energy_weight": 1000.0,
        "forces_weight": 0.1,
        "bucket_size": 8192,
        "graph_mode": "dense",   # try "sparse" too
        "name": "physnet_memmap",
        "ckpt_dir": None,
        "seed": 0,
    }
    defaults.update(overrides)
    for k,v in defaults.items():
        print(k,v)
    return type("Args", (), defaults)


In [5]:
BS = 124
MODE = "dense"
args = make_dummy_args(name=f"physnet{BS}{MODE}", batch_size=BS)

data_path openqdc_packed_memmap
valid_split 0.125
features 128
max_degree 0
num_iterations 5
num_basis_functions 128
cutoff 5.0
num_atoms 128
n_res 3
batch_size 124
num_epochs 10
learning_rate 0.0001
energy_weight 1000.0
forces_weight 0.1
bucket_size 8192
graph_mode dense
name physnet124dense
ckpt_dir None
seed 0


In [None]:
main(args)


PhysNet Training (Memmap + JAX graph)
__dict__: <attribute '__dict__' of 'Args' objects>
__doc__: None
__module__: __main__
__weakref__: <attribute '__weakref__' of 'Args' objects>
batch_size: 124
bucket_size: 8192
ckpt_dir: None
cutoff: 5.0
data_path: openqdc_packed_memmap
energy_weight: 1000.0
features: 128
forces_weight: 0.1
graph_mode: dense
learning_rate: 0.0001
max_degree: 0
n_res: 3
name: physnet124dense
num_atoms: 128
num_basis_functions: 128
num_epochs: 10
num_iterations: 5
seed: 0
valid_split: 0.125
Total molecules: 998243 | Train: 873463 | Valid: 124780
Model initialized with 428,280 parameters

Starting training...

Epoch 1/10
--------------------------------------------------------------------------------
  Batch 0: Loss=121267288.000000  E_MAE=461.572144  F_MAE=23.954142
