In [1]:
! uv pip install lovely-jax

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".99"

import lovely_jax as lj

lj.monkey_patch()

import jax
from openqdc.datasets import SpiceV2 as Spice

from physnetjax.data.datasets import process_in_memory
from physnetjax.models.model import EF
from physnetjax.training.training import train_model

# Configurable Constants
NATOMS = 110
DEFAULT_DATA_KEYS = ["Z", "R", "D", "E", "F", "N"]
RANDOM_SEED = 42
BATCH_SIZE = 20

# # Environment configuration
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".99"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# JAX Configuration Check
def check_jax_configuration():
    devices = jax.local_devices()
    print("Devices:", devices)
    print("Default Backend:", jax.default_backend())
    print("All Devices:", jax.devices())


check_jax_configuration()





In [2]:
batch_kwargs = {
    "batch_shape" : int((BATCH_SIZE - 1) * NATOMS),
    "nb_len" : int((NATOMS * (NATOMS - 1) * (BATCH_SIZE - 1)) // 1.6)
}

print(batch_kwargs)


batch_method = "advanced"
if batch_method == "advanced" and isinstance(batch_kwargs, dict) and \
    "batch_shape" in batch_kwargs and "nb_len" in batch_kwargs:
    print("Using advanced batching method")
    from physnetjax.data.batches import prepare_batches_advanced_minibatching
    def _prepare_batches(x):
        return prepare_batches_advanced_minibatching(
        x["key"],
        x["data"],
        x["batch_size"],
        batch_kwargs["batch_shape"],
        batch_kwargs["nb_len"],
        num_atoms=x["num_atoms"],
        data_keys=x["data_keys"],
    )
else:
    print("Using default batching method")
    from physnetjax.data.batches import get_prepare_batches_fn
    _prepare_batches = get_prepare_batches_fn()


prepare_spice_dataset?

In [3]:
# Constants
NATOMS = 110
# total number of samples, SpiceV2 = 2008628
NTRAIN = 100000
NVALID = 500
DATA_KEYS = ("Z", "R", "E", "F", "N")
RANDOM_SEED = 42
BATCH_SIZE = 32


# Dataset preparation
def prepare_spice_dataset(
    dataset, subsample_size, max_atoms, ignore_indices=None, key=jax.random.PRNGKey(42)
):
    """Prepare the dataset by preprocessing and subsampling."""
    key = key[0] if len(key) > 1 else key
    indices = dataset.subsample(subsample_size, seed=key)
    if ignore_indices is not None:
        indices = [_ for _ in indices if _ not in ignore_indices]
    d = [dict(ds[_]) for _ in indices]
    res = process_in_memory(d, max_atoms=max_atoms, openqdc=True)
    return res, indices


ds = Spice(energy_unit="ev", distance_unit="ang", array_format="jax")
ds.read_preprocess()


# Random key initialization
data_key, train_key = jax.random.split(jax.random.PRNGKey(RANDOM_SEED), 2)


# load the validation set
validation_set, validation_set_idxs = prepare_spice_dataset(
    ds,
    subsample_size=NVALID,
    max_atoms=NATOMS,
    key=data_key,
)

# # get a new data key
# data_key, _ = jax.random.split(data_key, 2)
# # load the training set
# training_set, training_set_idxs = prepare_spice_dataset(
#     ds,
#     subsample_size=NTRAIN,
#     max_atoms=NATOMS,
#     key=data_key,
#     ignore_indices=validation_set_idxs,
# )


In [4]:
from argparse import ArgumentParser
from pathlib import Path
from physnetjax.analysis.plot_run import plot_run
import polars as pl
import matplotlib.pyplot as plt
from physnetjax.directories import LOGS_PATH, BASE_CKPT_DIR
from physnetjax.logger.tensorboard_interface import process_tensorboard_logs

logs_path = BASE_CKPT_DIR / "test-6276dc44-fdba-4835-960b-d42df1b1800a" / "tfevents"
key = logs_path.parent.name
df = process_tensorboard_logs(logs_path)

# pretty print polars dataframe with rich
from rich.console import Console

console = Console()
console.print(df)

fig, ax = plt.subplots(5, 2, figsize=(12, 12))
plot_run(df, ax, 1, key, log=True)
# save the plot
save_path = LOGS_PATH / key / "tf_logs.png"
# make the directory if it does not exist
save_path.parent.mkdir(exist_ok=True, parents=True)
fig.savefig(save_path, bbox_inches="tight")
# save the dataframe as a csv
df.write_csv(LOGS_PATH / key / "tf_logs.csv")


In [5]:
key, shuffle_key = jax.random.split(data_key)

kwargs = {
    "key": shuffle_key,
    "data": validation_set,
    "batch_size": BATCH_SIZE,
    "num_atoms": NATOMS,
    "data_keys": DATA_KEYS,
}


valid_batches = _prepare_batches(kwargs)


# kwargs = {
#     "key": shuffle_key,
#     "data": train_data,
#     "batch_size": batch_size,
#     "num_atoms": num_atoms,
#     "data_keys": data_keys,
# }


# train_batches = _prepare_batches(kwargs)

In [6]:
b = valid_batches[0]
b.keys()

In [7]:
jax.numpy.array(b["F"]).v, b["batch_mask"]

In [8]:
fig = plt.figure(figsize=(200,2.5))
n_rows = 5

ax1 = plt.subplot(n_rows, 1, 1)
ax1.matshow(b["R"].T, cmap="bwr")
ax1.set_axis_off()

ax2 = plt.subplot(n_rows, 1, 2,)
ax2.matshow(b["Z"][...,None].T, cmap="rainbow", vmin=0, vmax=35)
ax2.set_axis_off()

ax3 = plt.subplot(n_rows, 1, 3, )
ax3.matshow(b["F"].T, cmap="bwr", vmin=-0.1, vmax=0.1)
ax3.set_axis_off()

ax4 = plt.subplot(n_rows, 1, 4,)
ax4.matshow(b["atom_mask"][None,...], cmap="Set2")
ax4.set_axis_off()

ax5 = plt.subplot(n_rows, 1, 5,)
ax5.matshow(b["batch_segments"][None,...], cmap="Set2")
ax5.set_axis_off()

plt.tight_layout()
plt.show()

In [9]:
from physnetjax.restart.restart import get_last

restart = BASE_CKPT_DIR / "test-6276dc44-fdba-4835-960b-d42df1b1800a" 
restart = get_last(restart)
restart

In [10]:
from physnetjax.restart.restart import get_params_model

params, model = get_params_model(restart, natoms=110)

In [11]:
model.natoms = NATOMS
print(model)
from physnetjax.analysis.analysis import plot_stats
output = plot_stats(valid_batches, model, params, _set="Test", 
               do_kde=True, batch_size=32)


In [12]:
import numpy as np
energies = np.array([_["E"] for _ in valid_batches])
# energies.flatten()

In [13]:
output.keys()
_idx = output["predEs"].nonzero()[0]
for i, (a,b) in enumerate(zip(output["predEs"][_idx], output["Es"][_idx])):
    print( i // 32 , i % 32, a, b)

In [14]:
ns = np.vstack([_["N"] for _ in valid_batches])
nonzero = np.nonzero(ns.flatten())[0]
plt.hist(ns.flatten()[nonzero])
ns, energies

In [12]:
# plt.hist(output2["E"])
# ds._e0s_dispatcher[output2["Z"][0]]
print(ds[0]["energies"], 512.6264  * 0.0367492929)
print(ds[0]["energies"] - ds[0]["e0"].sum() * 0.0367492929  )
print(ds[0]["e0"].sum() - 512.6264 )
print(ds[0])
ds[0]["e0"].sum() * 0.0367492929, [np.array([ds._e0s_dispatcher[int(_)].mean for _ in ds[0]["atomic_numbers"] if _ != 0]).sum() ]

In [13]:
forces = np.vstack([_["F"] for _ in valid_batches])
nonzero = np.nonzero(forces.flatten())[0]
plt.hist(forces.flatten()[nonzero])
# forces

In [None]:
from physnetjax.utils import get_last, get_files, get_params_model
from physnetjax.analysis import plot_stats
NATOMS = 110
model.natoms = NATOMS

output = plot_stats(combined, model, params, _set="Test", 
               do_kde=True, batch_size=batch_size)

# Example training

In [1]:
# Model initialization
model = EF(
    features=128,
    max_degree=0,
    num_iterations=5,
    num_basis_functions=16,
    cutoff=5.0,
    max_atomic_number=70,
    charges=False,
    natoms=NATOMS,
    total_charge=0,
    n_res=2,
    zbl=False,
)


# Model training
params = train_model(
    train_key,
    model,
    output1,
    output2,
    num_epochs= int(10**2),
    learning_rate=0.001,
    energy_weight=1,
    schedule_fn="constant",
    optimizer="amsgrad",
    batch_size=BATCH_SIZE,
    num_atoms=NATOMS,
    data_keys=DEFAULT_DATA_KEYS,
    print_freq=1,
    objective="valid_loss",
    best=1e6,
    batch_method="advanced",
    batch_args_dict=batch_kwargs,
)