In [None]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# from desc import set_device
# set_device("gpu")

In [2]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [3]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [4]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

DESC version=0.14.2+633.g5e938954b.dirty.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: CPU, with 10.84 GB available memory.


In [5]:
from diffrax import Euler, Heun
from desc.particles import *

In [None]:
def custom_B(coords, params):
    """Custom magnetic field function."""
    xyz = rpz2xyz(coords)
    X, Y, Z = xyz.T
    B0, dB = params
    B = jnp.zeros_like(coords)
    B = B.at[:, 2].set(B0 + dB * X)
    B = xyz2rpz_vec(B, phi=coords[:, 1])
    return B

In [None]:
from desc.particles import *
field = ToroidalMagneticField(B0=1.0, R0=3.0)
R0 = np.array([3.1, 4.0])
particles = ManualParticleInitializerLab(R0=R0, phi0=0, Z0=0, xi0=1, E=1e8)
model = VacuumGuidingCenterTrajectory(frame="lab")
ts = np.linspace(0, 1e-6, 1000)
plot_particle_trajectories(field=field,model=model, initializer=particles, ts=ts, return_data=False)

In [None]:
surf = FourierRZToroidalSurface(
    R_lmn = np.array([4, 1]),
    modes_R = np.array([[0, 0], [1, 0]]),
    Z_lmn = np.array([0, -1]),
    modes_Z = np.array([[0, 0], [-1, 0]]),
)
eq = Equilibrium(surface=surf, M=4, N=0, Psi=3)
eq.solve(verbose=3)
plot_section(eq, "|B|")

In [None]:
from diffrax import Euler, Heun
from desc.particles import *
R0 = np.array([4.2, 4.3])
particles = ManualParticleInitializerLab(R0=R0, phi0=0, Z0=0, xi0=1, E=1e8, eq=eq)
model = VacuumGuidingCenterTrajectory(frame="flux")
ts = np.linspace(0, 1e-9, 100)
plot_particle_trajectories(field=eq, model=model, initializer=particles, ts=ts, return_data=False, solver=Euler(), min_step_size=1e-20, max_steps=1e10)

In [6]:
def f1(x, eq, m, q, mu):
    rho, theta, zeta, vpar = x
    grid = Grid(
        jnp.array([rho, theta, zeta]).T,
        spacing=jnp.zeros((3,)).T,
        jitable=True,
        sort=False,
    )
    data_keys = [
        "B",
        "|B|",
        "grad(|B|)",
        "e^rho",
        "e^theta",
        "e^zeta",
        "b",
    ]

    transforms = get_transforms(data_keys, eq, grid, jitable=True)
    profiles = get_profiles(data_keys, eq, grid)
    data = compute_fun(eq, data_keys, eq.params_dict, transforms, profiles)

    # derivative of the guiding center position in R, phi, Z coordinates
    Rdot = vpar * data["b"] + (
        (m / q / data["|B|"] ** 2)
        * ((mu * data["|B|"]) + vpar**2)
        * cross(data["b"], data["grad(|B|)"])
    )
    rhodot = dot(Rdot, data["e^rho"])
    thetadot = dot(Rdot, data["e^theta"])
    zetadot = dot(Rdot, data["e^zeta"])
    vpardot = -mu * dot(data["b"], data["grad(|B|)"])
    dxdt = jnp.array([rhodot, thetadot, zetadot, vpardot]).reshape(x.shape)
    return dxdt.squeeze()

def f2(x, field, m, q, mu, **kwargs):
    vpar = x[-1]
    coord = x[:-1]

    field_compute = lambda y: jnp.linalg.norm(
        field.compute_magnetic_field(y, **kwargs), axis=-1
    ).squeeze()

    # magnetic field vector in R, phi, Z coordinates
    B = field.compute_magnetic_field(coord, **kwargs)
    grad_B = Derivative(field_compute, mode="grad")(coord)

    modB = jnp.linalg.norm(B, axis=-1)
    b = B / modB
    # factor of R from grad in cylindrical coordinates
    grad_Bphi = safediv(grad_B[1], coord[0])
    grad_B = grad_B.at[1].set(grad_Bphi)
    Rdot = vpar * b + (m / q / modB**2 * (mu * modB + vpar**2)) * cross(b, grad_B)

    vpardot = jnp.atleast_2d(-mu * dot(b, grad_B))
    dxdt = jnp.hstack([Rdot, vpardot.T]).reshape(x.shape)
    return dxdt.squeeze()

In [7]:
eq = get("precise_QA")
field = desc.io.load("../tests/inputs/precise_QA_helical_coils.h5")

In [8]:
r = jnp.linspace(0.1, 0.99, 2000)

particles = ManualParticleInitializerFlux(
    rho0 = r,
    theta0 = 0.0,
    zeta0 = 0.0,
    xi0 = 1.0,
    E = 1e3,
    eq=eq,
)
model = VacuumGuidingCenterTrajectory(frame="flux")
model_lab = VacuumGuidingCenterTrajectory(frame="lab")

In [9]:
x0, args = particles.init_particles(model, eq)

In [10]:
x0_lab, args_lab = particles.init_particles(model_lab, field)



In [11]:
m, q, mu = args
m = m[0]
q = q[0]    
mu = mu[0]

In [12]:
intf1 = lambda x: f1(x, eq, m, q, mu)
intf2 = lambda x: f2(x, field, m, q, mu)

In [13]:
s1 = jax.vmap(intf1)(x0)

In [14]:
s2 = jax.vmap(intf2)(x0_lab)