In [1]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "../jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

In [3]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [4]:
# from desc.backend import jax
# jax.devices("cpu")

In [5]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [6]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic

print_backend_info()

DESC version=0.13.0+1630.gb6b7fddcb.dirty.
Using JAX backend: jax version=0.5.0, jaxlib version=0.5.0, dtype=float64.
Using device: CPU, with 4.55 GB available memory.


In [None]:
eq = desc.examples.get("ATF")
objective = ObjectiveFunction(ForceBalance(eq))
cons = get_fixed_boundary_constraints(eq)
cons = maybe_add_self_consistency(eq, cons)
constraint = ObjectiveFunction(cons)
lcon = LinearConstraintProjection(objective, constraint)
lcon.build()
x = lcon.x(eq)
J = lcon.jac_scaled_error(x).block_until_ready()

In [None]:
x = jnp.zeros(J.shape[0])
x = x.at[0].set(1.0)

cutoff = jnp.finfo(J.dtype).eps * max(J.shape)
uf, sf, vtf = jnp.linalg.svd(J, full_matrices=False)
print(len(sf), sf[0], sf[-1], cutoff * sf[0])
print(f"{max(sf) / min(sf):.2e}")
sf = jnp.where(sf[-1] < cutoff * sf[0], sf+sf[-1], sf) # add a tiny bit of regularization
sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf)
res = vtf.T @ (sfi * (uf.T @ x))

print(jnp.linalg.norm(J @ res - x))

In [None]:
x = jnp.zeros(J.shape[0])
x = x.at[0].set(1.0)

q, r = qr(J, mode="economic")
res = solve_triangular(r, q.T @ x)
print(jnp.linalg.norm(J @ res - x))

In [7]:
def run_qh_step(n, eq):
    """Run 1 step of the precise QH optimization example from Landreman & Paul."""
    print(f"==========QH step {n+1}==========")
    grid = LinearGrid(
        M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.array([0.6, 0.8, 1.0]), sym=True
    )

    objective = ObjectiveFunction(
        (
            QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid),
            AspectRatio(eq=eq, target=8, weight=1e2),
        ),
    )
    R_modes = np.vstack(
        (
            [0, 0, 0],
            eq.surface.R_basis.modes[
                np.max(np.abs(eq.surface.R_basis.modes), 1) > n + 1, :
            ],
        )
    )
    Z_modes = eq.surface.Z_basis.modes[
        np.max(np.abs(eq.surface.Z_basis.modes), 1) > n + 1, :
    ]
    constraints = (
        ForceBalance(eq=eq),
        FixBoundaryR(eq=eq, modes=R_modes),
        FixBoundaryZ(eq=eq, modes=Z_modes),
        FixPressure(eq=eq),
        FixCurrent(eq=eq),
        FixPsi(eq=eq),
    )
    optimizer = Optimizer("proximal-lsq-exact")
    eq1, history = eq.optimize(
        objective=objective,
        constraints=constraints,
        optimizer=optimizer,
        maxiter=1,
        verbose=3,
        copy=True,
        options={},
    )

    return eq1


surf = FourierRZToroidalSurface(
    R_lmn=[1, 0.125, 0.1],
    Z_lmn=[-0.125, -0.1],
    modes_R=[[0, 0], [1, 0], [0, 1]],
    modes_Z=[[-1, 0], [0, -1]],
    NFP=4,
)
# eq = Equilibrium(M=5, N=5, Psi=0.04, surface=surf)
# eq = solve_continuation_automatic(eq, objective="force", bdry_step=0.5, verbose=3)[
#     -1
# ]

eq1 = run_qh_step(0, eq)

Step 1
Spectral indexing: ansi
Spectral resolution (L,M,N)=(5,5,0)
Node resolution (L,M,N)=(10,10,0)
Boundary ratio = 0
Pressure ratio = 0
Current ratio = 1
Perturbation Order = 2
Objective: force
Optimizer: lsq-exact
Building objective: force
Precomputing transforms
Timer: Precomputing transforms = 3.01 sec
Timer: Objective build = 3.80 sec
Building objective: lcfs R
Building objective: lcfs Z
Building objective: fixed Psi
Building objective: fixed pressure
Building objective: fixed current
Building objective: fixed sheet current
Building objective: self_consistency R
Building objective: self_consistency Z
Building objective: lambda gauge
Building objective: axis R self consistency
Building objective: axis Z self consistency
Timer: Objective build = 2.35 sec
Timer: LinearConstraintProjection build = 10.3 sec
Number of parameters: 19
Number of objectives: 72
Timer: Initializing the optimization = 16.6 sec

Starting optimization
Using method: lsq-exact
   Iteration     Total nfev       

KeyboardInterrupt: 