In [1]:
import mmml
import matplotlib.pyplot as plt
import patchworklib as pw
import os
from pathlib import Path
import numpy as np
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import jax
devices = jax.local_devices()
print(devices)
print(jax.default_backend())
print(jax.devices())

# Import DCMNET components
from mmml.dcmnet.dcmnet.models import DCM1, DCM2, DCM3, DCM4, dcm1_params, dcm2_params, dcm3_params, dcm4_params

from mmml.dcmnet.dcmnet.models import models, model_params
from mmml.dcmnet.dcmnet.loss import esp_mono_loss
from mmml.dcmnet.dcmnet.electrostatics import calc_esp
from mmml.dcmnet.dcmnet.utils import apply_model

import numpy as np
import jax
import jax.numpy as jnp
import sys
import os


from mmml.dcmnet.dcmnet_mcts import optimize_dcmnet_combination, DCMNETSelectionEnv

# make bwr the default colormap
plt.set_cmap('bwr')


[CudaDevice(id=0)]
gpu
[CudaDevice(id=0)]


<Figure size 100x100 with 0 Axes>

In [2]:
def rotate(XYZ, rot):
    # rotate XYZ by rot degrees around the y-axis
    return np.dot(XYZ, np.array([[np.cos(rot), 0, np.sin(rot)],
                                 [0, 1, 0],
                                 [-np.sin(rot), 0, np.cos(rot)]]))

def  plot_esp_in_3d(output, rot=False, KEY="esp_target"):
    # 3d plot of ESP    
    ax = plt.axes(projection='3d')

    # remove far away points
    mask = np.linalg.norm(output['initial_vdw_surface'], axis=-1) < 10

    if rot and type(rot) == float:
        XYZ = rotate(output['initial_vdw_surface'], rot)
    else:
        XYZ = output['initial_vdw_surface']

    s = ax.scatter(*XYZ[mask].T, 
    c=output[KEY][mask].flatten(), s=15, vmin=-0.1, vmax=0.1,    )


    MIN_CO = np.min(XYZ[mask])
    MAX_CO = np.max(XYZ[mask])

    # even xyz ranges
    ax.set_xlim(MIN_CO, MAX_CO)
    ax.set_ylim(MIN_CO, MAX_CO)
    ax.set_zlim(MIN_CO, MAX_CO)
    plt.colorbar(s)
    return ax






In [3]:
# ax = plt.axes()
# # ax.axis('equal')
# MinV, MaxV = np.min(output['esp_target']), np.max(output['esp_target'])
# MinV, MaxV = -0.05, 0.05
# ax.scatter(output['esp_target'], 0.5 * output['esp_pred'])
# ax.set_xlim(MinV, MaxV)
# ax.set_ylim(MinV, MaxV)
# ax.plot([0, 1], [0, 1], 'k--', transform=ax.transAxes)

# ax.set_xlabel('Target ESP')
# ax.set_ylabel('Predicted ESP')
# ax.set_title('ESP Prediction vs Target')
# ax



In [4]:
import jax

from mmml.dcmnet.dcmnet.data import prepare_datasets
from mmml.dcmnet.dcmnet.modules import MessagePassingModel
from mmml.dcmnet.dcmnet.training import train_model, train_model_dipo

key = jax.random.PRNGKey(0)


lovely_jax enabled for enhanced array visualization


In [5]:
NDCM = 7
model = MessagePassingModel(
    features=32, max_degree=2, num_iterations=2,
    num_basis_functions=32, cutoff=10.0, n_dcm=NDCM,
    include_pseudotensors=False,
)

## Data

In [6]:
current_path = Path.cwd()
print(current_path)
data_path = Path("/home/ericb")
data_files = list(data_path.glob("*npz"))
for i in range(len(data_files)):
    print(i, data_files[i])

/home/ericb/mmml/notebooks/dcmnet
0 /home/ericb/esp2000.npz
1 /home/ericb/RZ.npz
2 /home/ericb/beta-diketones_71208.npz


In [7]:
# for k in data_loaded.keys():
#     print(k)
#     shape = data_loaded[k].shape
#     print(shape
#     )
#     if len(shape) < 3:
#         d = data_loaded[k]
#         d = d.flatten()
#         plt.hist(d)
#         title = f"{k}: {d.min()} - {d.max()}"
#         plt.title(title)
#         plt.show()

In [8]:
train_data, valid_data = prepare_datasets(
    key, 1800, 200, ["/home/ericb/esp2000.npz"],
    esp_mask=True,
)

shape (2000, 60, 3)
R (2000, 60, 3)
(2000, 60, 3)
['R', 'Z', 'N', 'mono', 'esp', 'vdw_surface', 'n_grid', 'espMask']
2000
0 R 2000 (2000, 60, 3)
1 Z 2000 (2000, 60)
2 N 2000 (2000, 1)
3 mono 2000 (2000, 60)
4 esp 2000 (2000, 4953)
5 vdw_surface 2000 (2000, 4953, 3)
6 n_grid 2000 (2000,)
7 espMask 2000 (2000, 4953)


## Training (1)

In [9]:
train_model?

[31mSignature:[39m
train_model(
    key,
    model,
    train_data,
    valid_data,
    num_epochs,
    learning_rate,
    batch_size,
    writer,
    ndcm,
    esp_w=[32m1.0[39m,
    chg_w=[32m0.01[39m,
    restart_params=[38;5;28;01mNone[39;00m,
    ema_decay=[32m0.999[39m,
    num_atoms=[32m60[39m,
    use_grad_clip=[38;5;28;01mFalse[39;00m,
    grad_clip_norm=[32m2.0[39m,
)
[31mDocstring:[39m
Train DCMNet model with ESP and monopole losses.

Performs full training loop with validation, logging, and checkpointing.
Uses exponential moving average (EMA) for parameter smoothing and saves
best parameters based on validation loss.

Parameters
----------
key : jax.random.PRNGKey
    Random key for training
model : MessagePassingModel
    DCMNet model instance
train_data : dict
    Training dataset dictionary
valid_data : dict
    Validation dataset dictionary
num_epochs : int
    Number of training epochs
learning_rate : float
    Learning rate for optimization
batch_siz

In [10]:
# NDCM = 7
models = []
paramsco = []

for NDCM in range(1,8):

    model = MessagePassingModel(
        features=64, max_degree=2, num_iterations=2,
        num_basis_functions=32, cutoff=10.0, n_dcm=NDCM,
        include_pseudotensors=False,
    )
    new_params = None
    
    params, valid_loss = train_model(
        key=key, model=model,
        writer=None,
        train_data=train_data, valid_data=valid_data,
        num_epochs=100, learning_rate=1e-3, batch_size=1,
        ndcm=model.n_dcm, esp_w=10000.0,
    )
    new_params, valid_loss = train_model(
        key=key, model=model,
        writer=None,
        train_data=train_data, valid_data=valid_data,
        num_epochs=100, learning_rate=5e-4, batch_size=1,
        ndcm=model.n_dcm, esp_w=10000.0,
        restart_params=params if new_params is None else new_params,
    )
    new_params, valid_loss = train_model(
        key=key, model=model,
        writer=None,
        train_data=train_data, valid_data=valid_data,
        num_epochs=200, learning_rate=1e-4, batch_size=1,
        ndcm=model.n_dcm, esp_w=10000.0,
        restart_params=params if new_params is None else new_params,
    )
    np.save(f"modelB{NDCM}", new_params)
    models.append(model)
    paramsco.append(new_params)


Preparing batches
..................
Training
..................


TypeError: cannot reshape array of shape (60, 3, 1) (size 180) into shape (1, 18, 3) (size 54)

In [None]:
def fig(output, batch):
    """ plot results of dcmnet_analysis """
    import patchworklib as pw
    VMAX = 0.01
    xy_ax = pw.Brick()
    xy_ax.scatter(batch["esp"], output['esp_pred'], s=1)
    max_val = np.sqrt(max(np.max(batch["esp"]**2), np.max(output['esp_pred']**2)))
    xy_ax.plot(np.linspace(-max_val, max_val, 100), np.linspace(-max_val, max_val, 100))
    xy_ax.set_aspect('equal')

    ax_true = pw.Brick()
    Npoints = 4150
    vdw_surface_min = np.min(batch["vdw_surface"][0], axis=0)
    vdw_surface_max = np.max(batch["vdw_surface"][0], axis=0)

    ax_true.scatter(
        batch["vdw_surface"][0][:Npoints,0], 
    batch["vdw_surface"][0][:Npoints,1], 
    c=batch["esp"][0][:Npoints],
    s=0.01,
        vmin=-VMAX, vmax=VMAX
    )
    max_val = np.sqrt(max(np.max(batch["esp"]**2), np.max(output['esp_pred']**2)))
    # ax.plot(np.linspace(-max_val, max_val, 100), np.linspace(-max_val, max_val, 100))
    ax_true.set_aspect('equal')

    ax_pred = pw.Brick()

    ax_pred.scatter(
        batch["vdw_surface"][0][:Npoints,0], 
    batch["vdw_surface"][0][:Npoints,1], 
    c=output['esp_pred'][:Npoints],
    s=0.01,
        vmin=-VMAX, vmax=VMAX
    )
    max_val = np.sqrt(max(np.max(batch["esp"]**2), np.max(output['esp_pred']**2)))
    # ax.plot(np.linspace(-max_val, max_val, 100), np.linspace(-max_val, max_val, 100))
    ax_pred.set_aspect('equal')


    ax_diff = pw.Brick()
    ax_diff.scatter(
        batch["vdw_surface"][0][:Npoints,0], 
    batch["vdw_surface"][0][:Npoints,1], 
    c=batch["esp"][0][:Npoints] - output['esp_pred'][:Npoints],
    s=0.01,
        vmin=-VMAX, vmax=VMAX
    )
    ax_diff.set_aspect('equal')

    for _ in [ax_pred, ax_true, ax_diff]:
        _.set_xlim(vdw_surface_min[0], -vdw_surface_min[0])
        _.set_ylim(vdw_surface_min[1], -vdw_surface_min[0])

    charge_ax = pw.Brick()
    charge_ax.matshow(output["mono"][0][:int(batch["N"])],vmin=-1,vmax=1)
    scharge_ax = pw.Brick()
    scharge_ax.matshow(output["mono"][0][:int(batch["N"])].sum(axis=-1)[:, None],vmin=-1,vmax=1)
    # scharge_ax.add_colorbar(vmin=-1,vmax=1)
    scharge_ax.axis("off")
    f = xy_ax | ((ax_pred | ax_true | ax_diff) /  (scharge_ax | charge_ax))
    f.add_colorbar(vmin=-1,vmax=1)
    return f



In [None]:
valid_data.keys()

In [None]:
def get_3d_views(output, batch):
    R = output["dipo"][:int(batch["N"])*NDCM]
    Z = np.array([1 if _ > 0 else 1 for _ in output["mono"][0][:int(batch["N"])].flatten()])
    R.shape, Z.shape
    dcm_atoms = ase.Atoms(Z, R)
    view(dcm_atoms,  viewer="x3d",
    viewer_kwargs={"width": 1000, "height": 1000, "show_unit_cell": 1})
    import ase
    from ase.visualize import view
    atoms = ase.Atoms(batch["Z"][:int(batch["N"])], 
    batch["R"][:int(batch["N"])])
    view(atoms, viewer="x3d")

In [None]:
batch


In [None]:
print("1. Molecular System:")
print(f"   - {len(molecular_data['Z'])} atoms")
print(f"   - Atomic numbers: {molecular_data['Z']}")
print(f"   - Atom positions:")
for i, pos in enumerate(molecular_data['R']):
    print(f"     Atom {i}: {pos}")

print("\n2. Available Models and Charges:")
total_charges = 0
for model_id, charges in model_charges.items():
    n_charges = charges.shape[1]
    total_charges += n_charges
    print(f"   - DCM{model_id+1}: {n_charges} charges per atom")
    print(f"     Example charges for atom 0: {charges[0].tolist()}")
    print(f"     Example positions for atom 0:")
    for j, pos in enumerate(model_positions[model_id][0]):
        print(f"       Charge {j}: {pos}")

print(f"\n   Total charges per atom across all models: {total_charges}")
print(f"   Total possible charge combinations: {total_charges ** len(molecular_data['Z'])}")

print("\n3. Creating DCMNET Selection Environment...")
env = DCMNETSelectionEnv(molecular_data, esp_target, vdw_surface, model_charges, model_positions)

print(f"   - Environment state shape: {env.selected_charges.shape}")
print(f"   - Charge mapping: {env.charge_mapping}")
print(f"   - Legal actions: {len(env.legal_actions())} possible (atom_idx, charge_idx) pairs")

print("\n4. Running MCTS Optimization...")
best_selection, best_loss = optimize_dcmnet_combination(
    molecular_data=molecular_data,
    esp_target=esp_target,
    vdw_surface=vdw_surface,
    model_charges=model_charges,
    model_positions=model_positions,
    n_simulations=100,  # Small number for demo
    temperature=1.0
)

In [None]:
best_selection, best_loss