In [59]:
%%time

# from desc import set_device
# set_device("gpu")

# running a job array with SLURM
import os

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import desc.io
from desc.examples import get
from desc.grid import LinearGrid

# idx = int(os.environ["SLURM_ARRAY_TASK_ID"])
# s = ((idx % 50) + 1) / 50

# if idx < 50:
#     name = "QA1"
# elif idx < 100 and idx >= 50:
#     name = "QA2"
# elif idx < 150 and idx >= 100:
#     name = "QA3"
# elif idx < 200 and idx >= 150:
#     name = "QA4"
# elif idx < 250 and idx >= 200:
#     name = "QA4n"
# elif idx < 300 and idx >= 250:
#     name = "QA5"
# elif idx < 350 and idx >= 300:
#     name = "QA4-1"
# elif idx < 400 and idx >= 350:
#     name = "QA4-2"
# elif idx < 450 and idx >= 400:
#     name = "QA4-3"
# elif idx < 500 and idx >= 450:
#     name = "QA4-4"
# elif idx < 550 and idx >= 500:
#     name = "QA5-1"
# elif idx < 600 and idx >= 550:
#     name = "QA5-2"
# elif idx < 650 and idx >= 600:
#     name = "QA5-3"
# elif idx < 700 and idx >= 650:
#     name = "QA5-4"
s = 0.5
name = "QA1"
eq = desc.io.load(name + "_solved.h5")

# from desc.vmec import VMECIO
# eq = VMECIO.load("wout_QA4n.nc")
# eq.save("wout_QA4n.h5")
# print('loaded eq')

# eq = get("W7-X")


from desc.compute.utils import cross, dot
from desc.grid import Grid

# get iota at this surface to use for initial guess
iota = eq.compute("iota", grid=Grid(jnp.array([[jnp.sqrt(s), 0, 0]])))["iota"]

# keep steps within one field period consistent by multiplying by NFP
stepswithin1FP = 100
nfulltransits = 100
stepswithin2pi = stepswithin1FP * eq.NFP

coords = jnp.ones((stepswithin2pi * nfulltransits, 3))
coords = coords.at[:, 0].set(coords[:, 0] * jnp.sqrt(s))
coords = coords.at[:, 2].set(jnp.linspace(0, nfulltransits * 2 * jnp.pi, stepswithin2pi * nfulltransits))
guess = coords.copy()

alpha = 0
coords = coords.at[:, 1].set(coords[:, 1] * alpha)  # set which field line we want

# for initial guess, alpha = zeta + iota*theta*
# rearrange for theta* and approx theta* ~ theta
# theta = (alpha - zeta) / iota
# as initial guess
guess = guess.at[:, 1].set((alpha - guess[:, 2]) / iota)

print("starting map coords")
coords1 = eq.map_coordinates(
    coords=coords,
    inbasis=["rho", "alpha", "zeta"],
    outbasis=["rho", "theta", "zeta"],
    period=[jnp.inf, 2 * jnp.pi, jnp.inf],
    guess=guess,
)  # (2 * jnp.pi / eq.NFP)],
# )
# reason is alpha = zeta + iota * theta*
# if zeta is modded by 2pi/NFP, then after each field period, it is as if we are trying to
# find the theta* for the point (alpha, zeta=0), which is DIFFERENT from (alpha,zeta=2pi/NFP)

# print(coords1)

coords1 = coords1.at[:, 2].set(coords[:, 2])

print('mapped coords')

# print(jnp.any(jnp.isnan(coords1)))
# print(coords1)
# print(jnp.where(jnp.isnan(coords1)))

grid2 = Grid(coords1)
# print(grid2)


wellGamma_c = 0
bigGamma_c = 0

# compute important quantities in DESC.

psi = eq.Psi  # might need to be normalized by 2pi

data_names = [
    "|grad(psi)|",
    "grad(psi)",
    "|grad(zeta)|",
    "e^zeta",
    "|B|",
    "|B|_r",
    "e_theta",
    "kappa_g",
    "B^zeta",
    "B^zeta_r",
    "B_R",
    "psi_r",
    "B_phi",
    "B_Z",
    "iota_r",
    "X",
    "Y",
    "Z",
]
data = eq.compute(data_names, grid2)

print("eq.compute done")

grad_psi_mag = data["|grad(psi)|"]
grad_psi = data["grad(psi)"]
grad_zeta_mag = data["|grad(zeta)|"]
grad_zeta = data["e^zeta"]
e_theta = jnp.linalg.norm(data["e_theta"], axis=-1)
kappa_g = data["kappa_g"]
Bsupz = data["B^zeta"]
dBsupzdpsi = data["B^zeta_r"] * 2 * jnp.pi / data["psi_r"]
dBdpsi = data["|B|_r"] * 2 * jnp.pi / data["psi_r"]  # might need 2pi

Br = data["B_R"]
Bphi = data["B_phi"]
zeta = coords1[:, 2]
B = data["|B|"]
Bxyz = jnp.zeros((len(B), 3))
Bxyz = Bxyz.at[:, 0].set(Br * jnp.cos(zeta) - Bphi * jnp.sin(zeta))
Bxyz = Bxyz.at[:, 1].set(Br * jnp.sin(zeta) + Bphi * jnp.cos(zeta))
Bxyz = Bxyz.at[:, 2].set(data["B_Z"])

dVdb_t1 = data["iota_r"] * dot(cross(grad_psi, Bxyz), grad_zeta) / B

# finding basic arc length of each segment
x = data["X"]
y = data["Y"]
z = data["Z"]
ds = jnp.sqrt(
    jnp.add(jnp.add(jnp.square(jnp.diff(x)), jnp.square(jnp.diff(y))), jnp.square(jnp.diff(z)))
)

maxB = jnp.nanmax(B)
minB = jnp.nanmin(jnp.abs(B))

# integrating dl/b
dloverb = jnp.sum(ds/B[:-1])
ds = jnp.append(ds, 0) # need to look into this

bpstep = 80  # number of iterations through b'
deltabp = (maxB - minB) / (minB * bpstep)
bp = 1+ (deltabp*jnp.linspace(-0.5, bpstep-1.5, bpstep))

B_reflect = minB * bp


B_array = jnp.tile(B, bpstep)
bp_array = jnp.repeat(bp, len(B))
B_reflect_array = minB * bp_array

# array with values 0 or 1 and length bpstep*len(B)
# 0 means not in well, 1 is in well 
# First len(B) values are for bp = lowest, etc
in_well = jnp.where(B_array < B_reflect_array, jnp.ones(len(B)*bpstep), jnp.zeros(len(B)*bpstep))


# flattened grad_psi array with first len(B) values being bp = lowest etc
# "flattened" means every value in each well is reduced to the value with the index corresponding to the minimum in B along that well.
print("starting flattening")
grad_psi_flat = array_flatten_op(jnp.tile(grad_psi_mag, bpstep)*in_well, B_array)
e_theta_flat = array_flatten_op(jnp.tile(e_theta, bpstep)*in_well, B_array)
print("flattened")

ds_array = jnp.tile(ds, bpstep)
sqrt_bbb_array = jnp.sqrt(1 - B_array / B_reflect_array)

dIdb = jnp.nansum(((in_well * ds_array / bp_array / sqrt_bbb_array)).reshape([bpstep,len(B)]), axis = 1) / 2 / minB

dgdb = jnp.nansum((in_well * ds_array * grad_psi_flat * jnp.tile(kappa_g, bpstep) / bp_array / B_array * (sqrt_bbb_array + 1 / sqrt_bbb_array)
).reshape([bpstep,len(B)]),axis = 1) / 2

dBdpsi_array = jnp.tile(dBdpsi, bpstep)

dbigGdb = jnp.nansum((
    in_well * dBdpsi_array * ds_array / B_reflect_array / bp_array / B_array * (sqrt_bbb_array + 1 / sqrt_bbb_array)
        ).reshape([bpstep,len(B)]), axis = 1) / 2

dVdb = jnp.nansum((
    (in_well * jnp.tile(dVdb_t1, bpstep) - (2 * dBdpsi_array - B_array / jnp.tile(Bsupz, bpstep) * jnp.tile(dBsupzdpsi, bpstep)))
    * ds_array / B_array / B_reflect_array * sqrt_bbb_array
        ).reshape([bpstep,len(B)]), axis = 1) * 1.5

print("entering for loop")

well_start_inds = jnp.where(jnp.diff(in_well) == 1)[0] + 1

B_len = len(B)
bp_ind = well_start_inds // B_len
vrovervt = jnp.zeros_like(well_start_inds)
wellGamma_c = jnp.zeros_like(well_start_inds)
for i in range(len(well_start_inds)):
    temp = dgdb[bp_ind[i]] / grad_psi_flat[well_start_inds[i]] / dIdb[bp_ind[i]] / minB / e_theta_flat[well_start_inds[i]]
    vrovervt = temp / (dbigGdb[bp_ind[i]] / dIdb[bp_ind[i]] + 2 / 3 * dVdb[bp_ind[i]] / dIdb[bp_ind[i]])

    gamma_c = 2 * jnp.arctan(vrovervt) / jnp.pi
    wellGamma_c = wellGamma_c.at[i].set(gamma_c * gamma_c * dIdb[bp_ind[i]])

bigGamma_c = jnp.sum(wellGamma_c) * jnp.pi / 2 / jnp.sqrt(2) * deltabp / dloverb

print(bigGamma_c)

# file = name + "_10kSolved.txt"
# f = open(file, "a")
# f.write(f"{s:1.2f}, {bigGamma_c:1.3e}\n")
# f.close()


starting map coords
mapped coords
eq.compute done
starting flattening
flattened
entering for loop




0.17034895595292157
CPU times: user 4min 5s, sys: 40.6 s, total: 4min 46s
Wall time: 3min 49s


In [43]:
import jax.numpy as jnp

def seg_argmin(arr1, arr2): # return list of indices of minimum values in arr2 in each arr1 well
    min_indices = []
    for i in range(len(arr1)):
        if arr1[i] == 0:
            rolling_min = jnp.inf
            continue
        elif arr2[i] < rolling_min:
            rolling_min = arr2[i]
            rolling_min_indice = i
        if arr1[i+1] == 0:
            min_indices.append(rolling_min_indice)
    return min_indices



def array_flatten(arr1, arr2): # return arr1 with each well flattened to the value with the same indice as the minimum of arr2 along that well
    min_indices = seg_argmin(arr1, arr2)
    j = 0
    for i in range(len(arr1)):
        if arr1[i-1] == 0 and arr1[i] != 0:
            start_ind = i
        if arr1[i+1] == 0 and arr1[i] != 0:
            arr1 = arr1.at[start_ind:i+1].set(arr1[min_indices[j]])
            j += 1
    return arr1


# array1 = jnp.array([0, 1, 2, 3, 0, 5, 7, 9, 0])
# array2 = jnp.array([1, 5, 2, 6, 7, 8, 6, 3, 5])
# print(seg_argmin(array1, array2))
# print(array_flatten(array1, array2))
        

In [55]:
def array_flatten_op(arr1, arr2):
    start = jnp.where(jnp.logical_and(arr1 != 0, jnp.roll(arr1, 1) ==0))[0]
    end = jnp.where(jnp.logical_and(arr1 != 0, jnp.roll(arr1, -1) ==0))[0]
    min_indices = seg_argmin_op(arr1,arr2,start,end)
    for i in range(len(min_indices)):
        arr1 = arr1.at[start[i]:end[i]+1].set(arr1[min_indices[i]])
    return arr1

def seg_argmin_op(arr1, arr2, start, end):
    min_indices = jnp.zeros_like(start)
    for i in range(len(start)):
        ind_inwell = jnp.argmin(arr2[start[i]:end[i]+1])
        min_indices = min_indices.at[i].set(start[i] + ind_inwell)
    return min_indices



array1 = jnp.array([0, 1, 2, 3, 0, 5, 7, 9, 0])
array2 = jnp.array([1, 5, 2, 6, 7, 8, 6, 3, 5])

start = jnp.where(jnp.logical_and(array1 != 0, jnp.roll(array1, 1) ==0))[0]
end = jnp.where(jnp.logical_and(array1 != 0, jnp.roll(array1, -1) ==0))[0]
print(seg_argmin_op(array1, array2, start,end))
print(array_flatten(array1, array2))
print(array_flatten_op(array1, array2))

[2 7]
[0 2 2 2 0 9 9 9 0]
[0 2 2 2 0 9 9 9 0]


In [11]:
0.5323510129695246

0.5323510129695246

In [30]:
print(in_well)

[0. 0. 0. ... 1. 1. 1.]


In [32]:
print(bp_ind)

[ 1  2  3  4  5  6  7  7  8  8  8  9  9  9 10 10 10 11 11 11 12 12 12 13
 13 13 14 14 14 15 15 15 16 16 16 17 17 17 18 18 18 19 19 19 19 20 20 20
 21 21 21 22 22 22 23 23 23 24 24 24 25 25 25 26 26 26 27 27 27 28 28 28
 29 29 29 30 30 30 31 31 31 32 32 32 33 33 33 34 34 34 35 35 35 36 36 36
 37 37 37 38 38 38 39 39 39 40 40 40 41 41 41 42 42 42 43 43 43 44 44 44
 45 45 45 46 46 46 47 47 47 48 48 48 49 49 49 50 50 50 51 51 51 52 52 52
 53 53 53 54 54 54 54 55 55 55 56 56 56 57 57 57 58 58 58 59 59 59 60 60
 60 61 61 61 62 62 62 63 63 63 63 64 64 64 64 65 65 65 66 66 66 67 67 67
 68 68 68 69 69 69 70 70 70 70 71 71 71 71 71 72 72 72 72 72 72 73 73 73
 73 73 73 73 74 74 74 74 74 74 74 75 75 75 75 75 75 76 76 77 78 79]


In [31]:
print(dIdb)

[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan]


In [33]:
print(gamma_c)

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


In [34]:
print(wellGamma_c)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [35]:
bp_array

Array([0.9984231, 0.9984231, 0.9984231, ..., 1.247574 , 1.247574 ,
       1.247574 ], dtype=float64)

In [36]:
sqrt_bbb_array

Array([       nan,        nan,        nan, ..., 0.33280666, 0.36694597,
       0.39132203], dtype=float64)

In [37]:
dgdb

Array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan], dtype=float64)

In [50]:
%%time
grad_psi_flat = array_flatten(jnp.tile(grad_psi_mag, bpstep)*in_well, B_array)
e_theta_flat = array_flatten(jnp.tile(e_theta, bpstep)*in_well, B_array)

CPU times: user 5min 4s, sys: 251 ms, total: 5min 5s
Wall time: 5min 5s


In [51]:
%%time
grad_psi_flat = array_flatten_op(jnp.tile(grad_psi_mag, bpstep)*in_well, B_array)
e_theta_flat = array_flatten_op(jnp.tile(e_theta, bpstep)*in_well, B_array)

CPU times: user 1min 56s, sys: 130 ms, total: 1min 56s
Wall time: 1min 56s


In [56]:
%%time
grad_psi_flat = array_flatten_op(jnp.tile(grad_psi_mag, bpstep)*in_well, B_array)
e_theta_flat = array_flatten_op(jnp.tile(e_theta, bpstep)*in_well, B_array)

CPU times: user 12.9 s, sys: 19.9 ms, total: 12.9 s
Wall time: 12.9 s
