In [None]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["XLA_FLAGS"] = (
#     "--xla_disable_hlo_passes=constant_folding"  # this disables constant folding
# )
from desc import set_device
set_device("gpu")

In [2]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [3]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [4]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

DESC version=0.15.0+518.g3f71da918.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: NVIDIA GeForce RTX 4080 Laptop GPU (id=0), with 10.90 GB available memory.


In [5]:
from desc.particles import *
from diffrax import *

In [6]:
def test_tracing_vacuum_tokamak():
    """Test particle tracing in a vacuum tokamak."""
    rmajor = 4.0
    rminor = 1.0
    ts = np.linspace(0, 1e-7, 10)
    R0 = rmajor + rminor / 2

    # Create a vacuum tokamak equilibrium with a FourierRZToroidalSurface
    surf = FourierRZToroidalSurface(
        R_lmn=np.array([rmajor, rminor]),
        modes_R=np.array([[0, 0], [1, 0]]),
        Z_lmn=np.array([0, -1]),
        modes_Z=np.array([[0, 0], [-1, 0]]),
    )
    eq = Equilibrium(surface=surf, L=8, M=8, N=0, Psi=3)
    eq.solve(verbose=1)

    particles = ManualParticleInitializerLab(R0=R0, phi0=0, Z0=0.0, xi0=0.9, E=1e3)
    model = VacuumGuidingCenterTrajectory(frame="flux")

    # Particle tracing compute the field on individual points as grid which
    # is not enough to compute iota profile. Instead find the iota profile before
    # and assign it to the equilibrium as a hack. For this test, not very
    # necessary since iota is 0.
    eq.iota = eq.get_profile("iota")

    # Initialize particles
    x0, args = particles.init_particles(model=model, field=eq)
    m, q, _ = args[0, :]
    # Ensure particles stay within the surface by bounds_R (not actually
    # needed here since the tracing time is chosen accordingly, but this
    # is the intended use case).
    rtz, vpar = trace_particles(
        field=eq,
        initializer=particles,
        model=model,
        ts=ts,
    )
    return rtz, vpar

In [7]:
def test_tracing_purely_toroidal_magnetic_field():
    """Test particle tracing within a purely toroidal magnetic field."""
    B0 = 1.0  # Constant magnetic field strength
    Rmajor = 3.0  # Major radius of the toroidal field
    R0 = 4.0  # Initial radial position of the particle
    ts = np.linspace(0, 1e-6, 100)
    # B_phi = B0 * Rmajor / r  # noqa : E800
    field = ToroidalMagneticField(B0=B0, R0=Rmajor)
    particles = ManualParticleInitializerLab(R0=R0, phi0=0, Z0=0, xi0=0.9, E=1e8)
    model = VacuumGuidingCenterTrajectory(frame="lab")
    x0, args = particles.init_particles(model=model, field=field)
    m, q, _ = args[0, :]
    rpz, vpar = trace_particles(
        field=field,
        initializer=particles,
        model=model,
        ts=ts,
    )
    return rpz, vpar

In [8]:
test_tracing_purely_toroidal_magnetic_field()

(Array([[[ 4.    ,  0.    ,  0.    ],
         [ 4.    ,  0.1573,  0.3048],
         [ 4.    ,  0.3146,  0.6095],
         [ 4.    ,  0.4719,  0.9143],
         [ 4.    ,  0.6292,  1.219 ],
         [ 4.    ,  0.7865,  1.5238],
         [ 4.    ,  0.9438,  1.8285],
         [ 4.    ,  1.1011,  2.1333],
         [ 4.    ,  1.2584,  2.438 ],
         [ 4.    ,  1.4157,  2.7428],
         [ 4.    ,  1.573 ,  3.0475],
         [ 4.    ,  1.7303,  3.3523],
         [ 4.    ,  1.8875,  3.657 ],
         [ 4.    ,  2.0448,  3.9618],
         [ 4.    ,  2.2021,  4.2665],
         [ 4.    ,  2.3594,  4.5713],
         [ 4.    ,  2.5167,  4.876 ],
         [ 4.    ,  2.674 ,  5.1808],
         [ 4.    ,  2.8313,  5.4855],
         [ 4.    ,  2.9886,  5.7903],
         [ 4.    ,  3.1459,  6.0951],
         [ 4.    ,  3.3032,  6.3998],
         [ 4.    ,  3.4605,  6.7046],
         [ 4.    ,  3.6178,  7.0093],
         [ 4.    ,  3.7751,  7.3141],
         [ 4.    ,  3.9324,  7.6188],
         [ 4

In [10]:
test_tracing_vacuum_tokamak()

Building objective: force
Precomputing transforms
Building objective: lcfs R
Building objective: lcfs Z
Building objective: fixed Psi
Building objective: fixed pressure
Building objective: fixed current
Building objective: fixed sheet current
Building objective: self_consistency R
Building objective: self_consistency Z
Building objective: lambda gauge
Building objective: axis R self consistency
Building objective: axis Z self consistency
Number of parameters: 48
Number of objectives: 162

Starting optimization
Using method: lsq-exact
`gtol` condition satisfied. (gtol=1.00e-08)
         Current function value: 2.988e-16
         Total delta_x: 2.391e-01
         Iterations: 13
         Function evaluations: 18
         Jacobian evaluations: 14
                                                                 Start  -->   End
Total (sum of squares):                                      5.940e-02  -->   2.988e-16, 
Maximum absolute Force error:                                7.152e+05  -->



(Array([[[0.5018, 0.    , 0.    ],
         [0.5018, 6.2832, 0.0005],
         [0.5018, 6.2832, 0.001 ],
         [0.5018, 6.2832, 0.0015],
         [0.5018, 6.2832, 0.0019],
         [0.5018, 6.2832, 0.0024],
         [0.5018, 6.2832, 0.0029],
         [0.5018, 6.2831, 0.0034],
         [0.5018, 6.2831, 0.0039],
         [0.5018, 6.2831, 0.0044]]], dtype=float64),
 Array([[[196975.1969],
         [196975.1969],
         [196975.1969],
         [196975.1969],
         [196975.1969],
         [196975.1969],
         [196975.1969],
         [196975.1969],
         [196975.1969],
         [196975.1969]]], dtype=float64))

In [13]:
B0 = 1.0  # Constant magnetic field strength
Rmajor = 3.0  # Major radius of the toroidal field
R0 = 4.0  # Initial radial position of the particle
ts = np.linspace(0, 1e-6, 100)
# B_phi = B0 * Rmajor / r  # noqa : E800
field = ToroidalMagneticField(B0=B0, R0=Rmajor)
particles = ManualParticleInitializerLab(R0=R0, phi0=0, Z0=0, xi0=0.9, E=1e8)
model = VacuumGuidingCenterTrajectory(frame="lab")
x0, args = particles.init_particles(model=model, field=field)
m, q, _ = args[0, :]

jax.make_jaxpr(model.vf)(0, x0, (args[0, :], field, field.params_dict, {}))

let _where = { [34;1mlambda [39;22m; a[35m:bool[][39m b[35m:i64[][39m c[35m:f64[][39m. [34;1mlet
    [39;22md[35m:f64[][39m = convert_element_type[new_dtype=float64 weak_type=False] b
    e[35m:f64[][39m = select_n a c d
  [34;1min [39;22m(e,) } in
{ [34;1mlambda [39;22m; f[35m:i64[][39m g[35m:f64[1,4][39m h[35m:f64[3][39m i[35m:f64[][39m j[35m:f64[][39m k[35m:f64[1][39m l[35m:f64[1][39m. [34;1mlet
    [39;22mm[35m:f64[4][39m = pjit[
      name=vf
      jaxpr={ [34;1mlambda [39;22m; f[35m:i64[][39m g[35m:f64[1,4][39m h[35m:f64[3][39m i[35m:f64[][39m j[35m:f64[][39m k[35m:f64[1][39m l[35m:f64[1][39m. [34;1mlet
          [39;22mn[35m:f64[4][39m = squeeze[dimensions=(0,)] g
          o[35m:f64[1][39m = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] h
          p[35m:f64[][39m = squeeze[dimensions=(0,)] o
          q[35m:f64[1][39m = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] h
          r[35m:f64[][3

In [16]:
rmajor = 4.0
rminor = 1.0
ts = np.linspace(0, 1e-7, 10)
R0 = rmajor + rminor / 2

# Create a vacuum tokamak equilibrium with a FourierRZToroidalSurface
surf = FourierRZToroidalSurface(
    R_lmn=np.array([rmajor, rminor]),
    modes_R=np.array([[0, 0], [1, 0]]),
    Z_lmn=np.array([0, -1]),
    modes_Z=np.array([[0, 0], [-1, 0]]),
)
eq = Equilibrium(surface=surf, L=8, M=8, N=0, Psi=3)
eq.solve(verbose=0)

particles = ManualParticleInitializerLab(R0=R0, phi0=0, Z0=0.0, xi0=0.9, E=1e3)
model = VacuumGuidingCenterTrajectory(frame="flux")

# Particle tracing compute the field on individual points as grid which
# is not enough to compute iota profile. Instead find the iota profile before
# and assign it to the equilibrium as a hack. For this test, not very
# necessary since iota is 0.
eq.iota = eq.get_profile("iota")

pr_out = jax.make_jaxpr(model.vf)(0, x0, (args[0, :], eq, eq.params_dict, {}))



In [17]:
print(pr_out)

let atleast_2d = { lambda ; a:f64[3]. let
    b:f64[1,3] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] a
  in (b,) } in
let _where = { lambda ; c:bool[1] d:f64[] e:f64[1]. let
    f:f64[] = convert_element_type[new_dtype=float64 weak_type=False] d
    g:f64[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] f
    h:f64[1] = select_n c e g
  in (h,) } in
let atleast_2d1 = { lambda ; i:f64[1]. let
    j:f64[1,1] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 1)
      sharding=None
    ] i
  in (j,) } in
let zernike_radial = { lambda ; k:f64[1,1] l:i64[25] m:i64[25]. let
    n:i64[25] = abs m
    o:f64[25] = convert_element_type[new_dtype=float64 weak_type=False] n
    p:f64[25] = convert_element_type[new_dtype=float64 weak_type=False] l
    q:f64[25] = sub p o
    r:f64[25] = pjit[
      name=floor_divide
      jaxpr={ lambda ; q:f64[25] s:i64[]. let
          t:f64[] = co