In [1]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "../jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

In [3]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [4]:
from desc.backend import jax

# jax.devices("cpu")

In [5]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [6]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic

print_backend_info()

DESC version=0.13.0+1687.gf75ae6abf.dirty.
Using JAX backend: jax version=0.5.0, jaxlib version=0.5.0, dtype=float64.
CPU Info:  13th Gen Intel(R) Core(TM) i5-1335U CPU with 8.30 GB available memory


In [10]:
method = "jac_scaled_error"
spline = False
eq = desc.examples.get("W7-X")
eq.change_resolution(6, 6, 6, 12, 12, 12)
num_transit = 10
objective = ObjectiveFunction(
    EffectiveRipple(
        eq,
        grid=LinearGrid(
            rho=[0.4, 1.0], M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=False
        ),
        num_transit=num_transit,
        num_well=10 * num_transit,
        num_quad=16,
        spline=spline,
    ),
)
objective.build()
constraint = ObjectiveFunction(
    ForceBalance(eq),
)
prox = ProximalProjection(objective, constraint, eq)
prox.build(eq)
x = prox.x(eq)
J = getattr(prox, method)(x, prox.constants).block_until_ready()
J.shape



Building objective: Effective ripple
Building objective: force
Precomputing transforms


(2, 184)

In [8]:
prox._constraint._deriv_mode

'batched'

In [None]:
# TODO: this is for debugging purposes, must be deleted before merging!
# loops over objectives without using mpi
@jit_if_not_parallel
def _proximal_jvp_blocked_test(objective, vgs, xgs, op):
    out = []
    for k, (obj, const) in enumerate(zip(objective.objectives, objective.constants)):
        # TODO: this is for debugging purposes, must be deleted before merging!
        if objective._is_multi_device:
            print(f"This should run on GPU id:{obj._device_id}")
        thing_idx = objective._things_per_objective_idx[k]
        xi = [xgs[i] for i in thing_idx]
        vi = [vgs[i] for i in thing_idx]
        if objective._is_multi_device:  # pragma: no cover
            # inputs to jitted functions must live on the same device. Need to
            # put xi and vi on the same device as the objective
            xi = jax.device_put(xi, obj._device)
            vi = jax.device_put(vi, obj._device)
        assert len(xi) > 0
        assert len(vi) > 0
        assert len(xi) == len(vi)
        if obj._deriv_mode == "rev":
            # obj might not allow fwd mode, so compute full rev mode jacobian
            # and do matmul manually. This is slightly inefficient, but usually
            # when rev mode is used, dim_f <<< dim_x, so its not too bad.
            Ji = getattr(obj, "jac_" + op)(*xi, constants=const)
            outi = jnp.array([Jii @ vii.T for Jii, vii in zip(Ji, vi)]).sum(axis=0)
            out.append(outi)
        else:
            outi = getattr(obj, "jvp_" + op)([_vi for _vi in vi], xi, constants=const).T
            out.append(outi)
    if objective._is_multi_device:  # pragma: no cover
        out = pconcat(out)
    else:
        out = jnp.concatenate(out)
    return -out