In [None]:
# ! 
import mmml
import matplotlib.pyplot as plt
import patchworklib as pw
import os
from pathlib import Path
import numpy as np
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import jax
devices = jax.local_devices()
print(devices)
print(jax.default_backend())
print(jax.devices())

# %%
from mmml import dcmnet
params = None
# %%
import jax
import jax.numpy as jnp
from mmml.dcmnet.dcmnet.data import prepare_datasets
from mmml.dcmnet.dcmnet.modules import MessagePassingModel
from mmml.dcmnet.dcmnet.training import train_model, train_model_dipo

key = jax.random.PRNGKey(0)

seed = 42
# %%
NDCM = 4
model = MessagePassingModel(
    features=128, max_degree=2, num_iterations=2,
    num_basis_functions=32, cutoff=8.0, n_dcm=NDCM,
    include_pseudotensors=False,
)


data_path_resolved = Path('/scicore/home/meuwly/boitti0000/test.npz') 
if not data_path_resolved.exists():
    data_path_resolved = Path('/pchem-data/meuwly/boittier/home/test.npz') 
    if not data_path_resolved.exists():
        raise FileNotFoundError(f"Data file not found at {data_path_resolved}")

data_loaded = np.load(data_path_resolved, allow_pickle=True)

for k in data_loaded.keys():
    print(k)
    shape = data_loaded[k].shape
    print(shape
    )

n_sample = 1000  # Number of points to keep
Nboot = 10
for i in range(Nboot):
    data_key = jax.random.PRNGKey(i*seed)

    train_data, valid_data = prepare_datasets(
        data_key, num_train=1000, num_valid=100,
        filename=[data_path_resolved],
        clean=False, esp_mask=False,
        natoms=18,
        clip_esp=False,
    )

    def random_sample_esp(esp, esp_grid, n_sample, seed=i*seed):
        np.random.seed(seed)
        sampled_esp = []
        sampled_grid = []
        
        for i in range(len(esp)):
            lessthan = esp[i] < 2
            morethan = esp[i] > -2
            not_0 = esp[i] != 0.0
            condmask = lessthan*morethan*not_0
            _shape = esp[i][condmask].shape[0]
            # print(_shape)
            indices = np.random.choice(_shape, n_sample, replace=False)
            #indices = np.sort(indices) 
            sampled_esp.append(np.take(esp[i], condmask[indices]))
            # print(sampled_esp[-1].shape)
            sampled_grid.append(np.take(esp_grid[i], condmask[indices], axis=0))
            # print(sampled_grid[-1].shape)
        
        return np.array(sampled_esp), np.array(sampled_grid)

    train_data["esp"], train_data["esp_grid"] = random_sample_esp(
        train_data["esp"] , train_data["esp_grid"], n_sample
    )
    valid_data["esp"], valid_data["esp_grid"] = random_sample_esp(
        valid_data["esp"] , valid_data["esp_grid"], n_sample
    )


    valid_data["esp"] = 0.0016 * valid_data["esp"]
    train_data["esp"] = 0.0016 * train_data["esp"]

    train_data["vdw_surface"] = train_data["esp_grid"] 
    valid_data["vdw_surface"] = valid_data["esp_grid"] 
    train_data["n_grid"] = np.full(len(train_data["vdw_surface"]), n_sample)
    valid_data["n_grid"] = np.full(len(valid_data["vdw_surface"]), n_sample)


    train_data["vdw_surface"] = train_data["esp_grid"]
    valid_data["vdw_surface"] = valid_data["esp_grid"]

    Hs_train = train_data["Z"] == 1.0
    Os_train = train_data["Z"] == 8.0
    Hs_valid = valid_data["Z"] == 1.0
    Os_valid = valid_data["Z"] == 8.0

    train_data["mono"] = Hs_train * 0.1 + Os_train * -0.2
    valid_data["mono"] = Hs_valid * 0.1 + Os_valid * -0.2

    # Fix n_grid shape
    train_data["n_grid"] = np.full(train_data["Z"].shape[0], n_sample)
    valid_data["n_grid"] = np.full(valid_data["Z"].shape[0], n_sample)

    # Fix N shape  
    train_data["N"] = np.count_nonzero(train_data["Z"], axis=1)
    valid_data["N"] = np.count_nonzero(valid_data["Z"], axis=1)

    print("After fixes:")
    batch = {k: v[0:1] if len(v.shape) > 0 else v for k, v in train_data.items()}
    for key in ['mono', 'esp', 'vdw_surface', 'n_grid', 'N', 'R', 'Z']:
        if key in batch:
            print(f"{key}: {batch[key].shape}")

    # Also check the specific values
    print(f"\nmono values: {batch['mono']}")
    print(f"N values: {batch['N']}")
    print(f"n_grid values: {batch['n_grid']}")

    # %%
    esp_data = train_data["esp"]


    # %%
    params, valid_loss = train_model(
        key=data_key, model=model,
        writer=None,
        train_data=train_data, valid_data=valid_data,
        num_epochs=50, learning_rate=1e-4, batch_size=1,
        restart_params=params if params is None else params,
        ndcm=model.n_dcm, esp_w=1000.0*((i+1)/Nboot), chg_w=1.0/((i+1)),
         use_grad_clip=True, grad_clip_norm=1.0,
    )
    new_params = params.copy()

from mmml.dcmnet.dcmnet.analysis import dcmnet_analysis, prepare_batch
from mmml.dcmnet.dcmnet.data import prepare_batches
from mmml.dcmnet.dcmnet.analysis import dcmnet_analysis

def prepare_batch_for_analysis(data, index=0):
    """Prepare a single batch correctly for dcmnet_analysis."""
    # Extract single item but keep batch dimension
    _dict = {k: np.array(v[[index]]) for k, v in data.items()}
    
    # Use prepare_batches with include_id=True
    batch = prepare_batches(jax.random.PRNGKey(0), _dict, batch_size=1, include_id=False, num_atoms =18)[0]
    batch["com"] = np.array([0,0,0])
    batch["Dxyz"] = np.array([0,0,0])
    return batch

batch = prepare_batch_for_analysis(train_data, index=0)
output = dcmnet_analysis(params, model, batch, 18)
print(f"RMSE: {output['rmse_model']:.6f}")
print(f"RMSE (masked): {output['rmse_model_masked']:.6f}")





[CudaDevice(id=0)]
gpu
[CudaDevice(id=0)]
R
(1983, 18, 3)
D
(1983, 3)
Q
(1983, 3, 3)
Z
(1983, 18)
esp
(1983, 44510)
esp_grid
(1983, 44510, 3)
F
(1983, 18, 3)
shape (1983, 18, 3)
D (1983, 3)
Q 3 (1983, 3, 3) 1983
Q (1983, 3, 3)
R (1983, 18, 3)
(1983, 18, 3)
['R', 'Z', 'F', 'esp', 'D', 'esp_grid', 'Q']
1983
0 R 1983 (1983, 18, 3)
1 Z 1983 (1983, 18)
2 F 1983 (1983, 18, 3)
3 esp 1983 (1983, 44510)
4 D 5949 (5949, 1)
5 esp_grid 1983 (1983, 44510, 3)
6 Q 1983 (1983, 3, 3)
After fixes:
mono: (1, 18)
esp: (1, 1000)
vdw_surface: (1, 1000, 3)
n_grid: (1,)
N: (1,)
R: (1, 18, 3)
Z: (1, 18)

mono values: [[-0.2  0.1  0.1 -0.2  0.1  0.1  0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0. ]]
N values: [6]
n_grid values: [1000]
Preparing batches
..................
Training
..................

Epoch   1 Statistics
Metric                         Train           Valid      Difference
--------------------------------------------------------------------------------
loss                    6.512354