## An example optimziation using a single objective and multiple simulators


In [1]:
import os
os.chdir(os.path.join(os.getcwd() ,"../../.."))
os.getcwd()

'/home/ryan/repos/jax-dna'

In [2]:
import functools
import itertools
import logging
import os
from pathlib import Path
import shutil
import typing
import warnings
import jax
import jax.numpy as jnp
import jax_md
import optax
import ray

import jax_dna.energy as jdna_energy
import jax_dna.energy.dna1 as dna1_energy
import jax_dna.input.toml as toml_reader
import jax_dna.input.trajectory as jdna_traj
import jax_dna.input.topology as jdna_top
import jax_dna.input.tree as jdna_tree
import jax_dna.observables as jd_obs
import jax_dna.optimization.simulator as jdna_simulator
import jax_dna.optimization.objective as jdna_objective
import jax_dna.optimization.optimization as jdna_optimization
import jax_dna.simulators.oxdna as oxdna
import jax_dna.simulators.io as jdna_sio
import jax_dna.utils.types as jdna_types
import jax_dna.ui.loggers.jupyter as jupyter_logger

jax.config.update("jax_enable_x64", True)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
# This is a function that combines multiple gradients into a single gradient
# our example will only use a single gradient, so it will just be the identity function
def tree_mean(trees:tuple[jdna_types.PyTree]) -> jdna_types.PyTree:
    if len(trees) <= 1:
        return trees[0]
    summed = jax.tree.map(operator.add, *trees)
    return jax.tree.map(lambda x: x / len(trees), summed)

In [4]:
ray.init(
    ignore_reinit_error=True,
    log_to_driver=True,
    runtime_env={
        "env_vars": {
            "JAX_ENABLE_X64": "True",
            "JAX_PLATFORM_NAME": "cpu",
        }
    }
)

  self.pid = _fork_exec(
2025-01-30 16:26:13,206	INFO worker.py:1816 -- Started a local Ray instance.


0,1
Python version:,3.11.0
Ray version:,2.38.0


In [5]:
optimization_config = {
    "n_steps": 50,
    "oxdna_build_threads": 4,
    "log_every": 5,
    "n_oxdna_runs": 3,
}

Setup the experiment

In [6]:
kT = toml_reader.parse_toml("jax_dna/input/dna1/default_simulation.toml")["kT"]
geometry = toml_reader.parse_toml("jax_dna/input/dna1/default_energy.toml")["geometry"]

template_dir = Path("data/templates/simple-helix")
topology_fname = template_dir / "sys.top"

cwd = Path(os.getcwd())

Setup the energy function

In [7]:
energy_fns = dna1_energy.default_energy_fns()
energy_configs = []
opt_params = []

for ec in dna1_energy.default_energy_configs():
    # We are only interested in the stacking configuration
    # However we don't want to optimize ss_stack_weights and kt
    if isinstance(ec, dna1_energy.StackingConfiguration):
        ec = ec.replace(
            non_optimizable_required_params=(
                "ss_stack_weights",
                "kt",
            )
        )
        opt_params.append(ec.opt_params)
        energy_configs.append(ec)
    else:
        energy_configs.append(ec)
        opt_params.append({})

transform_fn = functools.partial(
    dna1_energy.Nucleotide.from_rigid_body,
    com_to_backbone=geometry["com_to_backbone"],
    com_to_hb=geometry["com_to_hb"],
    com_to_stacking=geometry["com_to_stacking"],
)

energy_fn_builder_fn = jdna_energy.energy_fn_builder(
    energy_fns=energy_fns,
    energy_configs=energy_configs,
    transform_fn=transform_fn,
)

top = jdna_top.from_oxdna_file(topology_fname)
def energy_fn_builder(params: jdna_types.Params) -> callable:
    return jax.vmap(
        lambda trajectory: energy_fn_builder_fn(params)(
            trajectory.rigid_body,
            seq=jnp.array(top.seq),
            bonded_neighbors=top.bonded_neighbors,
            unbonded_neighbors=top.unbonded_neighbors.T,
        )
        / top.n_nucleotides
    )

Setup the simulators

In [8]:
run_flag = oxdna.oxDNABinarySemaphoreActor.remote()

sim_outputs_dir = cwd / "sim_outputs"
sim_outputs_dir.mkdir(parents=True, exist_ok=True)

def make_simulator(id:str, disable_build:bool) -> jdna_simulator.BaseSimulator:
    sim_dir = sim_outputs_dir / id

    if sim_dir.exists():
        warnings.warn(f"Directory {sim_dir} already exists. Assuming that's fine.")

    sim_dir.mkdir(parents=True, exist_ok=True)

    for f in template_dir.iterdir():
        shutil.copy(f, sim_dir)

    simulator = oxdna.oxDNASimulator(
        input_dir=sim_dir,
        sim_type=jdna_types.oxDNASimulatorType.DNA1,
        energy_configs=energy_configs,
        n_build_threads=optimization_config["oxdna_build_threads"],
        disable_build=disable_build,
        check_build_ready=lambda: ray.get(run_flag.check.remote()),
        set_build_ready=run_flag.set.remote,
    )

    output_dir = sim_dir / "trajectory"
    trajectory_loc = output_dir / "trajectory.pkl"
    if not output_dir.exists():
        output_dir.mkdir(parents=True, exist_ok=True)

    def simulator_fn(
        params: jdna_types.Params,
        meta: jdna_types.MetaData,
    ) -> tuple[str, str]:
        simulator.run(params)

        ox_traj = jdna_traj.from_file(
            sim_dir / "output.dat",
            strand_lengths=top.strand_counts,
        )
        traj = jdna_sio.SimulatorTrajectory(
            rigid_body=ox_traj.state_rigid_body,
        )

        jdna_tree.save_pytree(traj, trajectory_loc)
        return [trajectory_loc]

    return jdna_simulator.SimulatorActor.options(
        runtime_env={
            "env_vars": {
                oxdna.BIN_PATH_ENV_VAR: str(Path("../oxDNA/build/bin/oxDNA").resolve()),
                oxdna.BUILD_PATH_ENV_VAR: str(Path("../oxDNA/build").resolve()),
            },
        },
    ).remote(
        name=id,
        fn=simulator_fn,
        exposes=[f"traj-{id}",],
        meta_data={},
    )


sim_ids = [f"sim{i}" for i in range(optimization_config["n_oxdna_runs"])]
traj_ids = [f"traj-{id}" for id in sim_ids]

simulators = [make_simulator(*id_db) for id_db in zip(sim_ids, [False] + [True]*(len(sim_ids)-1))]



Setup the objective

In [9]:
prop_twist_fn = jd_obs.propeller.PropellerTwist(
    rigid_body_transform_fn=transform_fn,
    h_bonded_base_pairs=jnp.array([[1, 14], [2, 13], [3, 12], [4, 11], [5, 10], [6, 9]]),
)

def prop_twist_loss_fn(
    traj: jax_md.rigid_body.RigidBody,
    weights: jnp.ndarray,
    energy_model: jdna_energy.base.ComposedEnergyFunction,
) -> tuple[float, tuple[str, typing.Any]]:
    obs = prop_twist_fn(traj)
    expected_prop_twist = jnp.dot(weights, obs)
    loss = (expected_prop_twist - jd_obs.propeller.TARGETS["oxDNA"]) ** 2
    loss = jnp.sqrt(loss)
    return loss, (("prop_twist", expected_prop_twist), {})

propeller_twist_objective = jdna_objective.DiffTReObjectiveActor.remote(
    name="prop_twist",
    required_observables=traj_ids,
    needed_observables=traj_ids,
    logging_observables=["loss", "prop_twist"],
    grad_or_loss_fn=prop_twist_loss_fn,
    energy_fn_builder=energy_fn_builder,
    opt_params=opt_params,
    min_n_eff_factor=0.95,
    beta=jnp.array(1 / kT, dtype=jnp.float64),
    n_equilibration_steps=0,
    max_valid_opt_steps=10,
)

In [10]:

params_to_log = [
    "eps_stack_base",
    "eps_stack_kt_coeff",
    [
        "dr_low_stack",
        "dr_high_stack",
        "a_stack",
        "dr0_stack",
        "dr_c_stack",
    ],
    [
        "theta0_stack_4",
        "delta_theta_star_stack_4",
        "a_stack_4",
    ],
    [
        "theta0_stack_5",
        "delta_theta_star_stack_5",
        "a_stack_5",
    ],
    [
        "theta0_stack_6",
        "delta_theta_star_stack_6",
        "a_stack_6",
    ],
    [
        "neg_cos_phi1_star_stack",
        "a_stack_1",
    ],
    [
        "neg_cos_phi2_star_stack",
        "a_stack_2",
    ],

]
params_list_flat = list(itertools.chain.from_iterable([[p,] if isinstance(p,str) else p for p in params_to_log]))

logger = jupyter_logger.JupyterLogger(
    simulators=sim_ids,
    observables=traj_ids,
    objectives=["prop_twist"],
    metrics_to_log=["loss", ["prop_twist", "target_ptwist"]] + params_to_log,
    max_opt_steps=optimization_config["n_steps"],
    plots_size_px=(900, 1400),
    plots_nrows_ncols = (3, 1)
)
logger.show()



VBox(children=(Label(value='Optimization Status'), HBox(children=(IntProgress(value=0, bar_style='info', descr…

Make the optimizer


In [11]:
# Optimization =============================================================
objectives = [propeller_twist_objective]

opt = jdna_optimization.Optimization(
    objectives=objectives,
    simulators=simulators,
    optimizer = optax.adam(learning_rate=1e-3),
    aggregate_grad_fn=tree_mean,
    logger=logger,
)
# ==========================================================================

In [12]:
jd_obs.propeller.TARGETS["oxDNA"]

21.7

Run the optimization

In [None]:
for i in range(optimization_config["n_steps"]):
    opt_state, opt_params = opt.step(opt_params)

    if i % optimization_config["log_every"] == 0:
        for objective in opt.objectives:
            log_values = ray.get(objective.logging_observables.remote())
            for (name, value) in log_values:
                logger.log_metric(name, value, step=i)
            logger.log_metric("target_ptwist", jd_obs.propeller.TARGETS["oxDNA"], step=i)

        for param in opt_params:
            for k, v in filter(lambda kv : kv[0] in params_list_flat,  param.items()):
                    logger.log_metric(k, v, step=i)


    opt = opt.post_step(
        optimizer_state=opt_state,
        opt_params=opt_params,
    )
    logger.increment_prog_bar()
    # block the oxdna builds so that the simulator that builds can do so
    run_flag.set.remote(False)


[36m(SimulatorActor pid=260952)[0m An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[36m(SimulatorActor pid=260951)[0m 2025-01-30 16:26:20,639 INFO:jax_dna.simulators.oxdna.oxdna:Updating oxDNA parameters
[36m(SimulatorActor pid=260951)[0m 2025-01-30 16:27:20,546 INFO:jax_dna.simulators.oxdna.oxdna:oxDNA binary rebuilt
[36m(SimulatorActor pid=260951)[0m 2025-01-30 16:27:20,546 INFO:jax_dna.simulators.oxdna.oxdna:oxDNA input file: /home/ryan/repos/jax-dna/sim_outputs/sim0/input
[36m(SimulatorActor pid=260951)[0m 2025-01-30 16:27:20,548 INFO:jax_dna.simulators.oxdna.oxdna:Starting oxDNA simulation
[36m(SimulatorActor pid=260951)[0m An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.[32m [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/us

loss 0.5974519179044435
prop_twist 21.102548082095556
loss 0.5533786833253629
prop_twist 21.146621316674636
loss 0.2547327670868391
prop_twist 21.95473276708684
loss 0.25742190627375194
prop_twist 21.95742190627375
loss 0.3128843817610587
prop_twist 21.38711561823894
loss 0.296586557641735
prop_twist 21.403413442358264
loss 0.1523194918130315
prop_twist 21.85231949181303
loss 0.1489189911686104
prop_twist 21.84891899116861
loss 0.5399510642636507
prop_twist 22.23995106426365
loss 0.352566844834751
prop_twist 21.34743315516525
