In [1]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["XLA_FLAGS"] = (
#     "--xla_disable_hlo_passes=constant_folding "  # this disables constant folding
#     # "--xla_cpu_use_thunk_runtime=false "
# )
from desc import set_device
set_device("gpu")

In [3]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_explain_cache_misses", True)

In [4]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [5]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [6]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *
from desc.io import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *
from desc.particles import *
from diffrax import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

DESC version=0.16.0+29.g7aa703f28.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: NVIDIA GeForce RTX 4080 Laptop GPU (id=0), with 11.09 GB available memory.


In [7]:
eq = get("DSHAPE")
eq.change_resolution(NFP=10)
coils = initialize_modular_coils(eq, num_coils=1, r_over_a=2.5).to_FourierXY(N=10)

In [8]:
fig = plot_3d(eq, "|B|")
plot_coils(coils, fig=fig)

In [9]:
coil_grid = LinearGrid(N=50)
# similarly define a grid on the plasma surface where B*n errors will be evaluated
plasma_grid = LinearGrid(M=25, N=25, NFP=eq.NFP, sym=eq.sym)

weights = {
    "quadratic flux": 200,
    "coil-coil min dist": 100,
    "plasma-coil min dist": 10,
    "coil curvature": 750,
    "coil length": 20,
}
obj = ObjectiveFunction(
    (
        QuadraticFlux(
            eq,
            field=coils,
            # grid of points on plasma surface to evaluate normal field error
            eval_grid=plasma_grid,
            field_grid=coil_grid,
            vacuum=True,  # vacuum=True means we won't calculate the plasma contribution to B as it is zero
            weight=weights["quadratic flux"],
            bs_chunk_size=10,
        ),
        CoilSetMinDistance(
            coils,
            # in normalized units, want coil-coil distance to be at least 10% of minor radius
            bounds=(0.1, np.inf),
            normalize_target=False,  # we're giving bounds in normalized units
            grid=coil_grid,
            weight=weights["coil-coil min dist"],
            dist_chunk_size=2,  # this helps to reduce peak memory usage, needed to run on github CI
        ),
        PlasmaCoilSetMinDistance(
            eq,
            coils,
            # in normalized units, want plasma-coil distance to be at least 25% of minor radius
            bounds=(0.25, np.inf),
            normalize_target=False,  # we're giving bounds in normalized units
            plasma_grid=plasma_grid,
            coil_grid=coil_grid,
            eq_fixed=True,  # Fix the equilibrium. For single stage optimization, this would be False
            weight=weights["plasma-coil min dist"],
            dist_chunk_size=2,  # this helps to reduce peak memory usage, needed to run on github CI
        ),
        CoilCurvature(
            coils,
            # this uses signed curvature, depending on whether it curves towards
            # or away from the centroid of the curve, with a circle having positive curvature.
            # We give the bounds normalized units, curvature of approx 1 means circular,
            # so we allow them to be a bit more strongly shaped
            bounds=(-1, 2),
            normalize_target=False,  # we're giving bounds in normalized units
            grid=coil_grid,
            weight=weights["coil curvature"],
        ),
        CoilLength(
            coils,
            bounds=(0, 2 * np.pi * (coils[0].compute("length")["length"])),
            normalize_target=True,  # target length is in meters, not normalized
            grid=coil_grid,
            weight=weights["coil length"],
        ),
    )
)
obj.build()

coil_indices_to_fix_current = [False for c in coils]
coil_indices_to_fix_current[0] = True
constraints = (FixCoilCurrent(coils, indices=coil_indices_to_fix_current),)

Building objective: Quadratic flux
Precomputing transforms
Building objective: coil-coil minimum distance
Building objective: plasma-coil minimum distance
Building objective: coil curvature
Precomputing transforms
Building objective: coil length
Precomputing transforms


In [10]:
optimizer = Optimizer("lsq-exact")

(optimized_coilset,), _ = optimizer.optimize(
    coils,
    objective=obj,
    constraints=constraints,
    maxiter=100,
    verbose=3,
    ftol=1e-4,
    copy=True,
)

Building objective: fixed coil current
Building objective: fixed shift
Building objective: fixed rotation
Timer: Objective build = 358 ms
Timer: LinearConstraintProjection build = 1.78 sec
Number of parameters: 46
Number of objectives: 1468
Timer: Initializing the optimization = 2.20 sec

Starting optimization
Using method: lsq-exact
Solver options:
------------------------------------------------------------
Maximum Function Evaluations       : 501
Maximum Allowed Total Δx Norm      : inf
Scaled Termination                 : True
Trust Region Method                : qr
Initial Trust Radius               : 3.746e+01
Maximum Trust Radius               : inf
Minimum Trust Radius               : 2.220e-16
Trust Radius Increase Ratio        : 2.000e+00
Trust Radius Decrease Ratio        : 2.500e-01
Trust Radius Increase Threshold    : 7.500e-01
Trust Radius Decrease Threshold    : 2.500e-01
------------------------------------------------------------ 

   Iteration     Total nfev        Co

In [11]:
plot_grid = LinearGrid(M=20, N=40, NFP=1, endpoint=True)
fig = plot_3d(
    eq.surface, "B*n", field=optimized_coilset, field_grid=coil_grid, grid=plot_grid
)

fig = plot_coils(optimized_coilset, fig=fig)
fig.show()

In [12]:
x, x_s = optimized_coilset._compute_position(grid=LinearGrid(N=30), dx1=True, basis="xyz")

In [13]:
xyz = np.concatenate([xi for xi in x])
xyz_vec = np.concatenate([xi for xi in x_s])
xyz_vec /= jnp.linalg.norm(xyz_vec, axis=1)[:, None]
Js = np.concatenate([np.ones(xi.shape[0])*Ji for xi, Ji in zip(x, optimized_coilset._all_currents())], axis=0)
xyz_vec *= Js[:, None]

Jx_coil = xyz_vec[:, 0]
Jy_coil = xyz_vec[:, 1]
Jz_coil = xyz_vec[:, 2]
X_coil = xyz[:, 0]
Y_coil = xyz[:, 1]
Z_coil = xyz[:, 2]

# What you will need?

- `xyz` is a (N,3) array for X, Y, Z positions of the coil currents.
- `xyz_vec` is a (N,3) array for X, Y, Z components of the coil currents at `xyz` points.
- The data below are the currents from plasma.

You can save these to text file and load it from there after you obtain them once!

```python
np.savetxt("a.txt", a)
a = np.loadtxt("a.txt")
```

In [14]:
eq.change_resolution(NFP=1)
grid = LinearGrid(L=eq.L_grid, M=eq.M_grid, N=10, sym=False, NFP=1)
data = eq.compute(["J", "X", "Y", "Z"], grid=grid, basis="xyz")
Jx = data["J"][:, 0]
Jy = data["J"][:, 1]
Jz = data["J"][:, 2]
X = data["X"]
Y = data["Y"]
Z = data["Z"]