In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
from desc import set_device
set_device("gpu")

In [2]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
# for sake of simplicity, we will just import everything from desc
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

print_backend_info()

DESC version=0.16.0+578.ga5e3fa78f.dirty.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: NVIDIA GeForce RTX 4080 Laptop GPU (id=0), with 11.46 GB available memory.


In [4]:
# includes all particle tracing capabilities and particle initialization
from desc.particles import *
# includes different solvers, stepsize controller and bunch of other tools for ODE integration
from diffrax import *

# Alpha Particle Trajectory Optimization in DESC

To run the notebook, please checkout to `rc/particles` branch,

```bash
git checkout rc/particles
git pull
```

The results presented in this notebook are not great, one needs to run it with more particles, longer integration time and possibly better objective weights to get better results. The aim for this notebook is to primarly describe the use of particle tracing in an objective. Please reach out to Yigit Gunsur Elmacioglu (yigit.elma@princeton.edu) for further questions.

In [5]:
from desc.objectives.objective_funs import _Objective
from desc.particles import _trace_particles


class DirectParticleTracing(_Objective):
    """Confinement metric for radial transport from direct tracing.

    Traces particles in flux coordinates within the equilibrium, and
    returns a confinement metric based off of the average deviation of
    the particle trajectory from its initial flux surface. The trajectories
    are traced and a line is fitted to the radial position vs time,
    and the slope of this line is used as the metric.

    Parameters
    ----------
    eq : Equilibrium
        Equilibrium that will be optimized to satisfy the Objective.
    iota_grid : Grid, optional
        Grid to evaluate rotational transform profile on.
        Defaults to ``LinearGrid(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid)``.
    particles : ParticleInitializer
        should initialize them in flux coordinates, same seed
        will be used each time.
    model : TrajectoryModel
        should be either Vacuum or SlowingDown

    """

    __doc__ = __doc__.rstrip() + collect_docs(
        target_default="``target=0``.", bounds_default="``target=0``."
    )
    _static_attrs = _Objective._static_attrs + [
        "_trace_particles",
        "_max_steps",
        "_stepsize_controller",
        "_adjoint",
        "_event",
        "_particle_chunk_size",
        "_solver",
    ]

    _coordinates = "rtz"
    _units = "(dimensionless)"
    _print_value_fmt = "Particle Confinement error: "

    def __init__(
        self,
        eq,
        particles,
        model,
        solver=Tsit5(),  # on CPU, Tsit5(scan_kind="bounded") is recommended
        ts=jnp.arange(0, 1e-3, 100),
        stepsize_controller=None,
        adjoint=RecursiveCheckpointAdjoint(),
        max_steps=None,
        min_step_size=1e-8,
        particle_chunk_size=None,
        target=None,
        bounds=None,
        weight=1,
        normalize=False,
        normalize_target=False,
        loss_function=None,
        deriv_mode="auto",
        name="Particle Confinement",
        jac_chunk_size=None,
    ):
        if target is None and bounds is None:
            target = 0
        self._ts = jnp.asarray(ts)
        self._adjoint = adjoint
        if max_steps is None:
            max_steps = 1
            max_steps = int((ts[-1] - ts[0]) / min_step_size * max_steps)
        self._max_steps = max_steps
        self._min_step_size = min_step_size
        self._stepsize_controller = (
            stepsize_controller
            if stepsize_controller is not None
            else PIDController(
                rtol=1e-4,
                atol=1e-4,
                dtmin=min_step_size,
                pcoeff=0.3,
                icoeff=0.3,
                dcoeff=0,
            )
        )
        assert model.frame == "flux", "can only trace in flux coordinates"
        self._model = model
        self._particles = particles
        self._solver = solver
        self._particle_chunk_size = particle_chunk_size
        self._interpolator = FourierChebyshevField(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid)
        super().__init__(
            things=eq,
            target=target,
            bounds=bounds,
            weight=weight,
            normalize=normalize,
            normalize_target=normalize_target,
            loss_function=loss_function,
            deriv_mode=deriv_mode,
            name=name,
            jac_chunk_size=jac_chunk_size,
        )

    def build(self, use_jit=True, verbose=1):
        """Build constant arrays.

        Parameters
        ----------
        use_jit : bool, optional
            Whether to just-in-time compile the objective and derivatives.
        verbose : int, optional
            Level of output.

        """
        eq = self.things[0]
        self._x0, self._model_args = self._particles.init_particles(
            model=self._model, field=eq
        )

        # one metric per particle
        self._dim_f = self._x0.shape[0]
        # self._dim_f = 1

        # tracing uses carteasian coordinates internally, the termainating event
        # must look at rho values by conversion
        def default_event(t, y, args, **kwargs):
            i = jnp.sqrt(y[0] ** 2 + y[1] ** 2)
            return i > 1.0

        self._event = Event(default_event)
        # self._event = None

        timer = Timer()
        if verbose > 0:
            print("Precomputing transforms")
        timer.start("Precomputing transforms")
        self._interpolator.build(eq)
        eq = self.things[0]
        self._interpolator.fit(eq.params_dict, {"iota": eq.iota, "current": eq.current})

        timer.stop("Precomputing transforms")
        if verbose > 1:
            timer.disp("Precomputing transforms")

        super().build(use_jit=use_jit, verbose=verbose)

    def compute(self, params, constants=None):
        """Compute particle tracing metric errors.

        Parameters
        ----------
        params : dict
            Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict
        constants : dict
            Dictionary of constant data, eg transforms, profiles etc. Defaults to
            self.constants

        Returns
        -------
        f : ndarray
            Average deviation in rho from initial surface, for each particle.
        """
        eq = self.things[0]
        self._interpolator.fit(params, {"iota": eq.iota, "current": eq.current})
        rpz, _ = _trace_particles(
            field=self._interpolator,
            y0=self._x0,
            model=self._model,
            model_args=self._model_args,
            ts=self._ts,
            params=None,
            stepsize_controller=self._stepsize_controller,
            saveat=SaveAt(ts=self._ts),
            max_steps=self._max_steps,
            min_step_size=self._min_step_size,
            solver=self._solver,
            adjoint=self._adjoint,
            event=self._event,
            options={},
            chunk_size=self._particle_chunk_size,
            throw=False,
            return_aux=False,
        )

        # rpz is shape [N_particles, N_time, 3], take just index rho
        rhos = rpz[:, :, 0]
        rho0s = self._x0[:, 0]

        def fit_line(y):
            ts = self._ts
            # replace nans with zeros, since (0,0) is already the initial
            # point, this will not affect the fit
            y = jnp.where(jnp.isnan(y), 0.0, y)
            ts = jnp.where(jnp.isnan(y), 0.0, ts)
            coeffs = jnp.polyfit(ts, y, 1)
            return coeffs[0]

        slopes = vmap(fit_line)(rhos - rho0s[:, None]) * self._ts[-1]
        return slopes

In [6]:
name = "precise_QA"
try:
    # if the file exists, load it
    eq = desc.io.load(f"eqs/{name}_vacuum_scaled_solved.h5")
    eqi_scaled = eq.copy()
except:
    # else, create it from scratch
    eqi = get(name)
    eq = rescale(eq=eqi, L=("a", 1.7044), B=("<B>", 5.86), copy=True)
    eq.pressure = 0
    eq.current = 0
    eq.solve(ftol=1e-4, verbose=1);
    eqi_scaled = eq.copy()
    eq.save(f"eqs/{name}_vacuum_scaled_solved.h5")

# The Vacuum Guiding Center model assumes a constant pressure profile and zero current
# If the equilibrium does not satisfy these conditions, raise an error.
if (eq.p_l[1:] != 0).any():
    raise ValueError("Equilibrium doesn't have constant pressure, please use a vacuum equilibrium.")
if (eq.c_l != 0).any():
    raise ValueError("Equilibrium has non-zero current, please use a vacuum equilibrium.")

In [7]:
# create N particles between rho=0.1 and rho=0.3 randomly
N = 800  # number of particles traced
RHO0 = 0.1 + np.random.rand(N) * 0.2

model_flux = VacuumGuidingCenterTrajectory(frame="flux")
particles_flux = ManualParticleInitializerFlux(
    rho0=RHO0,
    theta0=np.random.rand(N) * 2 * np.pi,
    zeta0=np.random.rand(N) * 2 * np.pi,
    xi0=np.random.rand(N), # add negative region too
    E=3.5e6,
)

In [8]:
obj1 = ObjectiveFunction(
    DirectParticleTracing(
        eq,
        particles=particles_flux,
        model=model_flux,
        solver=Tsit5(),
        ts=np.linspace(0, 1e-4, 100),
        min_step_size=1e-8,
        max_steps=1000,
        stepsize_controller=PIDController(
            rtol=1e-4,
            atol=1e-4,
            dtmin=1e-8,
            pcoeff=0.3,
            icoeff=0.3,
            dcoeff=0,
        ),
        adjoint=ForwardMode(),  # default is RecursiveCheckpointAdjoint() (reverse mode)
        deriv_mode="fwd",
    )
)
obj2 = ObjectiveFunction(
    DirectParticleTracing(
        eq,
        particles=particles_flux,
        model=model_flux,
        solver=Tsit5(),
        ts=np.linspace(0, 1e-4, 100),
        min_step_size=5e-7,
        max_steps=1000,
        stepsize_controller=ConstantStepSize(),
        adjoint=ForwardMode(),  # default is RecursiveCheckpointAdjoint() (reverse mode)
        deriv_mode="fwd",
    )
)
obj1.build()
obj2.build()

Building objective: Particle Confinement
Precomputing transforms
Building objective: Particle Confinement
Precomputing transforms


In [9]:
f1 = obj1.compute_scaled_error(obj1.x(eq))

In [10]:
f2 = obj2.compute_scaled_error(obj2.x(eq))

In [11]:
%timeit obj1.compute_scaled_error(obj1.x(eq)).block_until_ready()
%timeit obj2.compute_scaled_error(obj2.x(eq)).block_until_ready()

4.09 s ± 53.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.74 s ± 107 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
with jax.log_compiles():
    g1 = obj1.grad(obj1.x(eq), obj1.constants)

In [None]:
g2 = obj2.grad(obj2.x(eq), obj2.constants)

In [12]:
v = jnp.zeros_like(obj1.x(eq))
v = v.at[0].set(1.0)
with jax.log_compiles():
    g1 = obj1.jvp_scaled_error(v, obj1.x(eq), obj1.constants)

