In [None]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "../jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# from desc import set_device
# set_device("gpu")

In [None]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [None]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [None]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

In [None]:
from desc.particles import *

In [None]:
# field = VerticalMagneticField(B0=1.0)
eq = get("precise_QA")
R0 = jnp.array([1.2, 1.25])
particles = ManualParticleInitializerLab(
    R0=R0,
    phi0 = jnp.zeros_like(R0),
    Z0=jnp.zeros_like(R0),
    xi0=0.7*jnp.ones_like(R0),
    E = 1e-3,
    m = 4.0,
    q = 1.0,
    eq = eq,
)

RHO0 = jnp.array([0.5])
particles_flux = ManualParticleInitializerFlux(
    rho0=RHO0,
    theta0 = jnp.zeros_like(RHO0),
    zeta0=jnp.zeros_like(RHO0),
    xi0=0.1*jnp.ones_like(RHO0),
    E = 1e-3,
    m = 4.0,
    q = 1.0,
    eq = eq,
)

model = VacuumGuidingCenterTrajectory(frame="lab")
model_flux = VacuumGuidingCenterTrajectory(frame="flux")

In [None]:
ts=np.linspace(0, 1e-10, 1)
x0, args = particles_flux.init_particles(model=model_flux, field=eq)
ms, qs, mus = args[:3]
args

In [None]:
with jax.log_compiles():    
    rpz, _ = trace_particles(
        eq, x0, ms, qs, mus, model=model_flux, 
        ts=ts, min_step_size = 1e-10
    )
rpz

In [None]:
# Equilibrium is not necessary but helps to find initial particle positions
eq = desc.examples.get("precise_QA")
grid_trace = desc.grid.LinearGrid(rho=np.linspace(0.5, 1.0, 2))
r0 = eq.compute("R", grid=grid_trace)["R"]
z0 = eq.compute("Z", grid=grid_trace)["Z"]
# grid = LinearGrid(rho=1.0, M=10, N=16, NFP=1, endpoint=True)
fig = plot_3d(eq, "|B|", alpha=0.5)
particles = ManualParticleInitializerLab(
    R0=r0,
    phi0 = jnp.zeros_like(r0),
    Z0=z0,
    xi0=0.7*jnp.ones_like(r0),
    E = 1e-1,
    m = 4.0,
    q = 1.0,
    eq = eq,
)
field = desc.io.load("../../tests/inputs/precise_QA_helical_coils.h5") 

plot_field_lines(field, r0, z0, ntransit=2, color="red", fig=fig)
plot_particle_trajectories(field, model, particles, ts=np.linspace(0, 1e-2, 1000), fig=fig)
# fig.write_html("plot_particle_trajectories.html")
fig