In [None]:
# import first or else jupyter crashes

from desc.plotting import plot_section
from desc.plotting import (
    plot_grid,
    plot_boozer_modes,
    plot_boozer_surface,
    plot_qs_error,
    plot_boundaries,
    plot_boundary,
)
import desc
import matplotlib.pyplot as plt


from desc.transform import Transform
from desc.equilibrium import Equilibrium
from desc.basis import ChebyshevZernikeBasis, chebyshev_z
from desc.compute import compute
from desc.grid import LinearGrid, ConcentricGrid, QuadratureGrid, Grid
from desc.compute.utils import get_transforms
from desc.objectives import (
    FixEndCapLambda,
    FixEndCapR,
    FixEndCapZ,
    FixBoundaryR,
    FixBoundaryZ,
    FixPsi,
    FixPressure,
    FixIota,
    FixAnisotropy,
    ForceBalance,
    ForceBalanceAnisotropic,
    ObjectiveFunction,
    CurrentDensity,
    MatchEndCapR,
    MatchEndCapZ,
    MatchEndCapLambda,
)
import numpy as np
from desc.optimize import Optimizer
from desc.profiles import PowerSeriesProfile
from desc.geometry import FourierRZToroidalSurface
from scipy.constants import mu_0
from mayavi import mlab


def chebygrid(N_grid):
    return np.concatenate(
        (
            [0],
            (-np.cos((2 * np.arange(N_grid) + 1) * np.pi / (2 * N_grid)) + 1) * np.pi,
            [2 * np.pi],
        )
    )


def grid_gen(L_grid, M_grid, N_grid, node_pattern="jacobi"):
    LMnodes = ConcentricGrid(L=L_grid, M=M_grid, N=0, node_pattern=node_pattern).nodes[
        :, :2
    ]
    Nnodes = chebygrid(N_grid)
    lm = np.tile(LMnodes, (Nnodes.size, 1))
    n = np.tile(Nnodes.reshape(-1, 1), (1, LMnodes.shape[0])).reshape(-1, 1)
    nodes = np.concatenate((lm, n), axis=1)

    # RG: weights and spacing defined here
    # just for the sake of compilation. Must be checked
    weights = np.ones(nodes.shape[0])
    spacing = np.ones_like(nodes)

    spacing[1:, 1] = np.diff(nodes[:, 1])
    spacing[1:, 2] = np.diff(nodes[:, 2])

    return Grid(nodes, spacing=spacing, weights=weights)


def get_lm_mode(basis, coeff, zeta, L, M, func_zeta=chebyshev_z):
    modes = basis.modes
    lm = 0
    for i, (l, m, n) in enumerate(modes):
        if l == L and m == M:
            lm += func_zeta(zeta, n) * coeff[i]
    return lm


def chebygrid(N_grid):
    return np.concatenate(
        (
            [0],
            (-np.cos((2 * np.arange(N_grid) + 1) * np.pi / (2 * N_grid)) + 1) * np.pi,
            [2 * np.pi],
        )
    )


def grid_gen(L_grid, M_grid, N_grid, node_pattern="jacobi"):
    LMnodes = ConcentricGrid(L=L_grid, M=M_grid, N=0, node_pattern=node_pattern).nodes[
        :, :2
    ]
    Nnodes = chebygrid(N_grid)
    lm = np.tile(LMnodes, (Nnodes.size, 1))
    n = np.tile(Nnodes.reshape(-1, 1), (1, LMnodes.shape[0])).reshape(-1, 1)
    nodes = np.concatenate((lm, n), axis=1)

    # RG: weights and spacing defined here
    # just for the sake of compilation. Must be checked
    weights = np.ones(nodes.shape[0])
    spacing = np.ones_like(nodes)

    spacing[1:, 1] = np.diff(nodes[:, 1])
    spacing[1:, 2] = np.diff(nodes[:, 2])

    return Grid(nodes, spacing=spacing, weights=weights)


def get_lm_mode(basis, coeff, zeta, L, M, func_zeta=chebyshev_z):
    modes = basis.modes
    lm = 0
    for i, (l, m, n) in enumerate(modes):
        if l == L and m == M:
            lm += func_zeta(zeta, n) * coeff[i]
    return lm


plt.rcParams.update(
    {
        "font.size": 16,  # Default text size
        "axes.titlesize": 18,  # Title font
        "axes.labelsize": 16,  # x/y label font
        "xtick.labelsize": 14,  # Tick font
        "ytick.labelsize": 14,
        "legend.fontsize": 14,
    }
)


## test
surf = FourierRZToroidalSurface(
    R_lmn=[10, 1, -0.5],
    modes_R=[[0, 0], [1, 0], [1, 2]],
    Z_lmn=[0, -1, 0.5],
    modes_Z=[[0, 0], [-1, 0], [-1, 2]],
    NFP=1,
    sym=False,
    mirror=True,
)

iota = PowerSeriesProfile(params=[0.0, 0])

p = PowerSeriesProfile(params=[0.001 / mu_0, -0.001 / mu_0], modes=[0, 2])
p.set_params(0, a=-1e-3)  # = array([-1000])
p.set_params(2, a=1e-3)  # = array([-1000])

In [None]:
surf =  FourierRZToroidalSurface.load("../Isotropic/SFLM_lowMR_2p6.hdf5")
#surf =  FourierRZToroidalSurface.load("../Isotropic/SFLM_highMR_11p1.hdf5")


In [None]:
from desc.profiles import PerpendicularPressureProfile, ParallelPressureProfile

In [None]:
from desc.objectives._equilibrium import (
    ForceBalanceAnisotropic_3D,
    B_on_Axis,
    ForceBalanceAnisotropic_3D,
    ForceBalanceAnisotropic_3D_explicit,
)

In [None]:
coeffs = [
    (1, 50.0, 100),  # species 1
]

In [None]:
psi_p_perp_profile = PowerSeriesProfile(params=[1e2, -1e2], modes=[0, 2])

In [None]:
p_perp = PerpendicularPressureProfile(
    psi_profiles=[psi_p_perp_profile], coeffs=coeffs
)
p_para = ParallelPressureProfile(p_perp, d_coeffs=[-500])
p_para = ParallelPressureProfile(p_perp, d_coeffs=[-50])

In [None]:
grid = grid_gen(4, 10, 40)
data = surf.compute(["X", "Y", "Z"], grid=grid)

# 3D

x = data["X"]
y = data["Y"]
z = data["Z"]

plt.plot(y, x, ".")

In [None]:
surf_init = surf.copy()

In [None]:
eq = Equilibrium(
    anisotropy=p_para,
    surface=surf_init,
    L=5,
    M=2,
    N=8,
    mirror=True,
    pressure=p_perp,
    iota=iota,
    sym=False,
)

In [None]:
### now link the eq to the pressure, needed for direct calls to Profiles
p_perp.set_equilibrium(eq)
p_para.set_equilibrium(eq)

In [None]:
# ---- Grid for axis (rho = 0, theta = 0, zeta from 0 to 2pi) ----
surface = 0
zeta_axis = np.linspace(0, 2 * np.pi, 100)
rho_axis = np.ones_like(zeta_axis) * surface
theta_axis = np.zeros_like(zeta_axis)
nodes = np.stack([rho_axis, theta_axis, zeta_axis], axis=1)
grid_axis = Grid(nodes, spacing=np.ones_like(nodes), weights=np.ones_like(zeta_axis))

In [None]:
# Compute on axis, normalize Psi
data_axis = eq.compute(["|B|"], grid=grid_axis)
B_axis_t = data_axis["|B|"]
### Lower B magnitude by 10x
eq.Psi = 1 / np.max(B_axis_t)

data_axis = eq.compute(["|B|"], grid=grid_axis)
B_axis = data_axis["|B|"]
plt.figure()
plt.plot(B_axis)
plt.title("MR = {}".format(np.max(B_axis) / np.min(B_axis)))

In [None]:
p_eq = eq.compute(["p_perp", "p_parallel", "|B|"], grid=grid_axis)

plt.figure()
plt.plot(p_eq["p_perp"], label="eq.compute['p_perp']")
plt.plot(p_eq["p_parallel"], label=r"eq.compute['p_parallel']")

plt.xlabel("zeta")
plt.ylabel("p")
plt.legend()
plt.title(r"$\rho$ = {}, $\theta$ = 0 ".format(surface))

In [None]:
# ---- Grid for rho  ----
rho_axis = np.linspace(0, 1, 100) 
theta_axis = np.ones_like(rho_axis)  * 0
zeta_axis = np.ones_like(rho_axis) * np.pi
nodes = np.stack([rho_axis, theta_axis, zeta_axis], axis=1)
grid = Grid(nodes, spacing=np.ones_like(nodes), weights=np.ones_like(zeta_axis))

In [None]:
p_eq = eq.compute(["p_perp", "p_parallel", "|B|"], grid=grid)

plt.figure()
plt.plot(p_eq["p_perp"], label="eq.compute['p_perp']")
plt.plot(p_eq["p_parallel"], label=r"eq.compute['p_parallel']")

plt.xlabel("zeta")
plt.ylabel("p")
plt.legend()
plt.title(r"$\rho$ = {}, $\theta$ = 0 ".format(surface))

In [None]:
constraints = (
    # FixEndCapLambda(0, eq=eq),
    # FixEndCapR(0, eq=eq),
    # FixEndCapZ(0, eq=eq),
    # FixEndCapLambda(2*np.pi, eq=eq),
    # FixEndCapR(2*np.pi, eq=eq),
    # FixEndCapZ(2*np.pi, eq=eq),
    #MatchEndCapR(eq=eq),
    #MatchEndCapZ(eq=eq),
    #MatchEndCapLambda(eq=eq),
    # MatchMidplaneLambda(eq=eq),
    # MatchMidplaneR(eq=eq),
    # MatchMidplaneZ(eq=eq),
    FixBoundaryR(eq=eq),
    FixBoundaryZ(eq=eq),
    FixPressure(eq=eq),
    FixIota(eq=eq),
    FixPsi(eq=eq),
    FixAnisotropy(eq=eq),
)
optimizer = Optimizer("lsq-exact")
grid = grid_gen(20, 20, 20)

objective1 = ForceBalanceAnisotropic_3D_explicit(eq=eq, grid=grid)
objective3 = B_on_Axis(eq=eq, Nz=100, monotonic=True, MR=True, MR_target=11, weight=1)
objectives = [objective1]
obj = ObjectiveFunction(objectives=objectives)

eq.solve(
    objective=obj,
    constraints=constraints,
    optimizer=optimizer,
    ftol=1e-6,
    xtol=1e-16,
    gtol=1e-16,
    maxiter=25,
    verbose=3,
)

In [None]:
f0 = (
    eq.compute("<|F|>_vol")["<|F|>_vol"]
    / eq.compute("<|grad(|B|^2)|/2mu0>_vol")["<|grad(|B|^2)|/2mu0>_vol"]
)


print(f"Force error after solve: {f0:.4e}")

In [None]:
eq.anisotropy.params

In [None]:
from desc.plotting import *

In [None]:
plot_section(eq, "|F|", norm_F=True, log=True)

In [None]:
### plotting tools

In [None]:
### define a grid

In [None]:
grid = grid_gen(2, 10, 40)

In [None]:
data = eq.compute(
    [
        "X",
        "Y",
        "Z",
        "|B|",
        "<beta>_vol",
        "p_parallel",
        "|F|",
        "p_perp",
        "grad(p)",
        "J",
        "|J|",
        "J_rho",
        "J_theta",
        "J_zeta",
        "J*B",
        "J_parallel",
    ],
    grid=grid,
)

In [None]:
# ---- Grid for axis (rho = 0, theta = 0, zeta from 0 to 2pi) ----
Nz = 100
zeta_axis = np.linspace(0, 2 * np.pi, Nz)
rho_axis = np.ones_like(zeta_axis) * 0
theta_axis = np.zeros_like(zeta_axis)
grid_axis = Grid(np.stack([rho_axis, theta_axis, zeta_axis], axis=1))

# Compute on axis
data_axis = eq.compute(["|B|", "p_perp", "p_parallel", "Y"], grid=grid_axis)
B_axis = data_axis["|B|"]
p_perp_axis = data_axis["p_perp"]
p_par_axis = data_axis["p_parallel"]

plt.figure()
plt.plot(zeta_axis, B_axis, label=r"$|B|$")
plt.title("MR = {:.2f}".format(np.max(B_axis) / np.min(B_axis)))
plt.legend()

In [None]:
plt.figure()
plt.plot(data_axis["Y"], p_perp_axis, label=r"$p_\perp(z)$")


plt.plot(data_axis["Y"], p_par_axis, label=r"$p_\parallel(z)$")
plt.legend()

In [None]:
### LCFS metrics ? -- ordered set of points

In [None]:
# %% Get fieldlines on the LCFS
# assert isinstance(eq, Equilibrium)

Nz = 200
Nt = 200
rho_surf = 1

theta = np.linspace(0, 2 * np.pi, Nt)
zeta = np.linspace(0, 2 * np.pi, Nz, endpoint=True)

theta_mesh, zeta_mesh = np.meshgrid(theta, zeta, indexing="ij")
shape = theta_mesh.shape
coords_sfl = np.stack(
    [
        np.ones_like(theta_mesh.reshape(-1)) * rho_surf,
        theta_mesh.reshape(-1),
        zeta_mesh.reshape(-1),
    ],
    axis=-1,
)

coords_geo = eq.compute_theta_coords(coords_sfl)
coords_data = eq.compute(
    ["R", "Z", "zeta", "p", "|B|"], grid=Grid(nodes=coords_geo, sort=False)
)


lines = np.array([coords_data["R"], coords_data["zeta"], coords_data["Z"]])
lines = lines.reshape(3, shape[0], shape[1])
lines = np.moveaxis(lines, [0, 1], [1, 0])

B_on_lines = coords_data["|B|"]
B_on_lines = B_on_lines.reshape(shape[0], shape[1])

In [None]:
from scipy.constants import mu_0

data_mesh = eq.compute(
    [
        "p_parallel",
        "theta",
        "X",
        "Y",
        "Z",
        "|B|",
        "B",
        "p_parallel",
        "p_perp",
        "J",
        "e^theta",
        "e^rho",
        "e^zeta",
        "grad(alpha)",
    ],
    grid=Grid(nodes=coords_geo, sort=False),
)
p_par_mesh = data_mesh["p_parallel"]

### axial pressure profile
mu0 = 4 * np.pi * 1e-7

# ---- Grid for axis (rho = 0, theta = 0, zeta from 0 to 2pi) ----
zeta_axis = np.linspace(0, 2 * np.pi, 100)
rho_axis = np.zeros_like(zeta_axis)
theta_axis = np.zeros_like(zeta_axis)
grid_axis = Grid(np.stack([rho_axis, theta_axis, zeta_axis], axis=1))

# Compute on axis
data_axis = eq.compute(["|B|", "p_perp", "p_parallel"], grid=grid_axis)
B_axis = data_axis["|B|"]
p_perp_axis = data_axis["p_perp"]
p_par_axis = data_axis["p_parallel"]


# ---- Grid for midplane slice (zeta = pi, theta = 0, rho from 0 to 1) ----
rho_mid = np.linspace(0, 1, 100)
theta_mid = np.zeros_like(rho_mid)
zeta_mid = np.ones_like(rho_mid) * np.pi
grid_mid = Grid(np.stack([rho_mid, theta_mid, zeta_mid], axis=1))

# Compute on midplane profile
data_mid = eq.compute(["p_perp", "|B|", "p_parallel"], grid=grid_mid)
B_mid = data_mid["|B|"]
p_perp_mid = data_mid["p_perp"]
p_par_mid = data_mid["p_parallel"]


X2d = np.reshape(data_mesh["X"], (Nz, Nt))
Y2d = np.reshape(data_mesh["Y"], (Nz, Nt))
Z2d = np.reshape(data_mesh["Z"], (Nz, Nt))
B2d = np.reshape(data_mesh["|B|"], (Nz, Nt))
J2d = np.reshape(data_mesh["J"], (Nz, Nt, 3))
theta2d = np.reshape(data_mesh["theta"], (Nz, Nt))

etheta = np.reshape(data_mesh["e^theta"], (Nz, Nt, 3))
ezeta = np.reshape(data_mesh["e^zeta"], (Nz, Nt, 3))
erho = np.reshape(data_mesh["e^rho"], (Nz, Nt, 3))

e_theta_n = etheta / np.linalg.norm(etheta, axis=-1)[:, :, None]
e_zeta_n = ezeta / np.linalg.norm(ezeta, axis=-1)[:, :, None]
e_rho_n = erho / np.linalg.norm(erho, axis=-1)[:, :, None]

J2d_n = J2d / np.linalg.norm(J2d, axis=-1)[:, :, None]

comp_theta = np.sum(J2d_n * e_theta_n, axis=-1)
comp_zeta = np.sum(J2d_n * e_zeta_n, axis=-1)
comp_rho = np.sum(J2d_n * e_rho_n, axis=-1)

Bv = np.reshape(data_mesh["B"], (Nz, Nt, 3))
B2d_n = Bv / np.linalg.norm(Bv, axis=-1)[:, :, None]

JdotB = np.sum(J2d_n * B2d_n, axis=-1)

# gradp = np.reshape(data_mesh['grad(p)'],(Nz,Nt,3))
# gradp_n =  gradp / np.linalg.norm(gradp,axis = -1)[:,:,None]

# Jdotgradp = np.sum(J2d_n*gradp_n,axis = -1)

# Bdotgradp = np.sum(B2d_n*gradp_n,axis = -1)

JdotB2d_n = np.sum(J2d_n * B2d_n, axis=-1)

gradalpha = np.reshape(data_mesh["grad(alpha)"], (Nz, Nt, 3))
gradalpha_n = gradalpha / np.linalg.norm(gradalpha, axis=-1)[:, :, None]

Jdotgradalpha_n = np.sum(J2d_n * gradalpha_n, axis=-1)

beta = 2 * (data_mesh["p_perp"] + p_par_mesh) * mu_0 / (data_mesh["|B|"] ** 2 + 1e-20)
beta_axis = (
    2 * (data_axis["p_perp"] + p_par_axis) * mu_0 / (data_axis["|B|"] ** 2 + 1e-20)
)
beta_midplane = (
    2 * (data_mid["p_perp"] + p_par_mid) * mu_0 / (data_mid["|B|"] ** 2 + 1e-20)
)


beta2d = np.reshape(beta, (Nz, Nt))

In [None]:
fig, axs = plt.subplots(3, 3, figsize=(16, 11))

nlevels = 500
cmap = "coolwarm"

# field lines
i = 0
for line in lines[::10]:
    fieldlines = axs[0, 0].scatter(line[1], line[2], c=B_on_lines[i], s=0.2)
    i = i + 10
fig.colorbar(fieldlines, ax=axs[0, 0], label=r"|B|")

#

axs[0, 1].plot(zeta_axis, p_par_axis, label=r"$p_{\parallel}$")
axs[0, 1].plot(zeta_axis, p_perp_axis, label=r"$p_{\perp}$")
axs[0, 1].legend()

axs[0, 2].plot(rho_mid, p_par_mid, label=r"$p_{\parallel}$")
axs[0, 2].plot(rho_mid, p_perp_mid, label=r"$p_{\perp}$")
axs[0, 2].legend()


im = axs[1, 0].contour(Y2d, theta2d, beta2d, levels=nlevels, cmap=cmap)
fig.colorbar(im, ax=axs[1, 0])


im = axs[1, 1].plot(zeta_axis, beta_axis)
im = axs[1, 2].plot(rho_mid, beta_midplane)


# J
im = axs[2, 1].contour(Y2d, theta2d, comp_rho, levels=nlevels, cmap=cmap)
fig.colorbar(im, ax=axs[2, 1])

im = axs[2, 0].contour(Y2d, theta2d, Jdotgradalpha_n, levels=nlevels, cmap=cmap)
fig.colorbar(im, ax=axs[2, 0])

im = axs[2, 2].contour(Y2d, theta2d, JdotB2d_n, levels=nlevels, cmap=cmap)
fig.colorbar(im, ax=axs[2, 2])


# titles
axs[0, 0].set_title(r"$\rho = {}$ Field lines".format(rho_surf))
axs[0, 1].set_title("axial profile")
axs[0, 2].set_title("radial profile (midplane)")

axs[1, 0].set_title(r"$\beta$ at $\rho = {}$ ".format(rho_surf))
axs[1, 1].set_title(r"axial $\beta$")
axs[1, 2].set_title(r"radial $\beta$")

axs[2, 2].set_title(r"$J \cdot b$")
axs[2, 0].set_title(r"$J \cdot \nabla \alpha$")
axs[2, 1].set_title(r"$J \cdot e^\rho$")

axs[0, 0].set_ylabel(r"$Z$")
axs[0, 1].set_ylabel(r"$pressure(Pa)$")
axs[0, 2].set_ylabel(r"$pressure(Pa)$")

axs[2, 2].set_ylabel(r"$\theta$")
axs[2, 0].set_ylabel(r"$\theta$")
axs[2, 1].set_ylabel(r"$\theta$")


axs[2, 0].set_xlabel(r"$Y$")
axs[2, 1].set_xlabel(r"$Y$")
axs[2, 2].set_xlabel(r"$Y$")


axs[0, 0].grid()
axs[1, 1].grid()
axs[1, 2].grid()

axs[2, 2].grid()
axs[2, 0].grid()
axs[2, 1].grid()

fig.tight_layout()

In [None]:
# Create a figure and axis manually
fig, ax = plt.subplots(figsize=(12, 8))

# Plot on the custom axis
desc.plotting.plot_1d(eq, name="magnetic well", ax=ax, lw=2)

In [None]:
# 3D

In [None]:
# %% Get fieldlines on the LCFS
# assert isinstance(eq, Equilibrium)

Nz = 50
Nt = 50
Nr = 1

theta = np.linspace(0, 2 * np.pi, Nt)
zeta = np.linspace(0, 2 * np.pi, Nz, endpoint=False)
rho = np.linspace(1, 1, Nr)

full_coords = []
for r in rho:

    theta_mesh, zeta_mesh = np.meshgrid(theta, zeta, indexing="ij")
    shape = theta_mesh.shape
    coords_sfl = np.stack(
        [
            np.ones_like(theta_mesh.reshape(-1)) * r,
            theta_mesh.reshape(-1),
            zeta_mesh.reshape(-1),
        ],
        axis=-1,
    )

    coords_geo = eq.compute_theta_coords(coords_sfl)

    full_coords.append(coords_geo)

full_coords = np.array(full_coords)

In [None]:
coords_geo = np.array(
    [
        full_coords[:, :, 0].ravel(),
        full_coords[:, :, 1].ravel(),
        full_coords[:, :, 2].ravel(),
    ]
).T
grid = Grid(nodes=coords_geo, sort=False)
# data = eq.compute(['isodynamicity','X','Y','Z','|B|','<beta>_vol','beta_a','|F|','p','grad(p)','J','|J|','J_rho','J_theta','J_zeta','J*B','J_parallel'], grid = grid)

In [None]:
data = eq.compute(
    [
        "isodynamicity",
        "X",
        "Y",
        "Z",
        "|B|",
        "beta_a",
        "|F|",
        "p",
        "grad(p)",
        "J",
        "|J|",
        "J_rho",
        "J_theta",
        "J_zeta",
        "J*B",
        "J_parallel",
    ],
    grid=grid,
)

In [None]:
beta = 2 * ((data["p"]) * mu_0 / (data["|B|"] ** 2 + 1e-20)).reshape(Nt, Nz)
B = data["|B|"].reshape(Nt, Nz)

In [None]:
# 3D

x = data["X"].reshape(Nt, Nz)
y = data["Y"].reshape(Nt, Nz)
z = data["Z"].reshape(Nt, Nz)

fig = mlab.figure(bgcolor=(1, 1, 1), fgcolor=(0.0, 0.0, 0.0))


lines = np.array([x, y, z]).T
lines = np.swapaxes(lines, 0, 1)
for line in lines[::2]:
    mlab.plot3d(line[:, 0], line[:, 1], line[:, 2], tube_radius=0.002)


mlab.mesh(x, y, z, scalars=B, opacity=0.9, colormap="viridis")

mlab.colorbar(title="|B|")


mlab.view(azimuth=105, elevation=75, distance=10, focalpoint="auto")  # Adjust as needed

mlab.show()

In [None]:
# 3D

x = data["X"]
y = data["Y"]
z = data["Z"]
b = data["|B|"]
b = data["isodynamicity"]

fig = mlab.figure(bgcolor=(1, 1, 1), fgcolor=(0.0, 0.0, 0.0))


points = mlab.points3d(x, y, z, b, scale_factor=0.01, scale_mode="none")
# points = mlab.points3d(x,y,z,scale_factor = .05, scale_mode = 'none')

mlab.colorbar()


mlab.show()

In [None]:
# 3D

x = data["X"]
y = data["Y"]
z = data["Z"]
# b = data['p']
# bv = data['grad(p)']
bv = data["B"]

fig = mlab.figure(bgcolor=(1, 1, 1), fgcolor=(0.0, 0.0, 0.0))

points = mlab.points3d(x, y, z, scale_factor=0.01, scale_mode="none")
arrows = mlab.quiver3d(
    x,
    y,
    z,
    bv[:, 0],
    bv[:, 1],
    bv[:, 2],
    scale_factor=0.1,
    line_width=10,
    scale_mode="none",
)

bv = data["J"]
arrows1 = mlab.quiver3d(
    x,
    y,
    z,
    bv[:, 0],
    bv[:, 1],
    bv[:, 2],
    scale_factor=0.1,
    line_width=10,
    scale_mode="none",
    color=(1, 0, 1),
)

# bv = data['e^zeta']
# arrows2 = mlab.quiver3d(x,y,z,bv[:,0],bv[:,1],bv[:,2], scale_factor=.1, line_width=10,scale_mode = 'none',color = (1,0,0))

# bv = data['e^theta']
# arrows3 = mlab.quiver3d(x,y,z,bv[:,0],bv[:,1],bv[:,2], scale_factor=.1, line_width=10,scale_mode = 'none',color = (0,0,1))


# mlab.colorbar(arrows)


mlab.show()

In [None]:
###### From the Balloning Example 

In [None]:
# Flux surfaces on which to evaluate ballooning stability
surfaces = np.array([0.01, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0])
# Field lines on which to evaluate ballooning stability
alpha = np.linspace(0, 2*np.pi, 16, endpoint=False)
# Number of toroidal transits of the field line
nturns = 1
# Number of point along a field line in ballooning space
N0 = nturns * 200
# range of the ballooning coordinate zeta
zeta = np.linspace(0, 2*np.pi * nturns, N0)
# we need to make a special grid in field aligned coordinates, which we do here
# coordinates="raz" tells desc that this grid is in rho,alpha,zeta coordinates.
grid = Grid.create_meshgrid([surfaces, alpha, zeta], coordinates="raz")

In [None]:
data = eq.compute(
    ["ideal ballooning lambda", "ideal ballooning eigenfunction"], grid=grid
)

In [None]:
data = eq.compute(["Newcomb ballooning metric"], grid=grid, data=data)


In [None]:
plt.plot(surfaces, data["Newcomb ballooning metric"], "-or", ms=4)
plt.xlabel(r"$\rho$", fontsize=18)
plt.ylabel("Newcomb metric", fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16);

In [None]:
print("Growth rate and eigenfunction calculation finished!")
eigenvals = data["ideal ballooning lambda"]
eigenfuns = data["ideal ballooning eigenfunction"]

lambda_max0 = np.zeros(surfaces.size)
eigenfunc_max0 = np.zeros((surfaces.size, N0))
for j in range(surfaces.size):
    idxmax = np.argmax(eigenvals[j])
    alpha_idx, zeta0_idx, eigval_idx = np.unravel_index(idxmax, eigenvals[j].shape)
    # max eigenvalues
    lambda_max0[j] = eigenvals[j, alpha_idx, zeta0_idx, eigval_idx]
    # eigenfunction corresponding to the max eigenvalue
    X0 = eigenfuns[j, alpha_idx, zeta0_idx, :, eigval_idx]
    sign_max = np.sign(X0[np.argmax(np.abs(X0))])
    eigenfunc_max0[j, 1:-1] = X0 / np.max(np.abs(X0)) * sign_max

In [None]:
plt.plot(surfaces, lambda_max0, "-or", ms=4)
plt.xlabel(r"$\rho$", fontsize=18)
plt.ylabel(r"$\lambda_{\mathrm{max}}$", fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.figure()
plt.plot(zeta, eigenfunc_max0[3])  # plotting eigenfunction on rho=0.4
plt.xlabel(r"$\zeta$", fontsize=18)
plt.ylabel(r"$X_{\mathrm{max}}$", fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16);

In [None]:
data_test  = eq.compute(
    ["X","Y","Z"], grid=grid
)

In [None]:
plt.figure()
plt.plot(data_test["X"],data_test["Z"],'.')

In [None]:
### from the neoclassical example

In [None]:
import numpy as np
from matplotlib import pyplot as plt

from desc.integrals import Bounce2D
from desc.examples import get
from desc.grid import LinearGrid
from desc.optimize import Optimizer
from desc.objectives import (
    ForceBalance,
    FixPsi,
    FixBoundaryR,
    FixBoundaryZ,
    GenericObjective,
    FixPressure,
    FixIota,
    AspectRatio,
    EffectiveRipple,
    ObjectiveFunction,
)

In [None]:
def plot_wells(
    eq,
    grid,
    theta,
    Y_B=None,
    num_transit=1,
    num_well=None,
    num_pitch=10,
):
    """Plotting tool to help user set tighter upper bound on ``num_well``.

    Parameters
    ----------
    eq : Equilibrium
        Equilibrium to compute on.
    grid : Grid
        Tensor-product grid in (ρ, θ, ζ) with uniformly spaced nodes
        (θ, ζ) ∈ [0, 2π) × [0, 2π/NFP).
        Number of poloidal and toroidal nodes preferably rounded down to powers of two.
        Determines the flux surfaces to compute on and resolution of FFTs.
    theta : jnp.ndarray
        Shape (num rho, X, Y).
        DESC coordinates θ sourced from the Clebsch coordinates
        ``FourierChebyshevSeries.nodes(X,Y,rho,domain=(0,2*jnp.pi))``.
        Use the ``Bounce2D.compute_theta`` method to obtain this.
        ``X`` and ``Y`` are preferably rounded down to powers of two.
    Y_B : int
        Desired resolution for algorithm to compute bounce points.
        Default is double ``Y``.
    num_transit : int
        Number of toroidal transits to follow field line.
        In an axisymmetric device, field line integration over a single poloidal
        transit is sufficient to capture a surface average. For a 3D
        configuration, more transits will approximate surface averages on an
        irrational magnetic surface better, with diminishing returns.
    num_well : int
        Maximum number of wells to detect for each pitch and field line.
        Giving ``None`` will detect all wells but due to current limitations in
        JAX this will have worse performance.
        Specifying a number that tightly upper bounds the number of wells will
        increase performance. In general, an upper bound on the number of wells
        per toroidal transit is ``Aι+B`` where ``A``, ``B`` are the poloidal and
        toroidal Fourier resolution of B, respectively, in straight-field line
        PEST coordinates, and ι is the rotational transform normalized by 2π.
        A tighter upper bound than ``num_well=(Aι+B)*num_transit`` is preferable.
        The ``check_points`` or ``plot`` methods in ``desc.integrals.Bounce2D``
        are useful to select a reasonable value.
    num_pitch: int
        Number of pitch angles.

    Returns
    -------
    plots
        Matplotlib (fig, ax) tuples for the 1D plot of each field line.

    """
    data = eq.compute(Bounce2D.required_names + ["min_tz |B|", "max_tz |B|"], grid=grid)
    bounce = Bounce2D(grid, data, theta, Y_B, num_transit=num_transit)
    pitch_inv, _ = Bounce2D.get_pitch_inv_quad(
        grid.compress(data["min_tz |B|"]),
        grid.compress(data["max_tz |B|"]),
        num_pitch,
    )
    points = bounce.points(pitch_inv, num_well)
    plots = bounce.check_points(points, pitch_inv)
    return plots

In [None]:
eq0 = eq.copy()

In [None]:
rho = np.linspace(0.01, 1, 5)
grid = LinearGrid(rho=rho, M=eq0.M_grid, N=eq0.N_grid, NFP=eq0.NFP, sym=False)

# ---------- How to pick resolution? ----------
# Plotting for 3 toroidal transits to see by eye
# Seems like these resolutions are more than sufficient.
# We will use more pitch angles for the integration.
X, Y = 16, 32
theta = Bounce2D.compute_theta(eq0, X, Y, rho)
num_transit = 1
Y_B = 32

In [None]:
'''
plot_wells(
    eq0,
    grid,
    theta,
    Y_B=Y_B,
    num_transit=num_transit,
    num_well=10 * num_transit,
);
'''

In [None]:
'''
plot_wells(
    eq0,
    grid,
    theta,
    Y_B=Y_B,
    num_transit=num_transit,
    # Here we see some wells are ignored if num_well is too low.
    num_well=1 * num_transit,
);
'''

In [None]:
num_transit = 1
num_well = 10 * num_transit
num_quad = 32
num_pitch = 45
data = eq0.compute(
    "effective ripple",
    grid=grid,
    theta=theta,
    Y_B=Y_B,
    num_transit=num_transit,
    num_well=num_well,
    num_quad=num_quad,
    num_pitch=num_pitch,
    # Can also specify ``pitch_batch_size`` which determines the
    # number of pitch values to compute simultaneously.
    # Reduce this if insufficient memory. If insufficient memory is detected
    # early then the code will exit and return ε = 0 everywhere. If not detected
    # early then typical OOM errors will occur.
)

eps = grid.compress(data["effective ripple"])
fig, ax = plt.subplots()
ax.plot(rho, eps, marker="o")
ax.set(xlabel=r"$\rho$", ylabel=r"$\epsilon$", title="")
plt.tight_layout()
plt.show()