This is the code to create 2D training data (see file create_data_2d) completely written in JAX to speed it up.

In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
from typing import Callable, List
import numpy as np
from jax import random
import matplotlib.pyplot as plt
import equinox as eqx
import optax
import scipy
from tqdm.autonotebook import tqdm

  from tqdm.autonotebook import tqdm


In [2]:
def generate_correlated_lognormal_field(
    key,
    shape=(100, 100),
    mean=10.0,
    length_scale=0.1,
    sigma_g=5.0,
    percentile=99
):
    """
    Generate a 2D log-normal random field with spatial correlations.

    Args:
        key: JAX PRNG key.
        shape: tuple (Nx, Ny) of field dimensions.
        mean: desired mean of the real-space log-normal field.
        length_scale: controls correlation (smaller = more small-scale structure).
        sigma_g: std dev of the Gaussian log field (controls contrast).
        percentile: used to return a mask of "top X%" regions.

    Returns:
        field: 2D log-normal field with spatial correlation and given mean.
        top_mask: binary mask of top percentile pixels (e.g., top 1%)
    """
    Nx, Ny = shape
    key, subkey = random.split(key)

    # --- Step 1: define k-space grid
    kx = jnp.fft.fftfreq(Nx) / length_scale
    ky = jnp.fft.fftfreq(Ny) / length_scale
    kx_grid, ky_grid = jnp.meshgrid(kx, ky, indexing='ij')
    k = jnp.sqrt(kx_grid**2 + ky_grid**2)

    # --- Step 2: Power spectrum (Gaussian in log-k)
    k0 = 1.0
    log_k = jnp.log(jnp.clip(k, a_min=1e-6))  # avoid log(0)
    log_k0 = jnp.log(k0)
    sigma_k = 0.5
    P_k = jnp.exp(-0.5 * ((log_k - log_k0) / sigma_k)**2)
    P_k = P_k.at[0, 0].set(0.0)  # zero DC

    # --- Step 3: Generate Gaussian field in Fourier space
    phases = jnp.exp(2j * jnp.pi * random.uniform(subkey, (Nx, Ny)))
    amplitude = jnp.sqrt(P_k)
    fft_field = amplitude * phases

    # Hermitian symmetry for real field
    if Nx % 2 == 0:
        fft_field = fft_field.at[Nx // 2, :].set(fft_field[Nx // 2, :].real)
    if Ny % 2 == 0:
        fft_field = fft_field.at[:, Ny // 2].set(fft_field[:, Ny // 2].real)
    ix = jnp.arange(0, Nx // 2)
    iy = jnp.arange(0, Ny // 2)
    fft_field = fft_field.at[-ix[:, None], -iy[None, :]].set(jnp.conj(fft_field[ix[:, None], iy[None, :]]))

    # --- Step 4: Inverse FFT → correlated Gaussian field
    g = jnp.fft.ifft2(fft_field).real
    g = (g - jnp.mean(g)) / jnp.std(g)  # normalize to mean=0, std=1
    g = sigma_g * g

    # --- Step 5: Exponentiate to log-normal
    lognormal_field = jnp.exp(g)

    # --- Step 6: Rescale to desired mean
    current_mean = jnp.mean(lognormal_field)
    field = lognormal_field * (mean / current_mean)

    # --- Step 7: Create top-X% mask
    threshold = jnp.percentile(field, percentile)
    top_mask = field >= threshold

    return field, top_mask



In [3]:
def compute_intensity(
        shape=(100, 100),
        field=1,
        sigma_value = 0.05,
        j_value = 30.0
        
):
    # Grid and parameters
    #Nx, Ny = 100, 100
    Nx, Ny = shape
    Lx, Ly = 1.0, 1.0
    dx, dy = Lx / Nx, Ly / Ny
    x = jnp.linspace(0, Lx, Nx)
    y = jnp.linspace(0, Ly, Ny)
    X, Y = jnp.meshgrid(x, y, indexing='ij')

    # Opacity and emissivity
    kappa = field # 1.0 ## here your density field
    j0 = j_value # 30.0
    xc, yc = Lx / 2, Ly / 2
    sigma = sigma_value # 0.05
    j_emissivity = j0 * jnp.exp(-((X - xc)**2 + (Y - yc)**2) / (2 * sigma**2))

    # Angular discretization
    N_theta = 16
    theta_list = jnp.linspace(0, 2 * jnp.pi, N_theta, endpoint=False)

    # Storage for total intensity
    J = jnp.zeros((Nx, Ny))

    def theta_loop(state, theta):
        J = state
        mu_x = jnp.cos(theta)
        mu_y = jnp.sin(theta)
        I = jnp.zeros((Nx, Ny))
        
        # Determine sweep order based on angle
        i_range = jax.lax.cond(mu_x >= 0, lambda _: jnp.arange(Nx), lambda _: jnp.arange(Nx - 1, -1, -1), None)
        j_range = jax.lax.cond(mu_y >= 0, lambda _: jnp.arange(Ny), lambda _: jnp.arange(Ny - 1, -1, -1), None)
        
        # ---      

        def inner_loop(state, j):
            I, i  = state

            cond = (i - jnp.sign(mu_x) < 0) | (i - jnp.sign(mu_x) >= Nx) | (j - jnp.sign(mu_y) < 0) | (j - jnp.sign(mu_y) >= Ny)

            def branch_true(_):
                return 0.0, 0.0

            def branch_false(_):
                I_up_x = jax.lax.dynamic_slice(I, (jnp.int32(i - jnp.sign(mu_x)), j), (1, 1))[0, 0]
                I_up_y = jax.lax.dynamic_slice(I, (i, jnp.int32(j - jnp.sign(mu_y))), (1, 1))[0, 0]   
                return I_up_x, I_up_y

            I_up_x, I_up_y = jax.lax.cond(cond, branch_true, branch_false, None)


            denom = jnp.abs(mu_x) / dx + jnp.abs(mu_y) / dy + kappa[i, j]
            I_avg = (jnp.abs(mu_x) * I_up_x / dx + jnp.abs(mu_y) * I_up_y / dy) / denom
            source = j_emissivity[i, j] / denom

            I = I.at[i, j].set(I_avg + source)

            return (I, i), None

        def outer_loop(state, i):
            I = state
            inner_state = (I, i)
            inner_state, _ = jax.lax.scan(inner_loop, inner_state, j_range)

            I = inner_state[0]
            return I, None
                
        I, _ = jax.lax.scan(outer_loop, I, i_range)


        J += I  # Accumulate for mean intensity

        return J, None

    J, _ = jax.lax.scan(theta_loop, J, theta_list)

    # Compute mean intensity
    J /= N_theta
    return J.T 



In [33]:
# this function creates field, mask and corresponding intensity
def create_data(
        key,
        shape=(100, 100),
        mean=10.0,
        length_scale=0.1,
        sigma_g=5.0,
        percentile=99,
        sigma=0.05, 
        j_value = 30.0
):
    
    field, mask = generate_correlated_lognormal_field(key, shape, mean, length_scale, sigma_g, percentile)  # maybe don't use as parameters but instead initialize specific parameters in function randomly 
    intensity = compute_intensity(shape, field, sigma, j_value)

    

    return jnp.array([field, mask, intensity])



create_data_vmap = jax.vmap(create_data, in_axes=(0, None, None, None, None, None, None, 0))  # so far only key and j is varied in creation - maybe adjust


# use jit?

In [None]:
# vmappen und jitten
# und das aus unterster Zelle mit reinbringen s.d. auc direkt mesh in Trainingsdaten


# should also return the r-grid since this will also be used as input for the network  
# check, if it is necessary to speed up the RTE (and than vmap also create_data) - if so there are some comments above how to speed it up

In [34]:
key = random.PRNGKey(0)
keys = random.split(key, 10)
key2 = random.PRNGKey(2) 
j_values = random.uniform(key2, shape=(10,), minval=20.0, maxval=40.0)

results = create_data_vmap(keys, (100,100), 10.0, 0.1, 5.0, 00, 0.05, j_values)

In [35]:
results.shape

(10, 3, 100, 100)

In [None]:
# put this into create data maybe

data_x =  jnp.array([result[0] for result in results])
data_x = data_x[:, jnp.newaxis, :, :]
x = jnp.linspace(0, 1, 100)   # adapt s.t. it is created based on length of results 
X, Y = jnp.meshgrid(x, x)
X_shape_corrected = jnp.repeat(X[jnp.newaxis, jnp.newaxis, :], data_x.shape[0], axis=0)
Y_shape_corrected = jnp.repeat(Y[jnp.newaxis, jnp.newaxis, :], data_x.shape[0], axis=0)
data_x_with_mesh = jnp.concatenate((data_x, X_shape_corrected, Y_shape_corrected), axis=1)

data_y =  jnp.array([result[2] for result in results])
data_y = data_y[:, jnp.newaxis, :, :]


train_x, test_x = data_x_with_mesh[:100], data_x_with_mesh[100:150]  
train_y, test_y = data_y[:100], data_y[100:150]

# save training an testing arrays so I don't have to do the training again for the same setup

In [54]:
# this function creates field, mask and corresponding intensity
def create_training_data(
        keys,
        shape=(100, 100),
        mean=10.0,
        length_scale=0.1,
        sigma_g=5.0,
        percentile=99,
        sigma=0.05, 
        j_values = 30.0
):
    data = create_data_vmap(keys, (100,100), 10.0, 0.1, 5.0, 00, 0.05, j_values)

    data_in = data[:, 0, :, :]
    data_in = data_in[:, jnp.newaxis, :, :]

    x = jnp.linspace(0, 1, shape[0])   # adapt s.t. it is created based on length of results
    y = jnp.linspace(0, 1, shape[1])
    X, Y = jnp.meshgrid(x, y)
    X_shape_corrected = jnp.repeat(X[jnp.newaxis, jnp.newaxis, :], data_in.shape[0], axis=0)
    Y_shape_corrected = jnp.repeat(Y[jnp.newaxis, jnp.newaxis, :], data_in.shape[0], axis=0)
    data_in_with_mesh = jnp.concatenate((data_in, X_shape_corrected, Y_shape_corrected), axis=1)
    
    data_out = data[:, 2, :, :]
    data_out = data_out[:, jnp.newaxis, :, :]
    data_stacked = jnp.concatenate((data_in_with_mesh, data_out), axis=1)

    # splitting data into train and testset than in training

    return data_stacked


In [55]:
training_data = create_training_data(keys, (100,100), 10.0, 0.1, 5.0, 00, 0.05, j_values)

In [None]:
jnp.save('training_data_2d.npy', training_data)