In [1]:
from trainer import *
import jax
jax.devices()

lovely_jax enabled for enhanced array visualization


[CudaDevice(id=0)]

In [2]:
class NotebookArgs:
    """
    Helper class to mimic argparse.Namespace for Jupyter notebooks.
    Example:
        args = NotebookArgs(
            train_efd="../preclassified_data/energies_forces_dipoles_train.npz",
            train_grid="../preclassified_data/grids_esp_train.npz",
            valid_efd="../preclassified_data/energies_forces_dipoles_valid.npz",
            valid_grid="../preclassified_data/grids_esp_valid.npz",
            batch_size=16,
            epochs=500,
            n_dcm=3,
            verbose=True,
        )
    """

    def __init__(self, **kwargs):
        # Default values — should match argparse defaults
        defaults = dict(
            train_efd=None,
            train_grid=None,
            valid_efd=None,
            valid_grid=None,
            features=32,
            max_degree=2,
            num_iterations=2,
            num_basis_functions=32,
            cutoff=10.0,
            n_dcm=3,
            include_pseudotensors=False,
            batch_size=32,
            epochs=100,
            learning_rate=0.001,
            esp_weight=10000.0,
            seed=42,
            restart=None,
            name='co2_dcmnet',
            output_dir='./checkpoints',
            print_freq=10,
            save_freq=5,
            verbose=True,
        )

        # Update defaults with user-specified values
        defaults.update(kwargs)

        # Assign all attributes
        for key, val in defaults.items():
            setattr(self, key, val)

    def as_dict(self):
        """Return arguments as a plain dict (useful for logging)."""
        return vars(self)

In [3]:
from trainer import main, load_co2_data


args = NotebookArgs(
    train_efd=Path("../preclassified_data/energies_forces_dipoles_train.npz"),
    train_grid=Path("../preclassified_data/grids_esp_train.npz"),
    valid_efd=Path("../preclassified_data/energies_forces_dipoles_valid.npz"),
    valid_grid=Path("../preclassified_data/grids_esp_valid.npz"),
    output_dir=Path("./output"),
    batch_size=1000,
    epochs=5,
    learning_rate=5e-4,
)



In [4]:
# Now you can call your functions directly
train_data = load_co2_data(args.train_efd, args.train_grid)
valid_data = load_co2_data(args.valid_efd, args.valid_grid)

# Or if your `main()` function expects args like from argparse:
# 

len(valid_data["R"])

1000

In [5]:
?prepare_datasets
?train_model

[31mSignature:[39m
train_model(
    key,
    model,
    train_data,
    valid_data,
    num_epochs,
    learning_rate,
    batch_size,
    writer,
    ndcm,
    esp_w=[32m1.0[39m,
    chg_w=[32m0.01[39m,
    restart_params=[38;5;28;01mNone[39;00m,
    ema_decay=[32m0.999[39m,
    num_atoms=[32m60[39m,
    use_grad_clip=[38;5;28;01mFalse[39;00m,
    grad_clip_norm=[32m2.0[39m,
    mono_imputation_fn=[38;5;28;01mNone[39;00m,
)
[31mDocstring:[39m
Train DCMNet model with ESP and monopole losses.

Performs full training loop with validation, logging, and checkpointing.
Uses exponential moving average (EMA) for parameter smoothing and saves
best parameters based on validation loss.

Parameters
----------
key : jax.random.PRNGKey
    Random key for training
model : MessagePassingModel
    DCMNet model instance
train_data : dict
    Training dataset dictionary
valid_data : dict
    Validation dataset dictionary
num_epochs : int
    Number of training epochs
learning_rate : 

In [6]:
from train_charge_predictor import load_charge_data

# # Use only HF level
# R, Z, mono = load_charge_data(csv_file, scheme='Hirshfeld', level='hf')

# # Use only MP2 level
# R, Z, mono = load_charge_data(csv_file, scheme='Hirshfeld', level='mp2')

# # Use all levels (default)
# R, Z, mono = load_charge_data(csv_file, scheme='Hirshfeld', level=None)

In [7]:
#!/usr/bin/env python3
"""
Quick example: Train charge predictor on CO2 data

This script demonstrates how to train the gradient boosting charge predictor
using the CO2 charge data.
"""

from pathlib import Path
from train_charge_predictor import load_charge_data, train_charge_predictor

# Path to your data
data_file = Path("../detailed_charges/df_charges_long.csv")

print("="*70)
print("CO2 Charge Predictor Training Example")
print("="*70)

# Load data - you can choose different schemes: Hirshfeld, VDD, Becke, etc.
# and levels: hf, mp2
print("\nLoading data...")
R, Z, mono = load_charge_data(data_file, scheme='Hirshfeld', level='hf')

# Train models
print("\nTraining models...")
results = train_charge_predictor(
    R=R,
    Z=Z,
    mono=mono,
    test_size=0.2,
    random_state=42,
    n_estimators=100,
    learning_rate=0.1,
    max_depth=5,
    save_path="charge_predictor_hirshfeld.pkl"
)

print("\n" + "="*70)
print("Training Complete!")
print("="*70)
print(f"\nModel saved to: charge_predictor_hirshfeld.pkl")
print("\nTo use with DCMNet training:")
print("  from train_charge_predictor_usage import create_mono_imputation_fn_from_gb")
print("  mono_imputation_fn = create_mono_imputation_fn_from_gb('charge_predictor_hirshfeld.pkl')")
print("  train_model(..., mono_imputation_fn=mono_imputation_fn)")



CO2 Charge Predictor Training Example

Loading data...
Loaded 27540 rows from ../detailed_charges/df_charges_long.csv
Available schemes: ['Hirshfeld' 'VDD' 'Becke' 'ADCH' 'CHELPG' 'MK' 'CM5' 'MBIS' 'MBIS_raw']
Available levels: ['hf' 'mp2']
Using level: hf, 13770 rows
Using scheme: Hirshfeld, 1530 rows
Found 510 unique geometry+level combinations

Prepared data:
  R shape: (510, 3, 3)
  Z shape: (510, 3)
  mono shape: (510, 3)

Training models...


NameError: name '__file__' is not defined

In [None]:
from train_charge_predictor_usage import create_mono_imputation_fn_from_gb
mono_imputation_fn = create_mono_imputation_fn_from_gb('MBIS_raw')
mono_imputation_fn?

In [None]:
train_data["R"][0].shape, train_data["R"][0].shape

In [None]:
train_data.keys()

In [13]:

# Start training
print(f"\n{'='*70}")
print("STARTING TRAINING")
print(f"{'='*70}\n")

try:
    final_params = train_model(
        key=key,
        model=model,
        train_data=train_data,
        valid_data=valid_data,
        num_epochs=args.epochs,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        esp_w=args.esp_weight,
        restart_params=restart_params,
        writer=None,
        ndcm=args.n_dcm,
        # tag=args.name,
        # output_dir=args.output_dir,
        # print_freq=args.print_freq,
        # save_freq=args.save_freq,
    )
    
    # Save final model
    final_path = args.output_dir / f"{args.name}_final.pkl"
    with open(final_path, 'wb') as f:
        pickle.dump(final_params, f)
    
    print(f"\n{'='*70}")
    print("✅ TRAINING COMPLETE!")
    print(f"{'='*70}")
    print(f"\nFinal parameters saved to: {final_path}")
    print(f"\nTo use the trained model:")
    print(f"  from mmml.dcmnet.dcmnet.modules import MessagePassingModel")
    print(f"  import pickle")
    print(f"  ")
    print(f"  # Load parameters")
    print(f"  with open('{final_path}', 'rb') as f:")
    print(f"      params = pickle.load(f)")
    print(f"  ")
    print(f"  # Create model and predict")
    print(f"  model = MessagePassingModel(...)")
    print(f"  mono, dipo = model.apply(params, Z, R, dst_idx, src_idx)")
    print(f"  ")
    print(f"  # Calculate ESP")
    print(f"  from mmml.dcmnet.dcmnet.electrostatics import calc_esp")
    print(f"  esp_pred = calc_esp(mono, dipo, R, vdw_surface)")
    
except KeyboardInterrupt:
    print(f"\n\n⚠️  Training interrupted by user")
    print(f"Checkpoints saved to: {args.output_dir}")
    sys.exit(0)
except Exception as e:
    print(f"\n\n❌ Training failed with error:")
    print(f"  {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)


STARTING TRAINING

Preparing batches
..................
Training
..................


❌ Training failed with error:
  'mono'


Traceback (most recent call last):
  File "/scratch/boitti0000/slurm-job.61203577/ipykernel_2188079/908527186.py", line 7, in <module>
    final_params = train_model(
                   ^^^^^^^^^^^^
  File "/scicore/home/meuwly/boitti0000/mmml/mmml/dcmnet/dcmnet/training.py", line 337, in train_model
    train_esp_targets = []
                           
  File "/scicore/home/meuwly/boitti0000/.conda/envs/mmml-gpu/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/scicore/home/meuwly/boitti0000/.conda/envs/mmml-gpu/lib/python3.12/site-packages/jax/_src/pjit.py", line 263, in cache_miss
    executable, pgle_profiler, const_args) = _python_pjit_helper(
                                             ^^^^^^^^^^^^^^^^^^^^
  File "/scicore/home/meuwly/boitti0000/.conda/envs/mmml-gpu/lib/python3.12/site-packages/jax/_src/pjit.py", line 136, in _python_pjit_helper
    p, a

SystemExit: 1