In [1]:
import jax
import jax.numpy as jnp

jax.config.update("jax_compilation_cache_dir", "../../jax-caches")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))

In [3]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [4]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic

DESC version 0.13.0+1106.g528c17c1d.dirty,using JAX backend, jax version=0.4.37, jaxlib version=0.4.36, dtype=float64
Using device: CPU, with 11.44 GB available memory


In [None]:
def create_parallel_force_obj(eq, num_device):
    rhos = jnp.linspace(0.01, 1.0, num_device)
    objs = ()
    for i in range(num_device):
        obj = ForceBalance(
            eq,
            grid=LinearGrid(
                rho=rhos[i], M=int(eq.M_grid * i / len(rhos)), N=eq.N_grid, NFP=eq.NFP
            ),
        )
        objs += (obj,)
    return objs

In [6]:
eq = get("HELIOTRON")
objs = create_parallel_force_obj(eq, 10)
obj = ObjectiveFunction(objs)
cons = get_fixed_boundary_constraints(eq)

In [None]:
eq.solve(maxiter=2, objective=obj, constraints=cons, ftol=0, gtol=0, xtol=0, verbose=3)

In [None]:
rhos = jnp.linspace(0.01, 1.0, 10)
for i in range(len(rhos)):
    grid = LinearGrid(
        rho=rhos[i], M=int(eq.M_grid * i / len(rhos)), N=eq.N_grid, NFP=eq.NFP
    )
    plot_grid(grid)

In [7]:
eq = get("HELIOTRON")
objs = get_parallel_forcebalance(eq, 10, check_device=False)

In [8]:
obj = ObjectiveFunction(objs)
obj.build()
%timeit obj.jac_scaled_error(obj.x(eq))

Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
21.5 s ± 475 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
J = obj.jac_scaled_error(obj.x(eq))
J.shape

(17056, 1977)

In [9]:
eq = get("HELIOTRON")
grid = LinearGrid(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP)
obji = ObjectiveFunction(ForceBalance(eq, grid=grid))
obji.build()
%timeit obji.jac_scaled_error(obji.x(eq))

Building objective: force
Precomputing transforms
47.7 s ± 1.89 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
J = obji.jac_scaled_error(obji.x(eq))
J.shape

(35594, 1977)