In [None]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["XLA_FLAGS"] = (
#     "--xla_disable_hlo_passes=constant_folding"  # this disables constant folding
# )
from desc import set_device
set_device("gpu")

In [None]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [None]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [None]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

In [None]:
from desc.particles import *
from diffrax import *

In [None]:
eqi = get("precise_QA")
eq = rescale(eq=eqi, L=("a", 1.7044), B=("<B>", 5.86), copy=True)
eqi_scaled = eq.copy()
# eq = desc.io.load("precise_qa_particle_1e-6_5iter.h5")
# eq.iota = eq.get_profile("iota", kind="power_series")
# eq.solve(verbose=3);

In [None]:
N = 50  # particles traced
RHO0 = 0.3 * np.ones(N)

model_flux = VacuumGuidingCenterTrajectory(frame="flux")
particles_flux = ManualParticleInitializerFlux(
    rho0=RHO0,
    theta0=np.random.rand(N) * 2 * np.pi,
    zeta0=np.random.rand(N) * 2 * np.pi,
    xi0=np.random.rand(N),
    E=3.5e6,
)
dt = 0.01 / max(particles_flux.vpar0)
print(f"dt = {dt:.2e}")

In [None]:
obj = ObjectiveFunction(
    DirectParticleTracing(
        eq,
        particles=particles_flux,
        model=model_flux,
        solver=Tsit5(),
        ts=np.linspace(0, 1e-4, 1000),
        min_step_size=dt,
        stepsize_controller=PIDController(rtol=1e-8, atol=1e-8, dtmin=dt),
        deriv_mode="rev",
    )
)
obj.build()
obj.compute_scaled_error(obj.x(eq))

In [None]:
with jax.log_compiles():
    obj.compute_scaled_error(obj.x(eq))

In [None]:
AR = eq.compute("R0/a")["R0/a"]
obj = ObjectiveFunction(
    [
        DirectParticleTracing(
            eq,
            particles=particles_flux,
            model=model_flux,
            solver=Tsit5(),
            ts=np.linspace(0, 1e-5, 100),
            min_step_size=dt,
            stepsize_controller=PIDController(rtol=1e-8, atol=1e-8, dtmin=dt),
            deriv_mode="rev",
        ),
        AspectRatio(eq, target=AR, weight=1e3),
    ]
)
constraints = (ForceBalance(eq), FixPressure(eq), FixPsi(eq), FixCurrent(eq))

k = 2
R_modes = eq.surface.R_basis.modes[np.max(np.abs(eq.surface.R_basis.modes), 1) > 2, :]
Z_modes = eq.surface.Z_basis.modes[np.max(np.abs(eq.surface.Z_basis.modes), 1) > 2, :]
bdry_constraints = (
    FixBoundaryR(eq=eq, modes=R_modes),
    FixBoundaryZ(eq=eq, modes=Z_modes),
)

eq.optimize(
    objective=obj,
    constraints=constraints + bdry_constraints,
    verbose=3,
    maxiter=2,
    options={"max_nfev": 3, "initial_trust_ratio": 1e-3}
)

In [None]:
plot_comparison([eqi_scaled, eq], labels=["Initial", "Final"]);

In [None]:
eq.save("precise_qa_particle_1e-5_10iter.h5")

In [None]:
ts = np.linspace(0, 1e-6, 1000)
fig = plot_3d(eq, "|B|", alpha=0.3)

fig, data1 = plot_particle_trajectories(
    eq,
    model_flux,
    particles_flux,
    ts,
    fig=fig,
    min_step_size=dt,
    color="blue",
    return_data=True,
)

fig2 = plot_3d(eqi_scaled, "|B|", alpha=0.3)
fig2, data2 = plot_particle_trajectories(
    eqi_scaled,
    model_flux,
    particles_flux,
    ts,
    fig=fig2,
    min_step_size=dt,
    color="red",
    return_data=True,
)
# fig

In [None]:
# fig.write_html("optimized_particle_orbits.html")
fig

In [None]:
# fig.write_html("unoptimized_particle_orbits.html")
fig2

In [None]:
for i, (rhos, phis) in enumerate(zip(data1["rho"], data1["phi"])):
    plt.plot(phis, rhos)
plt.ylim([0,1])

In [None]:
for i, (rhos, phis) in enumerate(zip(data2["rho"], data2["phi"])):
    plt.plot(phis, rhos)
plt.ylim([0, 1])