In [1]:
from trainer import *
import jax
jax.devices()

lovely_jax enabled for enhanced array visualization


[CudaDevice(id=0)]

In [2]:
class NotebookArgs:
    """
    Helper class to mimic argparse.Namespace for Jupyter notebooks.
    Example:
        args = NotebookArgs(
            train_efd="../preclassified_data/energies_forces_dipoles_train.npz",
            train_grid="../preclassified_data/grids_esp_train.npz",
            valid_efd="../preclassified_data/energies_forces_dipoles_valid.npz",
            valid_grid="../preclassified_data/grids_esp_valid.npz",
            batch_size=16,
            epochs=500,
            n_dcm=3,
            verbose=True,
        )
    """

    def __init__(self, **kwargs):
        # Default values ‚Äî should match argparse defaults
        defaults = dict(
            train_efd=None,
            train_grid=None,
            valid_efd=None,
            valid_grid=None,
            features=32,
            max_degree=2,
            num_iterations=2,
            num_basis_functions=32,
            cutoff=10.0,
            n_dcm=3,
            include_pseudotensors=False,
            batch_size=32,
            epochs=100,
            learning_rate=0.001,
            esp_weight=10000.0,
            seed=42,
            restart=None,
            name='co2_dcmnet',
            output_dir='./checkpoints',
            print_freq=10,
            save_freq=5,
            verbose=True,
        )

        # Update defaults with user-specified values
        defaults.update(kwargs)

        # Assign all attributes
        for key, val in defaults.items():
            setattr(self, key, val)

    def as_dict(self):
        """Return arguments as a plain dict (useful for logging)."""
        return vars(self)

In [3]:
from trainer import main, load_co2_data


args = NotebookArgs(
    train_efd=Path("../preclassified_data/energies_forces_dipoles_train.npz"),
    train_grid=Path("../preclassified_data/grids_esp_train.npz"),
    valid_efd=Path("../preclassified_data/energies_forces_dipoles_valid.npz"),
    valid_grid=Path("../preclassified_data/grids_esp_valid.npz"),
    output_dir=Path("./output"),
    batch_size=1000,
    epochs=5,
    learning_rate=5e-4,
)



In [4]:
# Now you can call your functions directly
train_data = load_co2_data(args.train_efd, args.train_grid)
valid_data = load_co2_data(args.valid_efd, args.valid_grid)

# Or if your `main()` function expects args like from argparse:
# 

len(valid_data["R"])

1000

In [5]:
?prepare_datasets
?train_model

[31mSignature:[39m
train_model(
    key,
    model,
    train_data,
    valid_data,
    num_epochs,
    learning_rate,
    batch_size,
    writer,
    ndcm,
    esp_w=[32m1.0[39m,
    chg_w=[32m0.01[39m,
    restart_params=[38;5;28;01mNone[39;00m,
    ema_decay=[32m0.999[39m,
    num_atoms=[32m60[39m,
    use_grad_clip=[38;5;28;01mFalse[39;00m,
    grad_clip_norm=[32m2.0[39m,
    mono_imputation_fn=[38;5;28;01mNone[39;00m,
)
[31mDocstring:[39m
Train DCMNet model with ESP and monopole losses.

Performs full training loop with validation, logging, and checkpointing.
Uses exponential moving average (EMA) for parameter smoothing and saves
best parameters based on validation loss.

Parameters
----------
key : jax.random.PRNGKey
    Random key for training
model : MessagePassingModel
    DCMNet model instance
train_data : dict
    Training dataset dictionary
valid_data : dict
    Validation dataset dictionary
num_epochs : int
    Number of training epochs
learning_rate : 

In [6]:
from train_charge_predictor import load_charge_data

# # Use only HF level
# R, Z, mono = load_charge_data(csv_file, scheme='Hirshfeld', level='hf')

# # Use only MP2 level
# R, Z, mono = load_charge_data(csv_file, scheme='Hirshfeld', level='mp2')

# # Use all levels (default)
# R, Z, mono = load_charge_data(csv_file, scheme='Hirshfeld', level=None)

In [7]:
#!/usr/bin/env python3
"""
Quick example: Train charge predictor on CO2 data

This script demonstrates how to train the gradient boosting charge predictor
using the CO2 charge data.
"""

from pathlib import Path
from train_charge_predictor import load_charge_data, train_charge_predictor

# Path to your data
data_file = Path("../detailed_charges/df_charges_long.csv")

print("="*70)
print("CO2 Charge Predictor Training Example")
print("="*70)

# Load data - you can choose different schemes: Hirshfeld, VDD, Becke, etc.
# and levels: hf, mp2
print("\nLoading data...")
R, Z, mono = load_charge_data(data_file, scheme='MBIS_raw', level='mp2')

# Train models
print("\nTraining models...")
results = train_charge_predictor(
    R=R,
    Z=Z,
    mono=mono,
    test_size=0.2,
    random_state=42,
    n_estimators=100,
    learning_rate=0.1,
    max_depth=5,
    save_path="charge_predictor_MBIS_raw.pkl"
)

print("\n" + "="*70)
print("Training Complete!")
print("="*70)
print(f"\nModel saved to: charge_predictor_hirshfeld.pkl")
print("\nTo use with DCMNet training:")
print("  from train_charge_predictor_usage import create_mono_imputation_fn_from_gb")
print("  mono_imputation_fn = create_mono_imputation_fn_from_gb('charge_predictor_hirshfeld.pkl')")
print("  train_model(..., mono_imputation_fn=mono_imputation_fn)")



CO2 Charge Predictor Training Example

Loading data...
Loaded 27540 rows from ../detailed_charges/df_charges_long.csv
Available schemes: ['Hirshfeld' 'VDD' 'Becke' 'ADCH' 'CHELPG' 'MK' 'CM5' 'MBIS' 'MBIS_raw']
Available levels: ['hf' 'mp2']
Using level: mp2, 13770 rows
Using scheme: MBIS_raw, 1530 rows
Found 510 unique geometry+level combinations

Prepared data:
  R shape: (510, 3, 3)
  Z shape: (510, 3)
  mono shape: (510, 3)

Training models...

Training Gradient Boosting Charge Predictors

Computing molecular features...
Feature matrix shape: (510, 12)

Training model for atom 0 (Z=6)
  Train samples: 408
  Test samples: 102
  Charge range: [0.7880, 1.4690]
  Charge mean: 1.1948, std: 0.1349

  Training...
      Iter       Train Loss      OOB Improve   Remaining Time 
         1           0.0149           0.0029            0.17s
         2           0.0114          -0.0006            0.15s
         3           0.0096           0.0039            0.15s
         4           0.0080     

In [8]:
from train_charge_predictor_usage import create_mono_imputation_fn_from_gb
mono_imputation_fn = create_mono_imputation_fn_from_gb('charge_predictor_MBIS_raw.pkl')
mono_imputation_fn?

[31mSignature:[39m mono_imputation_fn(batch: Dict) -> jax.Array
[31mDocstring:[39m
Impute monopoles for a batch.

Parameters
----------
batch : dict
    Batch dictionary containing 'Z', 'R', 'dst_idx', 'src_idx', 'batch_segments'
    
Returns
-------
jnp.ndarray
    Atomic monopoles with shape (batch_size * num_atoms,)
[31mFile:[39m      ~/mmml/examples/co2/dcmnet_train/train_charge_predictor_usage.py
[31mType:[39m      function

In [9]:
train_data["R"][0].shape, train_data["R"][0].shape

((60, 3), (60, 3))

In [10]:
train_data.keys()

dict_keys(['R', 'Z', 'N', 'esp', 'vdw_surface', 'Dxyz', 'E'])

In [11]:
# Initialize JAX random key
key = jax.random.PRNGKey(args.seed)

In [12]:
import sys
import argparse
from pathlib import Path
import numpy as np
import pickle
from typing import Dict, Tuple, Optional, Any, Mapping

# Add mmml to path
# repo_root = Path(__file__).parent / "../../.."
# sys.path.insert(0, str(repo_root.resolve()))

import jax
import jax.numpy as jnp
from mmml.dcmnet.dcmnet.modules import MessagePassingModel
from mmml.dcmnet.dcmnet.training import train_model
# from mmml.dcmnet.dcmnet.data import prepare_datasets

# Validate input files
for fname, fpath in [
    ('Train EFD', args.train_efd),
    ('Train Grid', args.train_grid),
    ('Valid EFD', args.valid_efd),
    ('Valid Grid', args.valid_grid)
]:
    if not fpath.exists():
        print(f"‚ùå Error: {fname} file not found: {fpath}")
        raise FileNotFoundError(f"{fname} file not found: {fpath}")

print(f"\nüìÅ Data Files:")
print(f"  Train EFD:  {args.train_efd}")
print(f"  Train Grid: {args.train_grid}")
print(f"  Valid EFD:  {args.valid_efd}")
print(f"  Valid Grid: {args.valid_grid}")

# Setup output directory
args.output_dir.mkdir(exist_ok=True, parents=True)
print(f"  Output: {args.output_dir / args.name}")

# Load data
print(f"\n{'#'*70}")
print("# Loading Data")
print(f"{'#'*70}")

if args.verbose:
    print(f"\nLoading training data...")
train_data_raw = load_co2_data(args.train_efd, args.train_grid)

if args.verbose:
    print(f"Loading validation data...")
valid_data_raw = load_co2_data(args.valid_efd, args.valid_grid)

print(f"\n‚úÖ Data loaded:")
print(f"  Training samples: {len(train_data_raw['R'])}")
print(f"  Validation samples: {len(valid_data_raw['R'])}")
print(f"  Data keys: {list(train_data_raw.keys())}")

# Prepare datasets (convert to DCMnet format with edge lists, etc.)
print(f"\nPreparing datasets (computing edge lists, etc.)...")
# train_data, valid_data = prepare_datasets(
#     train_data_raw,
#     valid_data_raw,
#     num_valid = 
#     # cutoff=args.cutoff,
#     # batch_size=args.batch_size,
# )
train_data = train_data_raw
valid_data = valid_data_raw
print(f"‚úÖ Datasets prepared")
print(f"  Training batches: {len(train_data)}")
print(f"  Validation batches: {len(valid_data)}")

# Build model
print(f"\n{'#'*70}")
print("# Building Model")
print(f"{'#'*70}")

print(f"\nModel hyperparameters:")
print(f"  Features: {args.features}")
print(f"  Max degree: {args.max_degree}")
print(f"  Message passing iterations: {args.num_iterations}")
print(f"  Basis functions: {args.num_basis_functions}")
print(f"  Cutoff: {args.cutoff} √Ö")
print(f"  Distributed multipoles per atom: {args.n_dcm}")
print(f"  Include pseudotensors: {args.include_pseudotensors}")

model = MessagePassingModel(
    features=args.features,
    max_degree=args.max_degree,
    num_iterations=args.num_iterations,
    num_basis_functions=args.num_basis_functions,
    cutoff=args.cutoff,
    n_dcm=args.n_dcm,
    include_pseudotensors=args.include_pseudotensors,
)

print(f"\n‚úÖ Model created: DCMNet (n_dcm={args.n_dcm})")

# Training setup
print(f"\n{'#'*70}")
print("# Training Setup")
print(f"{'#'*70}")

print(f"\nTraining hyperparameters:")
print(f"  Batch size: {args.batch_size}")
print(f"  Epochs: {args.epochs}")
print(f"  Learning rate: {args.learning_rate}")
print(f"  ESP weight: {args.esp_weight}")
print(f"  Random seed: {args.seed}")

# Load restart parameters if provided
restart_params = None
if args.restart:
    print(f"\nüìÇ Loading restart checkpoint: {args.restart}")
    with open(args.restart, 'rb') as f:
        restart_params = pickle.load(f)
    print(f"‚úÖ Checkpoint loaded")


üìÅ Data Files:
  Train EFD:  ../preclassified_data/energies_forces_dipoles_train.npz
  Train Grid: ../preclassified_data/grids_esp_train.npz
  Valid EFD:  ../preclassified_data/energies_forces_dipoles_valid.npz
  Valid Grid: ../preclassified_data/grids_esp_valid.npz
  Output: output/co2_dcmnet

######################################################################
# Loading Data
######################################################################

Loading training data...
Loading validation data...

‚úÖ Data loaded:
  Training samples: 8000
  Validation samples: 1000
  Data keys: ['R', 'Z', 'N', 'esp', 'vdw_surface', 'Dxyz', 'E']

Preparing datasets (computing edge lists, etc.)...
‚úÖ Datasets prepared
  Training batches: 7
  Validation batches: 7

######################################################################
# Building Model
######################################################################

Model hyperparameters:
  Features: 32
  Max degree: 2
  Message passing iter

In [13]:
prepare_datasets?

[31mSignature:[39m
prepare_datasets(
    key,
    num_train,
    num_valid,
    filename,
    natoms=[32m60[39m,
    clean=[38;5;28;01mFalse[39;00m,
    esp_mask=[38;5;28;01mFalse[39;00m,
    clip_esp=[38;5;28;01mFalse[39;00m,
)
[31mDocstring:[39m
Prepare datasets for training and validation.

Wrapper function that calls prepare_multiple_datasets and then
creates train/validation splits and dictionaries.

Parameters
----------
key : jax.random.PRNGKey
    Random key for dataset shuffling
num_train : int
    Number of training samples
num_valid : int
    Number of validation samples
filename : str or list
    Filename(s) to load datasets from
clean : bool, optional
    Whether to filter failed calculations, by default False
esp_mask : bool, optional
    Whether to create ESP masks, by default False
clip_esp : bool, optional
    Whether to clip ESP to first 1000 points, by default False
natoms : int, optional
    Maximum number of atoms per system, by default 60

Returns
----

In [14]:
train_data["esp"]

array([[ 0.0052505,  0.0033848,  0.037617 , ...,  0.0093954, -0.023834 ,
         0.0059184],
       [ 0.0043249,  0.0058309, -0.0074244, ..., -0.017754 ,  0.013946 ,
        -0.006544 ],
       [ 0.018399 ,  0.011036 ,  0.0095978, ..., -0.011879 , -0.010322 ,
         0.0057994],
       ...,
       [ 0.042213 ,  0.032696 , -0.011879 , ..., -0.0024751,  0.032268 ,
         0.015244 ],
       [-0.0083788,  0.0076804, -0.010421 , ..., -0.010508 ,  0.0033204,
        -0.0011256],
       [-0.031544 , -0.0094387, -0.0063055, ..., -0.014822 , -0.002533 ,
        -0.0051037]], shape=(8000, 3000), dtype=float32)

In [15]:
train_model?

[31mSignature:[39m
train_model(
    key,
    model,
    train_data,
    valid_data,
    num_epochs,
    learning_rate,
    batch_size,
    writer,
    ndcm,
    esp_w=[32m1.0[39m,
    chg_w=[32m0.01[39m,
    restart_params=[38;5;28;01mNone[39;00m,
    ema_decay=[32m0.999[39m,
    num_atoms=[32m60[39m,
    use_grad_clip=[38;5;28;01mFalse[39;00m,
    grad_clip_norm=[32m2.0[39m,
    mono_imputation_fn=[38;5;28;01mNone[39;00m,
)
[31mDocstring:[39m
Train DCMNet model with ESP and monopole losses.

Performs full training loop with validation, logging, and checkpointing.
Uses exponential moving average (EMA) for parameter smoothing and saves
best parameters based on validation loss.

Parameters
----------
key : jax.random.PRNGKey
    Random key for training
model : MessagePassingModel
    DCMNet model instance
train_data : dict
    Training dataset dictionary
valid_data : dict
    Validation dataset dictionary
num_epochs : int
    Number of training epochs
learning_rate : 

In [None]:

# Start training
print(f"\n{'='*70}")
print("STARTING TRAINING")
print(f"{'='*70}\n")

try:
    final_params = train_model(
        key=key,
        model=model,
        train_data=train_data,
        valid_data=valid_data,
        num_epochs=args.epochs,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        esp_w=args.esp_weight,
        restart_params=restart_params,
        writer=None,
        ndcm=args.n_dcm,
        mono_imputation_fn=mono_imputation_fn,
        # num_atoms = 3,
        # tag=args.name,
        # output_dir=args.output_dir,
        # print_freq=args.print_freq,
        # save_freq=args.save_freq,
    )
    
    # Save final model
    final_path = args.output_dir / f"{args.name}_final.pkl"
    with open(final_path, 'wb') as f:
        pickle.dump(final_params, f)
    
    print(f"\n{'='*70}")
    print("‚úÖ TRAINING COMPLETE!")
    print(f"{'='*70}")
    print(f"\nFinal parameters saved to: {final_path}")
    print(f"\nTo use the trained model:")
    print(f"  from mmml.dcmnet.dcmnet.modules import MessagePassingModel")
    print(f"  import pickle")
    print(f"  ")
    print(f"  # Load parameters")
    print(f"  with open('{final_path}', 'rb') as f:")
    print(f"      params = pickle.load(f)")
    print(f"  ")
    print(f"  # Create model and predict")
    print(f"  model = MessagePassingModel(...)")
    print(f"  mono, dipo = model.apply(params, Z, R, dst_idx, src_idx)")
    print(f"  ")
    print(f"  # Calculate ESP")
    print(f"  from mmml.dcmnet.dcmnet.electrostatics import calc_esp")
    print(f"  esp_pred = calc_esp(mono, dipo, R, vdw_surface)")
    
except KeyboardInterrupt:
    print(f"\n\n‚ö†Ô∏è  Training interrupted by user")
    print(f"Checkpoints saved to: {args.output_dir}")
    sys.exit(0)
except Exception as e:
    print(f"\n\n‚ùå Training failed with error:")
    print(f"  {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)


STARTING TRAINING

