In [1]:
from trainer import *
import jax
jax.devices()

lovely_jax enabled for enhanced array visualization


[CudaDevice(id=0)]

In [2]:
class NotebookArgs:
    """
    Helper class to mimic argparse.Namespace for Jupyter notebooks.
    Example:
        args = NotebookArgs(
            train_efd="../preclassified_data/energies_forces_dipoles_train.npz",
            train_grid="../preclassified_data/grids_esp_train.npz",
            valid_efd="../preclassified_data/energies_forces_dipoles_valid.npz",
            valid_grid="../preclassified_data/grids_esp_valid.npz",
            batch_size=16,
            epochs=500,
            n_dcm=3,
            verbose=True,
        )
    """

    def __init__(self, **kwargs):
        # Default values ‚Äî should match argparse defaults
        defaults = dict(
            train_efd=None,
            train_grid=None,
            valid_efd=None,
            valid_grid=None,
            features=32,
            max_degree=2,
            num_iterations=2,
            num_basis_functions=32,
            cutoff=10.0,
            n_dcm=3,
            include_pseudotensors=False,
            batch_size=32,
            epochs=100,
            learning_rate=0.001,
            esp_weight=10000.0,
            seed=42,
            restart=None,
            name='co2_dcmnet',
            output_dir='./checkpoints',
            print_freq=10,
            save_freq=5,
            verbose=True,
        )

        # Update defaults with user-specified values
        defaults.update(kwargs)

        # Assign all attributes
        for key, val in defaults.items():
            setattr(self, key, val)

    def as_dict(self):
        """Return arguments as a plain dict (useful for logging)."""
        return vars(self)

In [29]:
from trainer import main, load_co2_data


args = NotebookArgs(
    train_efd=Path("../preclassified_data/energies_forces_dipoles_train.npz"),
    train_grid=Path("../preclassified_data/grids_esp_train.npz"),
    valid_efd=Path("../preclassified_data/energies_forces_dipoles_valid.npz"),
    valid_grid=Path("../preclassified_data/grids_esp_valid.npz"),
    output_dir=Path("./output"),
    batch_size=1000,
    epochs=5,
    learning_rate=5e-4,
)



In [30]:
# Now you can call your functions directly
train_data = load_co2_data(args.train_efd, args.train_grid)
valid_data = load_co2_data(args.valid_efd, args.valid_grid)

# Or if your `main()` function expects args like from argparse:
# 

len(valid_data["R"])

1000

In [31]:
print("="*70)
print("DCMNet Training - CO2 ESP Data")
print("="*70)

# Validate input files
for fname, fpath in [
    ('Train EFD', args.train_efd),
    ('Train Grid', args.train_grid),
    ('Valid EFD', args.valid_efd),
    ('Valid Grid', args.valid_grid)
]:
    if not fpath.exists():
        print(f"‚ùå Error: {fname} file not found: {fpath}")
        sys.exit(1)

print(f"\nüìÅ Data Files:")
print(f"  Train EFD:  {args.train_efd}")
print(f"  Train Grid: {args.train_grid}")
print(f"  Valid EFD:  {args.valid_efd}")
print(f"  Valid Grid: {args.valid_grid}")

# Setup output directory
args.output_dir.mkdir(exist_ok=True, parents=True)
print(f"  Output: {args.output_dir / args.name}")

# Load data
print(f"\n{'#'*70}")
print("# Loading Data")
print(f"{'#'*70}")

if args.verbose:
    print(f"\nLoading training data...")
train_data_raw = load_co2_data(args.train_efd, args.train_grid)

if args.verbose:
    print(f"Loading validation data...")
valid_data_raw = load_co2_data(args.valid_efd, args.valid_grid)

print(f"\n‚úÖ Data loaded:")
print(f"  Training samples: {len(train_data_raw['R'])}")
print(f"  Validation samples: {len(valid_data_raw['R'])}")
print(f"  Data keys: {list(train_data_raw.keys())}")

# Prepare datasets (convert to DCMnet format with edge lists, etc.)
print(f"\nPreparing datasets (computing edge lists, etc.)...")
train_data, valid_data = prepare_datasets(
    train_data_raw,
    valid_data_raw,
    cutoff=args.cutoff,
    batch_size=args.batch_size,
)

print(f"‚úÖ Datasets prepared")
print(f"  Training batches: {len(train_data)}")
print(f"  Validation batches: {len(valid_data)}")

# Build model
print(f"\n{'#'*70}")
print("# Building Model")
print(f"{'#'*70}")

print(f"\nModel hyperparameters:")
print(f"  Features: {args.features}")
print(f"  Max degree: {args.max_degree}")
print(f"  Message passing iterations: {args.num_iterations}")
print(f"  Basis functions: {args.num_basis_functions}")
print(f"  Cutoff: {args.cutoff} √Ö")
print(f"  Distributed multipoles per atom: {args.n_dcm}")
print(f"  Include pseudotensors: {args.include_pseudotensors}")

model = MessagePassingModel(
    features=args.features,
    max_degree=args.max_degree,
    num_iterations=args.num_iterations,
    num_basis_functions=args.num_basis_functions,
    cutoff=args.cutoff,
    n_dcm=args.n_dcm,
    include_pseudotensors=args.include_pseudotensors,
)

print(f"\n‚úÖ Model created: DCMNet (n_dcm={args.n_dcm})")

# Training setup
print(f"\n{'#'*70}")
print("# Training Setup")
print(f"{'#'*70}")

print(f"\nTraining hyperparameters:")
print(f"  Batch size: {args.batch_size}")
print(f"  Epochs: {args.epochs}")
print(f"  Learning rate: {args.learning_rate}")
print(f"  ESP weight: {args.esp_weight}")
print(f"  Random seed: {args.seed}")

# Load restart parameters if provided
restart_params = None
if args.restart:
    print(f"\nüìÇ Loading restart checkpoint: {args.restart}")
    with open(args.restart, 'rb') as f:
        restart_params = pickle.load(f)
    print(f"‚úÖ Checkpoint loaded")

# Initialize JAX random key
key = jax.random.PRNGKey(args.seed)

# Start training
print(f"\n{'='*70}")
print("STARTING TRAINING")
print(f"{'='*70}\n")

try:
    final_params = train_model(
        key=key,
        model=model,
        train_data=train_data,
        valid_data=valid_data,
        num_epochs=args.epochs,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        esp_w=args.esp_weight,
        restart_params=restart_params,
        name=args.name,
        output_dir=args.output_dir,
        print_freq=args.print_freq,
        save_freq=args.save_freq,
    )
    
    # Save final model
    final_path = args.output_dir / f"{args.name}_final.pkl"
    with open(final_path, 'wb') as f:
        pickle.dump(final_params, f)
    
    print(f"\n{'='*70}")
    print("‚úÖ TRAINING COMPLETE!")
    print(f"{'='*70}")
    print(f"\nFinal parameters saved to: {final_path}")
    print(f"\nTo use the trained model:")
    print(f"  from mmml.dcmnet.dcmnet.modules import MessagePassingModel")
    print(f"  import pickle")
    print(f"  ")
    print(f"  # Load parameters")
    print(f"  with open('{final_path}', 'rb') as f:")
    print(f"      params = pickle.load(f)")
    print(f"  ")
    print(f"  # Create model and predict")
    print(f"  model = MessagePassingModel(...)")
    print(f"  mono, dipo = model.apply(params, Z, R, dst_idx, src_idx)")
    print(f"  ")
    print(f"  # Calculate ESP")
    print(f"  from mmml.dcmnet.dcmnet.electrostatics import calc_esp")
    print(f"  esp_pred = calc_esp(mono, dipo, R, vdw_surface)")
    
except KeyboardInterrupt:
    print(f"\n\n‚ö†Ô∏è  Training interrupted by user")
    print(f"Checkpoints saved to: {args.output_dir}")
    sys.exit(0)
except Exception as e:
    print(f"\n\n‚ùå Training failed with error:")
    print(f"  {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

DCMNet Training - CO2 ESP Data

üìÅ Data Files:
  Train EFD:  ../preclassified_data/energies_forces_dipoles_train.npz
  Train Grid: ../preclassified_data/grids_esp_train.npz
  Valid EFD:  ../preclassified_data/energies_forces_dipoles_valid.npz
  Valid Grid: ../preclassified_data/grids_esp_valid.npz
  Output: output/co2_dcmnet

######################################################################
# Loading Data
######################################################################

Loading training data...
Loading validation data...

‚úÖ Data loaded:
  Training samples: 8000
  Validation samples: 1000
  Data keys: ['R', 'Z', 'N', 'esp', 'vdw_surface', 'Dxyz', 'E']

Preparing datasets (computing edge lists, etc.)...


TypeError: prepare_datasets() got an unexpected keyword argument 'cutoff'