In [1]:
assert "jax" not in sys.modules, "jax already imported: you must restart your runtime - DO NOT RUN THIS FUNCTION TWICE"
import jax

import jax.numpy as jnp

import numpy as np

import sys
import os

In [2]:
def trilinearInterpolator(x, y, z, lengths, dims, values, query_points, *, fill_value = jnp.nan):
    idr = jnp.clip(
        jnp.floor(
            ((query_points / jnp.asarray(lengths)) + 0.5) * (jnp.asarray(dims) - 1)
        ).astype(jnp.int32),
        0, jnp.asarray(dims) - 2
    )

    wx = (query_points[:, 0] - x[idr[:, 0]]) / (x[idr[:, 0] + 1] - x[idr[:, 0]])
    wy = (query_points[:, 1] - y[idr[:, 1]]) / (y[idr[:, 1] + 1] - y[idr[:, 1]])
    wz = (query_points[:, 2] - z[idr[:, 2]]) / (z[idr[:, 2] + 1] - z[idr[:, 2]])

    return (
        values[idr[:, 0], idr[:, 1], idr[:, 2]] * (1 - wx) * (1 - wy) * (1 - wz) +
        values[idr[:, 0], idr[:, 1], idr[:, 2] + 1] * (1 - wx) * (1 - wy) * wz       +
        values[idr[:, 0], idr[:, 1] + 1, idr[:, 2]] * (1 - wx) * wy       * (1 - wz) +
        values[idr[:, 0], idr[:, 1] + 1, idr[:, 2] + 1] * (1 - wx) * wy       * wz       +
        values[idr[:, 0] + 1, idr[:, 1], idr[:, 2]] * wx       * (1 - wy) * (1 - wz) +
        values[idr[:, 0] + 1, idr[:, 1], idr[:, 2] + 1] * wx       * (1 - wy) * wz       +
        values[idr[:, 0] + 1, idr[:, 1] + 1, idr[:, 2]] * wx       * wy       * (1 - wz) +
        values[idr[:, 0] + 1, idr[:, 1] + 1, idr[:, 2] + 1] * wx       * wy       * wz
    )

dim = 3

x_local = jnp.linspace(-1, 1, dim)
y_local = jnp.linspace(-1, 1, dim)
z_local = jnp.linspace(-1, 1, dim)

x_l = x_local[-1] - x_local[0]
y_l = y_local[-1] - y_local[0]
z_l = z_local[-1] - z_local[0]

# copy = true MUST be set else x, y, z will be overwritten by changes to XX, YY, ZZ and vice versa
values, _, _ = jnp.meshgrid(x_local, y_local, z_local, indexing = 'ij', copy = True)

'''
a, b, c = values.shape
for i in range(a):
    for j in range(b):
        for k in range(c):
            values[i, j, k] = 1
'''

print(values)

print(x_local, y_local, z_local)

query_points = jnp.array([0.5, 0, 0])
query_points = query_points.reshape(1, 3)
result = trilinearInterpolator(x_local, y_local, z_local, (x_l, y_l, z_l), (dim, dim, dim), values, query_points)
print(result)
print(type(result))

[[[-1. -1. -1.]
  [-1. -1. -1.]
  [-1. -1. -1.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]

 [[ 1.  1.  1.]
  [ 1.  1.  1.]
  [ 1.  1.  1.]]]
[-1.  0.  1.] [-1.  0.  1.] [-1.  0.  1.]
[0.5]
<class 'jaxlib._jax.ArrayImpl'>


In [7]:
n_cells = 128
Np = 1

print("\nStarting cpu sharding test run with", n_cells, "cells and", Np, "rays.")

jax.config.update('jax_enable_x64', True)

extent_x = 5e-3
extent_y = 5e-3
extent_z = 10e-3

probing_extent = extent_z

lengths = 2 * np.array([extent_x, extent_y, extent_z])

from scipy.constants import c
from scipy.constants import e

class ScalarDomain():
    def __init__(self, lengths, dim):
        self.x_length, self.y_length, self.z_length = lengths[0], lengths[1], lengths[2]
        self.x_n, self.y_n, self.z_n = dim, dim, dim

        self.x = np.float32(np.linspace(-self.x_length / 2, self.x_length / 2, self.x_n))
        self.y = np.float32(np.linspace(-self.y_length / 2, self.y_length / 2, self.y_n))
        self.z = np.float32(np.linspace(-self.z_length / 2, self.z_length / 2, self.z_n))
        self.coordinates = np.stack([self.x, self.y, self.z], axis = 1)

        self.XX, self.YY, _ = np.meshgrid(self.x, self.y, self.z, indexing = 'ij', copy = True)
        self.ZZ = None

        self.XX = self.XX / 2e-3
        self.XX = 10 ** self.XX

        self.YY = self.YY / 1e-3
        self.YY = np.pi * self.YY
        self.YY = 2 * self.YY
        self.YY = np.cos(self.YY)
        self.YY = 1 + self.YY

        self.ne = self.XX * self.YY

        self.ne = 1e24 * self.ne

domain = ScalarDomain(lengths, n_cells)

lwl = 1064e-9

divergence = 5e-5
beam_size = extent_x
ne_extent = probing_extent
beam_type = 'circular'

def init_beam(Np, beam_size, divergence, ne_extent):
    s0 = np.zeros((9, Np))

    t  = 2 * np.pi * np.random.randn(Np)

    u  = np.random.randn(Np)

    ϕ = np.pi * np.random.randn(Np)
    χ = divergence * np.random.randn(Np)

    s0[0, :] = beam_size * u * np.cos(t)
    s0[1, :] = beam_size * u * np.sin(t)
    s0[2, :] = -ne_extent

    s0[3, :] = c * np.sin(χ) * np.cos(ϕ)
    s0[4, :] = c * np.sin(χ) * np.sin(ϕ)
    s0[5, :] = c * np.cos(χ)

    s0[6, :] = 1.0
    s0[8, :] = 0.0
    s0[7, :] = 0.0

    return s0

beam_definition = init_beam(Np, beam_size, divergence, ne_extent)

from scipy.integrate import odeint, solve_ivp
from time import time

def trilinearInterpolator(x, y, z, lengths, dims, values, query_points, *, fill_value = np.nan):
    idr = np.clip(
        np.floor(
            ((query_points / np.asarray(lengths)) + 0.5) * (np.asarray(dims) - 1)
        ).astype(np.int32),
        0, np.asarray(dims) - 2
    )

    wx = (query_points[:, 0] - x[idr[:, 0]]) / (x[idr[:, 0] + 1] - x[idr[:, 0]])
    wy = (query_points[:, 1] - y[idr[:, 1]]) / (y[idr[:, 1] + 1] - y[idr[:, 1]])
    wz = (query_points[:, 2] - z[idr[:, 2]]) / (z[idr[:, 2] + 1] - z[idr[:, 2]])

    return (
        values[idr[:, 0], idr[:, 1], idr[:, 2]] * (1 - wx) * (1 - wy) * (1 - wz) +
        values[idr[:, 0], idr[:, 1], idr[:, 2] + 1] * (1 - wx) * (1 - wy) * wz       +
        values[idr[:, 0], idr[:, 1] + 1, idr[:, 2]] * (1 - wx) * wy       * (1 - wz) +
        values[idr[:, 0], idr[:, 1] + 1, idr[:, 2] + 1] * (1 - wx) * wy       * wz       +
        values[idr[:, 0] + 1, idr[:, 1], idr[:, 2]] * wx       * (1 - wy) * (1 - wz) +
        values[idr[:, 0] + 1, idr[:, 1], idr[:, 2] + 1] * wx       * (1 - wy) * wz       +
        values[idr[:, 0] + 1, idr[:, 1] + 1, idr[:, 2]] * wx       * wy       * (1 - wz) +
        values[idr[:, 0] + 1, idr[:, 1] + 1, idr[:, 2] + 1] * wx       * wy       * wz
    )

def calc_dndr(ScalarDomain, lwl = 1064e-9):
    omega = 2 * np.pi * c / lwl
    nc = 3.14207787e-4 * omega ** 2

    return (np.array(ScalarDomain.ne / nc, dtype = np.float32), omega)

def dndr(r, ne, omega, coordinates, length, dim):
    grad = np.zeros_like(r)

    dndx = -0.5 * c ** 2 * np.gradient(ne / (3.14207787e-4 * omega ** 2), coordinates[:, 0], axis = 0)
    grad[0, :] = trilinearInterpolator(coordinates[:, 0], coordinates[:, 1], coordinates[:, 2], length, dim, dndx, r.T, fill_value = 0.0)
    del dndx

    dndy = -0.5 * c ** 2 * np.gradient(ne / (3.14207787e-4 * omega ** 2), coordinates[:, 1], axis = 1)
    grad[1, :] = trilinearInterpolator(coordinates[:, 0], coordinates[:, 1], coordinates[:, 2], length, dim, dndy, r.T, fill_value = 0.0)
    del dndy

    dndz = -0.5 * c ** 2 * np.gradient(ne / (3.14207787e-4 * omega ** 2), coordinates[:, 2], axis = 2)
    grad[2, :] = trilinearInterpolator(coordinates[:, 0], coordinates[:, 1], coordinates[:, 2], length, dim, dndz, r.T, fill_value = 0.0)
    del dndz

    return grad

def dsdt(s, ne, coordinates, omega, length, dim):
    s = np.reshape(s, (9, 1))
    sprime = np.zeros_like(s)

    r = s[:3, :]
    v = s[3:6, :]

    sprime[3:6, :] = dndr(r, ne, omega, coordinates, length, dim)
    print(sprime[3:6, :])
    sprime[:3, :] = v

    del r
    del v

    return sprime.flatten()

ne, omega = calc_dndr(domain)
for i in range(2):
    beam_definition[:, 0] = dsdt(beam_definition[:, 0], ne, domain.coordinates, omega, (domain.x_length, domain.y_length, domain.z_length), (domain.x_n, domain.y_n, domain.z_n))
    print(beam_definition[:, 0])


Starting cpu sharding test run with 128 cells and 1 rays.
[[-1.83172124e-10]
 [ 1.71050399e-09]
 [ 0.00000000e+00]]
[ 9.32790794e+02 -7.14624519e+03  2.99792458e+08 -1.83172124e-10
  1.71050399e-09  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00]
[[ -106496.]
 [-1835008.]
 [       0.]]
[-1.83172124e-10  1.71050399e-09  0.00000000e+00 -1.06496000e+05
 -1.83500800e+06  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00]


  ).astype(np.int32),
