In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
from scipy.constants import mu_0
import numpy as np
from desc.backend import jnp
from desc.grid import QuadratureGrid, ConcentricGrid, LinearGrid
from desc.profiles import PowerSeriesProfile
import desc.io
from desc.compute._core import compute_rotational_transform
from desc.compute.utils import surface_averages, surface_integrals
from desc.compute import data_index
from desc.transform import Transform
from netCDF4 import Dataset
from desc.geometry import FourierRZCurve
from desc.equilibrium import Equilibrium
from desc.plotting import plot_1d

np.set_printoptions(precision=3, floatmode="fixed")
jnp.set_printoptions(precision=3, floatmode="fixed")

DESC version 0.5.1+74.g47683e6.dirty, using JAX backend, jax version=0.2.25, jaxlib version=0.1.76, dtype=float64
Using device: CPU, with 8.43 GB available memory


## Visualizing different grids

In [3]:
def print_grid(grid, stop=None, quantity=None):
    """
    Parameters
    ----------
    grid
        Prints nodes and spacing.
    stop : int
        Max number of nodes to print. Defaults to entire grid.
    quantity : ndarray
        A quantity to print alongside the grid. i.e. data["iota"]
    """
    assert (
        len(grid.nodes) // grid.num_zeta
        == jnp.where(~jnp.isclose(grid.nodes[:, 2], 0))[0][0]
    )
    print(grid.L, grid.M, grid.N, grid.NFP, grid.sym, grid.node_pattern)
    print(grid.num_rho, grid.num_theta, grid.num_zeta)
    print("nodes", "             ", "spacing")
    if stop is None:
        stop = len(grid.nodes)
    if quantity is None:
        for i, e in enumerate(zip(grid.nodes, grid.spacing)):
            a, b = e
            print(a, b)
            if i > stop:
                break
    else:
        for i, e in enumerate(zip(grid.nodes, grid.spacing, quantity)):
            a, b, c = e
            print(a, b, c)
            if i > stop:
                break


rng = np.random.default_rng()
L = rng.integers(low=1, high=100)
M = rng.integers(low=1, high=100)
N = rng.integers(low=1, high=100)
NFP = rng.integers(low=1, high=100)
sym = True if rng.integers(2) > 0 else False
random_grid = ConcentricGrid(L=L, N=N, M=M, NFP=NFP, sym=sym)
print_grid(random_grid, stop=10)

12 72 61 14 True jacobi
7 264 123
nodes               spacing
[0.056 0.349 0.000] [0.359 3.180 0.011]
[0.056 1.396 0.000] [0.359 3.180 0.011]
[0.056 2.443 0.000] [0.359 3.180 0.011]
[0.180 0.070 0.000] [0.450 0.636 0.011]
[0.180 0.279 0.000] [0.450 0.636 0.011]
[0.180 0.489 0.000] [0.450 0.636 0.011]
[0.180 0.698 0.000] [0.450 0.636 0.011]
[0.180 0.908 0.000] [0.450 0.636 0.011]
[0.180 1.117 0.000] [0.450 0.636 0.011]
[0.180 1.326 0.000] [0.450 0.636 0.011]
[0.180 1.536 0.000] [0.450 0.636 0.011]
[0.180 1.745 0.000] [0.450 0.636 0.011]


## Bulk flux surface averaging test
The tests pass. timeit also shows the _surface_sums no loop algorithm is faster.

In [4]:
heliotron = desc.io.load(
    "../../master-28d5b14/examples/DESC/" + "HELIOTRON" + "_output.h5"
)[-1]
random_integrand = np.random.random_sample(size=len(random_grid.nodes))

# print(flux_surface_function_integrand)
heliotron_grid = ConcentricGrid(
    L=heliotron.L_grid,
    M=heliotron.M_grid,
    N=heliotron.N_grid,
    NFP=heliotron.NFP,
    sym=heliotron.sym,
    node_pattern=heliotron.node_pattern,
)
flux_surface_function_integrand = np.asarray(
    heliotron.compute("p", grid=heliotron_grid)["p"]
)



In [5]:
# %%timeit

result_1 = np.empty(random_grid.num_rho)
# collect collocation node indices for each rho surface
surfaces = dict()
for index, rho in enumerate(random_grid.nodes[:, 0]):
    surfaces.setdefault(rho, list()).append(index)
# integration over non-contiguous elements
ds = random_grid.spacing[:, 1:].prod(axis=1)
for i, surface_indices in enumerate(surfaces.values()):
    result_1[i] = (ds * random_integrand)[surface_indices].sum()

In [6]:
# %%timeit

result_1_helio = np.empty(heliotron_grid.num_rho)
# collect collocation node indices for each rho surface
surfaces_helio = dict()
for index, rho in enumerate(heliotron_grid.nodes[:, 0]):
    surfaces_helio.setdefault(rho, list()).append(index)
# integration over non-contiguous elements
ds = (
    heliotron_grid.weights
    if heliotron_grid.num_rho == 1
    else heliotron_grid.spacing[:, 1:].prod(axis=1)
)  # for NFP bug
for i, surface_indices in enumerate(surfaces_helio.values()):
    result_1_helio[i] = (ds * flux_surface_function_integrand)[surface_indices].sum()

In [7]:
# %%timeit
result_2 = surface_integrals(random_grid, random_integrand)
result_2_helio = surface_integrals(heliotron_grid, flux_surface_function_integrand)

In [12]:
# must comment %%timeit to test assertion
print(result_1)
print(result_2)
assert jnp.allclose(result_1, result_2)
assert jnp.allclose(result_1_helio, result_2_helio)
sqrtg = heliotron.compute("sqrt(g)", grid=heliotron_grid)["sqrt(g)"]
assert jnp.allclose(
    flux_surface_function_integrand,
    surface_averages(
        heliotron_grid, flux_surface_function_integrand, sqrtg, match_grid=True
    ),
)
print_grid(heliotron_grid, stop=15, quantity=flux_surface_function_integrand)

[6.394 6.492 6.562 6.521 6.511 6.500 6.554]
[6.394 6.492 6.562 6.521 6.511 6.500 6.554]
36 18 6 19.0 1 jacobi
19 180 13
nodes               spacing
[0.009 2.094 0.000] [ 0.065 20.514  0.083] 17996.987305265866
[0.030 0.698 0.000] [0.088 6.838 0.083] 17966.641960604935
[0.030 2.793 0.000] [0.088 6.838 0.083] 17966.641960604935
[0.063 0.419 0.000] [0.125 4.103 0.083] 17856.02210540674
[0.063 1.676 0.000] [0.125 4.103 0.083] 17856.02210540674
[0.063 2.932 0.000] [0.125 4.103 0.083] 17856.02210540674
[0.107 0.299 0.000] [0.158 2.931 0.083] 17590.904428375157
[0.107 1.197 0.000] [0.158 2.931 0.083] 17590.904428375157
[0.107 2.094 0.000] [0.158 2.931 0.083] 17590.904428375157
[0.107 2.992 0.000] [0.158 2.931 0.083] 17590.904428375157
[0.160 0.233 0.000] [0.188 2.279 0.083] 17088.159331990475
[0.160 0.931 0.000] [0.188 2.279 0.083] 17088.159331990475
[0.160 1.629 0.000] [0.188 2.279 0.083] 17088.159331990475
[0.160 2.327 0.000] [0.188 2.279 0.083] 17088.159331990475
[0.160 3.025 0.000] [0.188

In [9]:
integrals_match_grid = surface_integrals(random_grid, random_integrand, match_grid=True)
for i, surface_indices in enumerate(surfaces.values()):
    for index in surface_indices:
        assert jnp.allclose(integrals_match_grid[index], result_1[i])
print_grid(random_grid, stop=15, quantity=integrals_match_grid)

12 72 61 14 True jacobi
7 264 123
nodes               spacing
[0.056 0.349 0.000] [0.359 3.180 0.011] 6.39439470385697
[0.056 1.396 0.000] [0.359 3.180 0.011] 6.39439470385697
[0.056 2.443 0.000] [0.359 3.180 0.011] 6.39439470385697
[0.180 0.070 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 0.279 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 0.489 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 0.698 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 0.908 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 1.117 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 1.326 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 1.536 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 1.745 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 1.955 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 2.164 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 2.374 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 2.583 0.000] [0.450 0.636 0.011] 6.491556777208184
[0.180 2.793 

## Axisymmetric, vacuum, no current test
Want to test if the returned rotational transform profile is 0 when toroidal current input is 0.
This should be a good test because the bulk of the computation lies on enforcing the zero toroidal current algorithm for the geometry of the device.
And when a non-zero toroidal current is specified we just add it to the numerator because that is the delta poloidal flux term.

In [10]:
def get_concentric_grid(eq):
    """
    Parameters
    ----------
    eq : Equilibrium
        The equilibrium.

    Returns
    -------
    ConcentricGrid
        Concentric grid used by the given equilibrium.
    """
    return ConcentricGrid(
        L=eq.L_grid,
        M=eq.M_grid,
        N=eq.N_grid,
        NFP=1,  # to avoid known grid bug
        sym=False,  # to avoid known grid bug
        node_pattern=eq.node_pattern,
    )


def get_linear_grid(eq, rho):
    """
    Parameters
    ----------
    eq : Equilibrium
        The equilibrium.
    rho: ndarray
        Grid rhos.

    Returns
    -------
    LinearGrid
        LinearGrid grid for the given equilibrium.
    """
    return LinearGrid(
        M=max(5, eq.M_grid),
        N=max(5, eq.N_grid),
        NFP=1,  # to avoid known grid bug
        sym=False,  # to avoid known grid bug
        rho=rho,
    )


def get_transform(eq, grid):
    """
    Parameters
    ----------
    eq : Equilibrium
        The equilibrium.
    grid : ConcentricGrid
        Concentric grid used by the given equilibrium.

    Returns
    -------
    Transform
         R_transform, Z_transform, L_transform of the given equilibrium.
    """
    r_derivs = data_index["sqrt(g)_rr"]["R_derivs"]
    l_derivs = jnp.vstack(
        (
            data_index["lambda_t"]["L_derivs"],
            data_index["lambda_rt"]["L_derivs"],
            data_index["lambda_rrt"]["L_derivs"],
            data_index["lambda_z"]["L_derivs"],
            data_index["lambda_rz"]["L_derivs"],
            data_index["lambda_rrz"]["L_derivs"],
        )
    )
    R_transform = Transform(grid, eq.R_basis, derivs=r_derivs, build=True)
    Z_transform = Transform(grid, eq.Z_basis, derivs=r_derivs, build=True)
    L_transform = Transform(grid, eq.L_basis, derivs=l_derivs, build=True)
    return R_transform, Z_transform, L_transform


def get_toroidal_current(eq, grid):
    """
    Parameters
    ----------
    eq : Equilibrium
        The equilibrium.
    grid : ConcentricGrid, LinearGrid
        The grid for the power series profile.

    Returns
    -------
    current : ndarray
        The toroidal current at each unique rho surface of the given equilibrium.
    """
    current = np.empty(grid.num_rho)
    for i, r in enumerate(grid.nodes[grid.unique_rho_indices, 0]):
        data = eq.compute("I", grid=LinearGrid(M=eq.M_grid, N=max(1, eq.N_grid), rho=r))
        current[i] = data["I"]
    # assert jnp.isfinite(current).all(), "compute_quasisymmetry_error produces nan values"
    return 2 * jnp.pi / mu_0 * current
    # return PowerSeriesProfile.from_values(x=rho, y=current, grid=grid)


def get_iota_data(eq, c_l, power_series, vmec_current=None):
    """
    Parameters
    ----------
    eq : Equilibrium
        The equilibrium.
    c_l : ndarray
        Coefficients of the current profile. (params of power_series).
    power_series : PowerSeriesProfile
        The power series to compute the toroidal current.
    vmec_current : ndarray
        uses vmec current as input to compute_rotational_transform if specified.
        Make sure power series has linear grid with rho = vmec plot rho.

    Returns
    -------
    data : dict
        The dictionary which contains at least iota.
    """
    current = None
    if c_l is None:
        current = (
            get_toroidal_current(eq, power_series.grid)
            if vmec_current is None
            else vmec_current
        )
    transform = get_transform(eq, power_series.grid)
    return compute_rotational_transform(
        eq.R_lmn,
        eq.Z_lmn,
        eq.L_lmn,
        eq.i_l,
        eq.c_l,
        eq.Psi,
        transform[0],
        transform[1],
        transform[2],
        iota=eq.iota,
        current=eq.current,
        toroidal_current_unique_rho=current,
    )

In [11]:
grid = get_concentric_grid(Equilibrium(N=5))
print(grid.nodes)
print_grid(grid)
torus = Equilibrium(
    current=PowerSeriesProfile(
        params=jnp.zeros(3), modes=jnp.arange(0, 6, step=2), grid=grid
    )
)
print(torus.current)
data = get_iota_data(torus, c_l=torus.c_l, power_series=torus.current)
print(data.keys())
assert jnp.allclose(data["iota"], 0)



[[0.355 0.000 0.000]
 [0.845 0.000 0.000]
 [0.845 1.257 0.000]
 [0.845 2.513 0.000]
 [0.845 3.770 0.000]
 [0.845 5.027 0.000]
 [0.355 0.000 0.299]
 [0.845 0.000 0.299]
 [0.845 1.257 0.299]
 [0.845 2.513 0.299]
 [0.845 3.770 0.299]
 [0.845 5.027 0.299]
 [0.355 0.000 0.598]
 [0.845 0.000 0.598]
 [0.845 1.257 0.598]
 [0.845 2.513 0.598]
 [0.845 3.770 0.598]
 [0.845 5.027 0.598]
 [0.355 0.000 0.898]
 [0.845 0.000 0.898]
 [0.845 1.257 0.898]
 [0.845 2.513 0.898]
 [0.845 3.770 0.898]
 [0.845 5.027 0.898]
 [0.355 0.000 1.197]
 [0.845 0.000 1.197]
 [0.845 1.257 1.197]
 [0.845 2.513 1.197]
 [0.845 3.770 1.197]
 [0.845 5.027 1.197]
 [0.355 0.000 1.496]
 [0.845 0.000 1.496]
 [0.845 1.257 1.496]
 [0.845 2.513 1.496]
 [0.845 3.770 1.496]
 [0.845 5.027 1.496]
 [0.355 0.000 1.795]
 [0.845 0.000 1.795]
 [0.845 1.257 1.795]
 [0.845 2.513 1.795]
 [0.845 3.770 1.795]
 [0.845 5.027 1.795]
 [0.355 0.000 2.094]
 [0.845 0.000 2.094]
 [0.845 1.257 2.094]
 [0.845 2.513 2.094]
 [0.845 3.770 2.094]
 [0.845 5.027

KeyError: 'iota'

## Compare to VMEC test
Want to test if the rotational transform profile from new compute function matches the rotational transform profile computed by VMEC.

In [None]:
def plot(x, y, prepend_title):
    """
    Parameters
    ----------
    x : ndarray
        plot x-axis.
    y : ndarray
        plot y-axis. will make multiple plots if y.ndim > 1
    prepend_title : str
        string to prepend to plot title
    """
    dot_size = 5 if len(x) > 64 else 10
    y = np.atleast_2d(y)
    fig, ax = plt.subplots(ncols=y.shape[0], figsize=(y.shape[0] * 8, 5))
    ax = np.atleast_1d(ax)

    for i in range(y.shape[0]):
        append = " " + str(i) + " derivative wrt rho"
        ax[i].scatter(x, y[i, :], s=dot_size)
        ax[i].plot(x, y[i, :])
        ax[i].set(
            xlabel=r"$\rho$",
            ylabel="iota" + append,
            yscale="symlog" if jnp.ptp(y[i, :]) > 1e3 else "linear",
            title=prepend_title + " iota" + append,
        )
        ax[i].grid()


def plot_overlay(x, y, prepend_title, x2, y2, prepend_title_2):
    """
    Parameters
    ----------
    x : ndarray
        plot x-axis.
    y : ndarray
        plot y-axis. will make multiple plots if y.ndim > 1
    prepend_title : str
        string to prepend to plot title
    """
    dot_size = 2 if len(x) > 128 else 5
    y = np.atleast_2d(y)
    y2 = np.atleast_2d(y2)
    fig, ax = plt.subplots(dpi=300)
    ax = np.atleast_1d(ax)

    for i in range(y.shape[0]):
        append = " " + str(i) + " derivative wrt rho"
        ax[i].scatter(x, y[i, :], s=dot_size, label=prepend_title)
        ax[i].plot(x, y[i, :])
        ax[i].scatter(x2, y2[i, :], s=dot_size, label=prepend_title_2)
        ax[i].plot(x2, y2[i, :])
        ax[i].set(
            xlabel=r"$\rho$",
            ylabel="iota" + append,
            yscale="symlog" if jnp.ptp(y[i, :]) > 1e3 else "linear",
            title="iota" + append,
        )
        ax[i].grid()
    fig.legend()


def get_desc_plot(
    eq,
    params,
    modes,
    ignore_params_use_eq_compute=False,
    vmec_rho=None,
    vmec_current=None,
):
    """
    Parameters
    ----------
    eq : Equilibrium
        The equilibrium.
    params : ndarray, None
        Coefficients of the current profile. (params of power_series).
    modes : ndarray
        Toroidal current power series profile modes.
        Should include only even modes to match VMEC AC input.
    ignore_params_use_eq_compute: bool
        The current is computed from eq.compute("I") rather than the power series.

    Returns
    -------
    rho : ndarray
        unique rho values. x-axis of plot
    iotas : ndarray
        iota, iota_r, iota_rr at the unique rho values. y-axis of plot
    """
    grid = get_linear_grid(eq, vmec_rho)
    data = get_iota_data(
        eq,
        I_l=None if ignore_params_use_eq_compute else params,
        power_series=PowerSeriesProfile(params=params, modes=modes, grid=grid),
        vmec_current=vmec_current,
    )
    assert len(grid.nodes) == len(data["iota"])
    rho = grid.nodes[grid.unique_rho_indices, 0]
    iota = data["iota"][grid.unique_rho_indices]
    # iota_r = data["iota_r"][grid.unique_rho_indices]
    # iota_rr = data["iota_rr"][grid.unique_rho_indices]
    print()
    return rho, iota  # iota_r, iota_rr])


def get_vmec_plot(name, return_iota=True):
    """
    Parameters
    ----------
    name : str
        Name of the equilibrium.
    return_iota: bool
        Returns iota for the y value if True, else current.

    Returns
    -------
    rho : ndarray
        rho values. x-axis of plot.
    iota : ndarray
        iota. y-axis of plot.
    """x

    f = Dataset("edu-vmec/input-iota/wout_" + name + ".nc")
    current = 2 * jnp.pi / mu_0 * jnp.array(f.variables["buco"])
    iota = jnp.asarray(f.variables["iotaf"])
    rho = jnp.sqrt(f.variables["phi"] / np.array(f.variables["phi"])[-1])
    return rho, iota if return_iota else current

In [None]:
# values are toroidal current profiles (AC input for VMEC)
stellarators = {
    "ATF": jnp.array([1, -1, 0]),
    "DSHAPE": jnp.array([2, -3, 0]),
    "SOLOVEV": jnp.array([4, -2, -1]),
    "HELIOTRON": jnp.array([2, -2, -1]),
    "AXISYM": jnp.array([1, -1, 0]),
}

In [None]:
# surface = FourierRZCurve(R_n=[1e1, -1e0, -3e-1, 3e-1],
#                          Z_n=[1e0, -3e-1, -3e-1],
#                          modes_R=[[0, 0], [1,0], [1, 1], [-1, -1]], # modes given as [m,n] for each coefficient
#                          modes_Z=[[-1, 0], [-1, 1], [1, -1]],
#                          NFP=1,
#                          sym=False,
#                         )
# pressure = PowerSeriesProfile(params=[1.8e4, -3.6e4, 1.8e4], modes=[0,2,4])
# current = PowerSeriesProfile(params=[2, -2, -1], modes=[0,2,4])
#
# heliotron = Equilibrium(Psi=1, NFP=1, sym=False, L=24, M=12, N=3, M_grid=18, N_grid=6,node_pattern="jacobi", pressure=pressure, current=current, spectral_indexing="fringe", objective="force", optimizer="lsq-exact")
# print(heliotron.compute("iota")["iota"])
# heliotron.solve()

In [None]:
# currents = dict()
# eqs = dict()
# for name in stellarators.keys():
#      rho, vmec_current = get_vmec_plot(name, return_iota=False)
#      eq = desc.io.load("../../master-28d5b14/examples/DESC/" + name + "_output.h5")[-1]
#      eqs[name] = eq
#      desc_current = get_toroidal_current(eq, LinearGrid(M=eq.M_grid,
#                                              N=max(1, eq.N_grid),
#                                              NFP=1,  # to avoid known grid bug
#                                              sym=False,  # to avoid known grid bug
#                                              rho=rho,
#                                             ))
#      currents[name] = rho, vmec_current, desc_current

In [None]:
# for name in stellarators.keys():
#     fig, ax = plt.subplots(dpi=300)
#     rho = currents[name][0]
#
#     ax.scatter(rho, currents[name][1], s=1)
#     ax.scatter(rho, currents[name][2], s=1)
#     ax.plot(rho, currents[name][1], label="VMEC")
#     ax.plot(rho, currents[name][2], label="DESC")
#     ax.plot(rho, -1 * currents[name][1], label="VMEC * -1")
#     ax.set(xlabel=r"$\rho$", ylabel="Toroidal Current (amperes)", title=name + " Toroidal Current")
#     ax.grid()
#     fig.legend()
#     fig.savefig(name + " current compare plot.png", facecolor='white')

## Note
VMEC says its HELIOTRON and AXISYM inputs may have convergence issues. Suggestion was to decrease DELT (step size). However, the error message remained.
DESC also says its HELIOTRON and AXISYM inputs exceed max function evaluations. This can be fixed by increasing `nfev` in the input files for HELIOTRON.
I did not change this to keep the input consistent with VMEC.

In [None]:
# plot DESC output
for name, params in stellarators.items():
    vmec_rho, vmec_current = get_vmec_plot(name, return_iota=False)
    _, vmec_iota = get_vmec_plot(name)
    assert jnp.allclose(_, vmec_rho)
    eq = desc.io.load("../../master-28d5b14/examples/DESC/" + name + "_output.h5")[-1]
    # modes are even to match VMEC psi modes
    rho, iota = get_desc_plot(
        eq,
        params,
        jnp.array([0, 2, 4]),
        ignore_params_use_eq_compute=True,
        vmec_rho=vmec_rho,
        vmec_current=-1 * vmec_current * mu_0 / 2 / jnp.pi,
    )
    scale = iota[-1]
    iota = iota / scale

    plot_overlay(
        rho,
        iota,
        "DESC " + name + " / (iota[-1]=" + str(scale) + ")",
        vmec_rho,
        vmec_iota,
        "VMEC " + name,
    )

In [None]:
# plot VMEC output
for name in stellarators.keys():
    rho, iota = get_vmec_plot(name)
    plot(rho, iota, "VMEC " + name)