In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))

In [2]:
from desc.backend import jnp, jit
from desc.grid import ConcentricGrid
import numpy as np

DESC version 0.5.0+42.ga55ac71.dirty, using JAX backend, jax version=0.2.25, jaxlib version=0.1.76, dtype=float64
Using device: CPU, with 9.68 GB available memory


In [3]:
jnp.set_printoptions(precision=3, floatmode="fixed")
rng = np.random.default_rng()

In [4]:
L = rng.integers(low=1, high=100)
M = rng.integers(low=1, high=100)
N = rng.integers(low=1, high=100)
print(L, M, N)
grid = ConcentricGrid(L=L, N=N, M=M, node_pattern="jacobi")
# print("nodes", "             ", "spacing")
# for a, b in zip(grid.nodes, grid.spacing):
#     print(a, b)

65 27 75


### Bulk flux surface averaging test

In [5]:
rho = grid.nodes[:, 0]
weights = np.random.random_sample(size=len(rho))

In [6]:
iota_1 = np.zeros(grid.num_rho)

# DESIRED ALGORITHM
# collect collocation node indices for each rho surface
surfaces = dict()
for index, r in enumerate(rho):
    surfaces.setdefault(r, list()).append(index)
# integration over non-contiguous elements
for i, surface in enumerate(surfaces.values()):
    iota_1[i] = weights[surface].sum()

In [7]:
@jit
def _surface_sums(surf_label, unique_append_upperbound, weights):
    """
    Parameters
    ----------
    surf_label : ndarray
        The surface label. Elements of a coordinate in the collocation grid.
        i.e. grid.nodes[:, 0]
    unique_append_upperbound : ndarray
        Sorted unique elements of surf_label with the upper bound of that coordinate appended.
        i.e. grid.unique_rho + [1]
    weights : ndarray
        Node at surf_label[i]'s contribution to its surface's sum is weights[i].
        For an integral, this could be: ds * function_to_integrate.

    Returns
    -------
    ndarray
        An array of weighted sums over each surface.
        The returned array has length = len(unique_append_upperbound) - 1.
    """
    # DESIRED ALGORITHM
    # collect collocation node indices for each rho surface
    # surfaces = dict()
    # for index, rho in enumerate(surf_label):
    #     surfaces.setdefault(rho, list()).append(index)
    # integration over non-contiguous elements
    # for i, surface in enumerate(surfaces.values()):
    #     surface_sums[i] = weights[surface].sum()

    # NO LOOP IMPLEMENTATION
    # Separate collocation nodes into bins with boundaries at unique values of rho.
    # This groups nodes with identical rho values.
    # Each is assigned a weight of their contribution to the integral.
    # The elements of each bin are summed, performing the integration.
    return jnp.histogram(surf_label, bins=unique_append_upperbound, weights=weights)[0]

In [8]:
# NO LOOP IMPLEMENTATION
bins = jnp.append(grid.unique_rho, 1)
iota_2 = _surface_sums(rho, bins, weights)
# bincount, bins = jnp.histogram(rho, bins=bins)
# print(grid.unique_rho)
# print(bincount)
# print(bins)

In [9]:
print(iota_1)
print(iota_2)
assert jnp.allclose(iota_1, iota_2)

[ 147.208  235.316  382.794  549.182  614.382  750.044  923.733  984.812
 1119.754 1295.693 1367.538 1525.841 1664.855 1756.368 1863.554 2073.164
 2119.459 2283.247 2445.757 2462.900 2681.401 2821.062 2870.304 3023.904
 3174.562 3270.566 3420.687 3563.704 3611.914 3756.860 3932.809 3997.768
 4112.891]
[ 147.208  235.316  382.794  549.182  614.382  750.044  923.733  984.812
 1119.754 1295.693 1367.538 1525.841 1664.855 1756.368 1863.554 2073.164
 2119.459 2283.247 2445.757 2462.900 2681.401 2821.062 2870.304 3023.904
 3174.562 3270.566 3420.687 3563.704 3611.914 3756.860 3932.809 3997.768
 4112.891]


In [10]:
# custom implementation of jnp.unique()
# @jit # still fails
def _where_change(x, size):
    return jnp.where(jnp.diff(x, prepend=jnp.nan), size=size)[0]


# @jit # still fails
def get_unique_rho(grid):
    change_zeta = _where_change(grid.nodes[:, -1], grid.num_zeta)
    stop_zeta = change_zeta[1] if grid.num_zeta > 1 else None
    change_rho = _where_change(grid.nodes[:stop_zeta, 0], grid.num_rho)
    return grid.nodes[change_rho, 0]


assert jnp.allclose(grid.unique_rho, get_unique_rho(grid))