In [None]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["XLA_FLAGS"] = (
#     "--xla_disable_hlo_passes=constant_folding"  # this disables constant folding
# )
from desc import set_device
set_device("gpu")

In [2]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [3]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [4]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

DESC version=0.14.2+685.ga8c422804.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: NVIDIA GeForce RTX 4080 Laptop GPU (id=0), with 11.99 GB available memory.


In [5]:
from desc.particles import *

In [6]:
eqi = get("precise_QA")
# eqi = load("eq_opt_low_mirror_8_07_20.h5")

eq = rescale(eq=eqi, L=("a", 1.7044), B=("<B>", 5.86), copy=True)
eqi_scaled = eq.copy()
fig = plot_3d(eq, "|B|", alpha=0.5)

# eq.iota = eq.get_profile("iota", kind="power_series")



In [None]:
theta0 = 0
N = 5  # particles traced
RHO0 = 0.2 * np.ones(N)

xi0 = np.random.rand(N)
# xi0 = 0.9
print(xi0)
model_flux = VacuumGuidingCenterTrajectory(frame="flux")
particles_flux = ManualParticleInitializerFlux(
    rho0=RHO0,
    theta0=np.random.rand(N) * 2 * np.pi,
    zeta0=jnp.zeros_like(RHO0),
    xi0=xi0,
    E=3.5e6,
    m=4.0,
    q=2.0,
)
dt = 0.01  / particles_flux.v0[0]
print(f"dt = {dt:.2e}")
from diffrax import Euler

ts = np.linspace(0, 1e-6, 100)
fig = plot_3d(eq, "|B|", alpha=0.5)
plot_particle_trajectories(
    eq, model_flux, particles_flux, ts, 
    bounds_R=(0,1), fig=fig, solver=Euler(), 
    min_step_size=dt
)

In [None]:
obj = ObjectiveFunction(
    DirectParticleTracing(
        eq, 
        particles=particles_flux, 
        model=model_flux,
        solver=Euler(),
        ts=np.linspace(0, 1e-6, 100),
    )
)
obj.build(use_jit=True)

In [None]:
obj.compute_scaled_error(obj.x(eq))

In [None]:
g = obj.grad(obj.x(eq))

In [None]:
print(f"{max(g):.2e}")