In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))

In [2]:
from desc.backend import jnp, jit
from desc.grid import ConcentricGrid
import numpy as np

DESC version 0.5.0+40.g5d59bae.dirty, using JAX backend, jax version=0.2.25, jaxlib version=0.1.76, dtype=float64
Using device: CPU, with 11.41 GB available memory


In [3]:
jnp.set_printoptions(precision=3, floatmode="fixed")
rng = np.random.default_rng()

In [4]:
L = rng.integers(low=1, high=100)
M = rng.integers(low=1, high=100)
N = rng.integers(low=1, high=100)
print(L, M, N)
grid = ConcentricGrid(L=L, N=N, M=M, node_pattern="jacobi")
# print("nodes", "             ", "spacing")
# for a, b in zip(grid.nodes, grid.spacing):
#     print(a, b)

81 15 75


In [5]:
rho = grid.nodes[:, 0]
num = np.random.random_sample(size=len(rho))
den = np.random.random_sample(size=len(rho))
unique_rho = jnp.unique(rho)
iota_1 = np.zeros_like(unique_rho)
iota_2 = np.zeros_like(unique_rho)

In [6]:
# DESIRED ALGORITHM
# collect collocation indices for each constant rho flux surface
surfaces = dict()
for index, r in enumerate(rho):
    surfaces.setdefault(r, list()).append(index)
# flux surface average integration
for i, surface in enumerate(surfaces.values()):
    iota_1[i] += num[surface].sum()
    iota_1[i] /= den[surface].sum()

In [7]:
# NO LOOP IMPLEMENTATION
bins = jnp.append(unique_rho, 1)
iota_2 += jnp.histogram(rho, bins=bins, weights=num)[0]
iota_2 /= jnp.histogram(rho, bins=bins, weights=den)[0]
# bincount, bins = jnp.histogram(rho, bins=bins)
# print(unique_rho)
# print(bincount)
# print(bins)

In [8]:
print(iota_1)
print(iota_2)
assert jnp.allclose(iota_1, iota_2)

[1.029 0.939 0.962 0.978 1.043 1.016 1.022 1.004 1.008 1.020 1.023 1.020
 1.023 0.988 1.019 1.004 0.989 1.000 1.013 0.964 1.011 0.981 1.008 1.006
 0.999 1.002 1.015 1.012 0.977 1.008 0.995 0.991 0.992 1.004 0.988 1.005
 1.010 0.999 0.978 1.014 1.006]
[1.029 0.939 0.962 0.978 1.043 1.016 1.022 1.004 1.008 1.020 1.023 1.020
 1.023 0.988 1.019 1.004 0.989 1.000 1.013 0.964 1.011 0.981 1.008 1.006
 0.999 1.002 1.015 1.012 0.977 1.008 0.995 0.991 0.992 1.004 0.988 1.005
 1.010 0.999 0.978 1.014 1.006]


In [9]:
# @jit fails
def _where_change(x):
    return jnp.where(jnp.diff(x, prepend=jnp.nan))[0]


def get_unique_rho(nodes):
    change_zeta = _where_change(nodes[:, -1])
    stop_zeta = change_zeta[1] if len(change_zeta) > 1 else None
    change_rho = _where_change(nodes[:stop_zeta, 0])
    return nodes[change_rho, 0]


assert jnp.allclose(unique_rho, get_unique_rho(grid.nodes))