In [None]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["XLA_FLAGS"] = (
#     "--xla_disable_hlo_passes=constant_folding "  # this disables constant folding
#     # "--xla_cpu_use_thunk_runtime=false "
# )
# from desc import set_device
# set_device("gpu")

In [2]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [3]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [4]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

DESC version=0.15.0+533.g89d0c704c.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: CPU, with 17.81 GB available memory.


In [5]:
from desc.particles import *
from diffrax import *

In [None]:
class VacuumGuidingCenterTrajectoryTest(VacuumGuidingCenterTrajectory):
    def _compute_flux_coordinates(self, x, eq, params, m, q, mu, **kwargs):
        xp, yp, zeta, vpar = x
        rho = jnp.sqrt(xp**2 + yp**2)
        theta = jnp.arctan2(yp, xp)
        # compute functions are not correct for very small rho
        rho = jnp.where(rho < 1e-6, 1e-6, rho)
        iota = kwargs.get("iota", None)
        transforms = kwargs.get("transforms", None)
        grid = transforms["grid"]
        assert grid._nodes.shape == jnp.array([[rho, theta, zeta]]).shape
        grid._nodes = jnp.array([[rho, theta, zeta]])
        data_keys = [
            "B",
            "|B|",
            "grad(|B|)",
            "e^rho",
            "e^theta",
            "e^zeta",
            "b",
        ]

        for key in ["R", "Z", "L"]:
            transforms[key]._built = False
            transforms[key]._grid = grid
            transforms[key].build()
        profiles = { "iota": eq.iota}
        if iota is not None:
            profiles["iota"] = iota
        data = compute_fun(
            eq,
            data_keys,
            params,
            transforms,
            profiles,
        )

        # derivative of the guiding center position in R, phi, Z coordinates
        Rdot = vpar * data["b"] + (
            (m / q / data["|B|"] ** 2)
            * ((mu * data["|B|"] / m) + vpar**2)
            * cross(data["b"], data["grad(|B|)"])
        )
        # take dot product for rho, theta and zeta coordinates
        rhodot = dot(Rdot, data["e^rho"])
        thetadot = dot(Rdot, data["e^theta"])
        zetadot = dot(Rdot, data["e^zeta"])
        # get the derivative for cartesian-like coordinates
        xpdot = rhodot * jnp.cos(theta) - rho * thetadot * jnp.sin(theta)
        ypdot = rhodot * jnp.sin(theta) + rho * thetadot * jnp.cos(theta)
        # derivative the parallel velocity
        vpardot = -mu / m * dot(data["b"], data["grad(|B|)"])
        dxdt = jnp.array([xpdot, ypdot, zetadot, vpardot]).reshape(x.shape)
        return dxdt.squeeze()

In [None]:
eq = get("DSHAPE")
model = VacuumGuidingCenterTrajectory(frame="flux")
modeltest = VacuumGuidingCenterTrajectoryTest(frame="flux")
particles = ManualParticleInitializerFlux(
    rho0=0.4,
    theta0=0,
    zeta0=0,
    xi0=0.3,
    E=1e4,
    m=4,
    q=2,
)
ts = np.linspace(0, 3e-6, 10)
data_keys = [
    "B",
    "|B|",
    "grad(|B|)",
    "e^rho",
    "e^theta",
    "e^zeta",
    "b",
]
grid = Grid(
    jnp.array([0.5018, 0, 0]).T,
    spacing=jnp.zeros((3,)).T,
    jitable=True,
)
transforms = get_transforms(data_keys, eq, grid, jitable=True)
x0, margs = particles.init_particles(model, eq)

In [None]:
def cond_fn(state):
    x, i = state
    return i < 3


def body(state):
    x, i = state
    dx = modeltest.vf(
        0,
        x,
        (
            margs[0],
            eq,
            eq.params_dict,
            {"transforms": transforms, "iota": eq.iota},
        ),
    )
    return (x + dx * 1e-8, i + 1)

# @jax.jit
def run_loop(x0):
    return jax.lax.while_loop(cond_fn, body, (x0[0], 0))

with jax.log_compiles():
    run_loop(x0)

In [None]:
def cond_fn(state):
    x, i = state
    return i < 5

def body(state):
    x, i = state
    dx = model.vf(
        0,
        x,
        (
            margs[0],
            eq,
            eq.params_dict,
            {
                "iota": eq.iota,
                "transforms": transforms,
            },
        ),
    )
    return (x+dx*1e-8, i + 1)

jax.lax.while_loop(cond_fn, body, (x0[0], 0))

In [None]:
def test(x0):    
    xf2 = x0[0]
    for i in range(5):
        dx = modeltest.vf(
            0,
            xf2,
            (
                margs[0],
                eq,
                eq.params_dict,
                {"grid": transforms["grid"], "transforms": transforms, "iota": eq.iota},
            ),
        )
        xf2 = xf2 + dx * 1e-8

test(x0)

In [6]:
rmajor = 4.0
rminor = 1.0
ts = np.linspace(0, 1e-6, 100)
R0 = rmajor + rminor / 2

# Create a vacuum tokamak equilibrium with a FourierRZToroidalSurface
surf = FourierRZToroidalSurface(
    R_lmn=np.array([rmajor, rminor]),
    modes_R=np.array([[0, 0], [1, 0]]),
    Z_lmn=np.array([0, -1]),
    modes_Z=np.array([[0, 0], [-1, 0]]),
)
eq = Equilibrium(surface=surf, L=8, M=8, N=0, Psi=3)
eq.solve(verbose=1)

particles = ManualParticleInitializerLab(R0=R0, phi0=0, Z0=0.0, xi0=0.9, E=1e6)
model = VacuumGuidingCenterTrajectory(frame="flux")

# Particle tracing compute the field on individual points as grid which
# is not enough to compute iota profile. Instead find the iota profile before
# and assign it to the equilibrium as a hack. For this test, not very
# necessary since iota is 0.
eq.iota = eq.get_profile("iota")

# Initialize particles
x0, args = particles.init_particles(model=model, field=eq)
m, q, _ = args[0, :]

Building objective: force
Precomputing transforms
Building objective: lcfs R
Building objective: lcfs Z
Building objective: fixed Psi
Building objective: fixed pressure
Building objective: fixed current
Building objective: fixed sheet current
Building objective: self_consistency R
Building objective: self_consistency Z
Building objective: lambda gauge
Building objective: axis R self consistency
Building objective: axis Z self consistency
Number of parameters: 48
Number of objectives: 162

Starting optimization
Using method: lsq-exact
`gtol` condition satisfied. (gtol=1.00e-08)
         Current function value: 2.988e-16
         Total delta_x: 2.391e-01
         Iterations: 13
         Function evaluations: 18
         Jacobian evaluations: 14
                                                                 Start  -->   End
Total (sum of squares):                                      5.940e-02  -->   2.988e-16, 
Maximum absolute Force error:                                7.152e+05  -->



In [7]:
data_keys = [
    "B",
    "|B|",
    "grad(|B|)",
    "e^rho",
    "e^theta",
    "e^zeta",
    "b",
]
grid = Grid(
    jnp.array([0.5018, 0, 0]).T,
    spacing=jnp.zeros((3,)).T,
    jitable=True,
)
transforms = get_transforms(data_keys, eq, grid, jitable=True)

In [8]:
with jax.log_compiles():
    rtz, vpar = trace_particles(
        field=eq,
        initializer=particles,
        model=model,
        ts=ts,
        solver=Tsit5(scan_kind="bounded"),
    )
# 0.7.2 -> 3.8 seconds
# 0.6.2 -> 200 seconds
# 0.6.1 -> 45 seconds
# 0.6.0 -> 42 seconds
# 0.5.3 -> 36 seconds
# 0.5.0 -> 30 seconds
# 0.4.38 -> 30 seconds



In [None]:
grid = Grid(rtz[0, :, :], jitable=True)
rpz = eq.compute("x", grid=grid)["x"]
# We will find the B0*r00/R field representation of the vacuum tokamak
# First, find the magnetic field at a random R position (equation doesn't
# depend on R as long as B0 and r00 are consistent)
# Then, the exact solution is the same as given in the
# test_tracing_purely_toroidal_magnetic_field above
grid = LinearGrid(rho=0.5, M=eq.M_grid, N=eq.N_grid)
data = eq.compute(["|B|", "x"], grid=grid)
B0 = grid.compress(data["|B|"])[0]
r00 = grid.compress(data["x"])[0, 0]
vd = m / (q * B0 * r00) * (particles.vpar0[0] ** 2 + particles.v0[0] ** 2) / 2
z_exact = vd * ts
# Angular velocity is constant and given by vpar0 / R0 in radians per second
# So, the exact phi position is given by phi(t) = vpar0 / R0 * t
# where vpar0 is the initial parallel velocity and R0 is the initial R.
phi_exact = particles.vpar0[0] / R0 * ts

assert np.allclose(rpz[:, 2], z_exact, atol=1e-12)
# There is no radial drift, R should remain constant
assert np.allclose(rpz[:, 0], R0, atol=1e-12)
# There is no mirror force, parallel velocity should be the same
assert np.allclose(vpar[0, :, 0], particles.vpar0, atol=1e-12)
# The phi position should be given by the angular velocity
assert np.allclose(rpz[:, 1], phi_exact, atol=1e-12)

In [None]:
plt.plot(ts, rpz[:, 2], label="traced")
plt.plot(ts, z_exact, "--", label="exact")
plt.legend();

In [None]:
plt.plot(ts, rpz[:, 0], label="traced")
plt.plot(ts, [R0]*len(ts), "--", label="exact")
plt.legend()

In [None]:
plt.plot(ts, rpz[:, 1], label="traced")
plt.plot(ts, phi_exact, "--", label="exact")
plt.legend()