In [1]:
import os
import sys

sys.path.append(os.path.join(os.getcwd(), ".."))

In [2]:
import functools
import itertools
import logging
import os
from pathlib import Path
import typing
import jax
import jax.numpy as jnp
import jax_md
import optax
import ray
import sys
from tqdm import tqdm


import jax_dna
import jax_dna.energy as jdna_energy
import jax_dna.energy.dna1 as dna1_energy
import jax_dna.input.toml as toml_reader
import jax_dna.input.tree as jdna_tree
import jax_dna.observables as jd_obs
import jax_dna.optimization.simulator as jdna_simulator
import jax_dna.optimization.objective as jdna_objective
import jax_dna.optimization.optimization as jdna_optimization
import jax_dna.simulators.oxdna as oxdna
import jax_dna.simulators.io as jdna_sio
import jax_dna.utils.types as jdna_types
import jax_dna.ui.loggers.jupyter as jupyter_logger
import jax_dna.ui.loggers.console as console_logger
from jax_dna.input import topology, trajectory

os.environ[oxdna.BIN_PATH_ENV_VAR] = str(Path("../../oxDNA/build/bin/oxDNA").resolve())
os.environ[oxdna.BUILD_PATH_ENV_VAR] =  str(Path("../../oxDNA/build").resolve())

jax.config.update("jax_enable_x64", True)

In [3]:
def tree_mean(trees:tuple[jdna_types.PyTree]) -> jdna_types.PyTree:
    if len(trees) <= 1:
        return trees[0]
    summed = jax.tree.map(operator.add, *trees)
    return jax.tree.map(lambda x: x / len(trees), summed)

In [4]:
ray.init(
    ignore_reinit_error=True,
    log_to_driver=True,
    runtime_env={
        "env_vars": {
            "JAX_ENABLE_X64": "True",
            "JAX_PLATFORM_NAME": "cpu",
            "RAY_DEBUG": "1",
        },
        "py_modules":[jax_dna],
    }
)

2025-01-17 17:05:51,372	INFO worker.py:1634 -- Connecting to existing Ray cluster at address: 172.30.118.26:6379...
2025-01-17 17:05:51,391	INFO worker.py:1810 -- Connected to Ray cluster. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2025-01-17 17:05:51,457	INFO packaging.py:600 -- Creating a file package for local module '/home/ryanhausen/repos/jax-dna/examples/../jax_dna'.
2025-01-17 17:05:51,543	INFO packaging.py:392 -- Pushing file package 'gcs://_ray_pkg_fcbc78e9eebb24f6.zip' (1.59MiB) to Ray cluster...
2025-01-17 17:05:51,575	INFO packaging.py:405 -- Successfully pushed file package 'gcs://_ray_pkg_fcbc78e9eebb24f6.zip'.


0,1
Python version:,3.11.0
Ray version:,2.39.0
Dashboard:,http://127.0.0.1:8265


In [5]:
objective_logging_config = {
    "level":logging.DEBUG,
    "filename":"objective.log",
    "filemode":"w",
}
simulator_logging_config = objective_logging_config | {"filename": "simulator.log"}

In [6]:
optimization_config = {
    "n_steps": 1000,
    "batch_size": 1,
}

simulation_config = toml_reader.parse_toml("../jax_dna/input/dna1/default_simulation.toml")
kT = simulation_config["kT"]
energy_config = toml_reader.parse_toml("../jax_dna/input/dna1/default_energy.toml")

energy_fns = dna1_energy.default_energy_fns()
energy_configs = dna1_energy.default_configs()
opt_params = []
for ec in energy_configs:
    opt_params.append(
        ec.opt_params if isinstance(ec, dna1_energy.StackingConfiguration) else {}
    )

for op in opt_params:
    if "ss_stack_weights" in op:
        del op["ss_stack_weights"]
# opt_params = [c.opt_params for c in energy_configs]


geometry = energy_config["geometry"]
transform_fn = functools.partial(
    dna1_energy.Nucleotide.from_rigid_body,
    com_to_backbone=geometry["com_to_backbone"],
    com_to_hb=geometry["com_to_hb"],
    com_to_stacking=geometry["com_to_stacking"],
)

energy_fn_builder_fn = jdna_energy.energy_fn_builder(
    energy_fns=energy_fns,
    energy_configs=energy_configs,
    transform_fn=transform_fn,
)

topology_fname = "../data/templates/simple-helix/sys.top"
top = topology.from_oxdna_file(topology_fname)

def energy_fn_builder(params: jdna_types.Params) -> callable:
    return jax.vmap(
        lambda trajectory: energy_fn_builder_fn(params)(
            trajectory.rigid_body,
            seq=jnp.array(top.seq),
            bonded_neighbors=top.bonded_neighbors,
            unbonded_neighbors=top.unbonded_neighbors.T,
        )
    )


### Simulator 

In [7]:
input_dir = "../data/templates/simple-helix"
simulator = oxdna.oxDNASimulator(
    input_dir=input_dir,
    sim_type=jdna_types.oxDNASimulatorType.DNA1,
    energy_configs=energy_configs,
    n_build_threads=4,
    logger_config=simulator_logging_config,
)

cwd = Path(os.getcwd())
output_dir = cwd / "basic_trajectory"
trajectory_loc = output_dir / "trajectory.pkl"
if not output_dir.exists():
    output_dir.mkdir(parents=True, exist_ok=True)

def simulator_fn(
    params: jdna_types.Params,
    meta: jdna_types.MetaData,
) -> tuple[str, str]:
    traj = simulator.run(params)
    p = Path("energies")
    p.mkdir(parents=True, exist_ok=True)
    n = len(list(p.glob("*.npy")))
    jnp.save(f"energies-{n}.npy", energy_fn_builder(params)(traj))
    jdna_tree.save_pytree(traj, trajectory_loc)
    return [trajectory_loc]

obs_trajectory = "trajectory"

trajectory_simulator = jdna_simulator.BaseSimulator(
    name="oxdna-sim",
    fn=simulator_fn,
    exposes = [obs_trajectory],
    meta_data = {},
)

### Objective

In [8]:
prop_twist_fn = jd_obs.propeller.PropellerTwist(
    rigid_body_transform_fn=transform_fn,
    h_bonded_base_pairs=jnp.array([[1, 14], [2, 13], [3, 12], [4, 11], [5, 10], [6, 9]])
)

def prop_twist_loss_fn(
    traj: jax_md.rigid_body.RigidBody,
    weights: jnp.ndarray,
    energy_model: jdna_energy.base.ComposedEnergyFunction,
) -> tuple[float, tuple[str, typing.Any]]:
    obs = prop_twist_fn(traj)
    expected_prop_twist = jnp.dot(weights, obs)
    loss = (expected_prop_twist - jd_obs.propeller.TARGETS["oxDNA"])**2
    loss = jnp.sqrt(loss)
    return loss, (("prop_twist", expected_prop_twist), {})


propeller_twist_objective = jdna_objective.DiffTReObjective(
    name = "DiffTRe",
    required_observables = [obs_trajectory],
    needed_observables = [obs_trajectory],
    logging_observables = ["loss", "prop_twist", "neff"],
    grad_or_loss_fn = prop_twist_loss_fn,
    energy_fn_builder = energy_fn_builder,
    opt_params = opt_params,
    trajectory_key = obs_trajectory,
    min_n_eff_factor = 0.95,
    beta = jnp.array(1/kT),
    n_equilibration_steps = 0, # periodic steps are already in oxdna
)                               # print intervals are every 100, so we'll toss
                                # the first 1000 of the 30000
                                # DONT DO THIS FOR OXDNA

In [9]:
params_to_log = [
    "eps_stack_base",
    "eps_stack_kt_coeff",
    [
        "dr_low_stack",
        "dr_high_stack",
        "a_stack",
        "dr0_stack",
        "dr_c_stack",
    ],
    [
        "theta0_stack_4",
        "delta_theta_star_stack_4",
        "a_stack_4",
    ],
    [
        "theta0_stack_5",
        "delta_theta_star_stack_5",
        "a_stack_5",
    ],
    [
        "theta0_stack_6",
        "delta_theta_star_stack_6",
        "a_stack_6",
    ],
    [
        "neg_cos_phi1_star_stack",
        "a_stack_1",
    ],
    [
        "neg_cos_phi2_star_stack",
        "a_stack_2",
    ],

]
params_list_flat = list(itertools.chain.from_iterable([[p,] if isinstance(p,str) else p for p in params_to_log]))

### Make logger

In [None]:
# logger = jupyter_logger.JupyterLogger(
#     simulators=["oxdna-sim",],
#     observables=["prop_twist",],
#     objectives=["DiffTRe"],
#     metrics_to_log=["loss", ["prop_twist", "target_ptwist"], "neff"],# + params_to_log,
#     max_opt_steps=optimization_config["n_steps"],
#     plots_size_px=(900, 1400),
#     plots_nrows_ncols = (3, 1)
# )
# logger.show()
logger  = console_logger.ConsoleLogger(
    log_dir="logs",
)

### Run optimization

In [11]:
jd_obs.propeller.TARGETS["oxDNA"]

21.7

In [None]:
objectives = [propeller_twist_objective]
simulators = [trajectory_simulator]

opt = jdna_optimization.SimpleOptimizer(
    objective=objectives[0],
    simulator=simulators[0],
    optimizer = optax.adam(learning_rate=1e-3),
    logger=logger,
)
log_every = 10

for i in range(optimization_config["n_steps"]):
    opt_state, opt_params = opt.step(opt_params)
    log_values = objectives[0].logging_observables()

    if i % log_every == 0:
        for (name, value) in log_values:
            logger.log_metric(name, value, step=i)
        logger.log_metric("target_ptwist", jd_obs.propeller.TARGETS["oxDNA"], step=i)

    opt = opt.post_step(
        optimizer_state=opt_state,
        opt_params=opt_params,
    )
    # logger.increment_prog_bar()

DEBUG:jax_dna.simulators.oxdna.oxdna:cmake_bin: /bin/cmake
DEBUG:jax_dna.simulators.oxdna.oxdna:make_bin: /bin/make
INFO:jax_dna.simulators.oxdna.oxdna:Updating oxDNA parameters
DEBUG:jax_dna.simulators.oxdna.oxdna:build_dir: /home/ryanhausen/repos/oxDNA/build
DEBUG:jax_dna.simulators.oxdna.oxdna:running cmake: std_out->/home/ryanhausen/repos/oxDNA/build/jax_dna.cmake.std.log, std_err->/home/ryanhausen/repos/oxDNA/build/jax_dna.cmake.err.log
DEBUG:jax_dna.simulators.oxdna.oxdna:cmake completed
DEBUG:jax_dna.simulators.oxdna.oxdna:running make with 4 processes: std_out->/home/ryanhausen/repos/oxDNA/build/jax_dna.make.std.log, std_err->/home/ryanhausen/repos/oxDNA/build/jax_dna.make.err.log
INFO:jax_dna.simulators.oxdna.oxdna:oxDNA binary rebuilt
INFO:jax_dna.simulators.oxdna.oxdna:oxDNA input file: ../data/templates/simple-helix/input
INFO:jax_dna.simulators.oxdna.oxdna:Starting oxDNA simulation
DEBUG:jax_dna.simulators.oxdna.oxdna:oxDNA std_out->../data/templates/simple-helix/oxdna.out

Step: 0, loss: 0.9820481545404647
Step: 0, prop_twist: 20.717951845459535
Step: 0, neff: 0.9999999999999996
Step: 0, target_ptwist: 21.7
Step: 10, loss: 0.8209214624624472
Step: 10, prop_twist: 20.879078537537552
Step: 10, neff: 0.9969394742960719
Step: 10, target_ptwist: 21.7
Step: 20, loss: 0.8405092127126608
Step: 20, prop_twist: 20.85949078728734
Step: 20, neff: 0.9978779769545042
Step: 20, target_ptwist: 21.7
