In [None]:
# NBVAL_SKIP


import os
import multiprocessing
import matplotlib.pyplot as plt


In [None]:
# NBVAL_SKIP

# Logical cores (includes hyperthreads)
print("Logical cores:", os.cpu_count())


# Total threads/cores via multiprocessing
print("multiprocessing.cpu_count():", multiprocessing.cpu_count())

In [None]:
# NBVAL_SKIP
# use dotenv to handle env variables
import os
from dotenv import load_dotenv
env_loaded =load_dotenv(dotenv_path='./data.env')
assert env_loaded, "Failed to load .env file"

import jax.numpy as jnp
import jax
from jax.sharding import PartitionSpec as P, NamedSharding

from rubix.core.pipeline import RubixPipeline 


In [None]:
# NBVAL_SKIP

print(jax.devices())


# RUBIX pipeline

RUBIX is designed as a linear pipeline, where the individual functions are called and constructed as a pipeline. This allows as to execute the whole data transformation from a cosmological hydrodynamical simulation of a galaxy to an IFU cube in two lines of code. This notebook shows, how to execute the pipeline on multiple machines. To see, how the pipeline is executed in small individual steps per individual function, we refer to the notebook `rubix_pipeline_stepwise.ipynb`.

## How to use the Pipeline
1) Define a `config`
2) Setup the `pipeline yaml`
3) Run the RUBIX pipeline
4) Do science with the mock-data

## Step 1: Config

The `config` contains all the information needed to run the pipeline. Those are run specfic configurations. Currently we just support Illustris as simulation, but extensions to other simulations (e.g. NIHAO) are planned.

For the `config` you can choose the following options:
- `pipeline`: you specify the name of the pipeline that is stored in the yaml file in rubix/config/pipeline_config.yml
- `logger`: RUBIX has implemented a logger to report to the user, what is happening during the pipeline execution and give warnings
- `data - args - particle_type`: load only stars particle ("particle_type": ["stars"]) or only gas particle ("particle_type": ["gas"]) or both ("particle_type": ["stars","gas"])
- `data - args - simulation`: choose the Illustris simulation (e.g. "simulation": "TNG50-1")
- `data - args - snapshot`: which time step of the simulation (99 for present day)
- `data - args - save_data_path`: set the path to save the downloaded Illustris data
- `data - load_galaxy_args - id`: define, which Illustris galaxy is downloaded
- `data - load_galaxy_args - reuse`: if True, if in the save_data_path directory a file for this galaxy id already exists, the downloading is skipped and the preexisting file is used
- `data - subset`: only a defined number of stars/gas particles is used and stored for the pipeline. This may be helpful for quick testing
- `simulation - name`: currently only IllustrisTNG is supported
- `simulation - args - path`: where the data is stored and how the file will be named
- `output_path`: where the hdf5 file is stored, which is then the input to the RUBIX pipeline
- `telescope - name`: define the telescope instrument that is observing the simulation. Some telescopes are predefined, e.g. MUSE. If your instrument does not exist predefined, you can easily define your instrument in rubix/telescope/telescopes.yaml
- `telescope - psf`: define the point spread function that is applied to the mock data
- `telescope - lsf`: define the line spread function that is applied to the mock data
- `telescope - noise`: define the noise that is applied to the mock data
- `cosmology`: specify the cosmology you want to use, standard for RUBIX is "PLANCK15"
- `galaxy - dist_z`: specify at which redshift the mock-galaxy is observed
- `galaxy - rotation`: specify the orientation of the galaxy. You can set the types edge-on or face-on or specify the angles alpha, beta and gamma as rotations around x-, y- and z-axis
- `ssp - template`: specify the simple stellar population lookup template to get the stellar spectrum for each stars particle. In RUBIX frequently "BruzualCharlot2003" is used.

In [None]:
#NBVAL_SKIP


config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 30000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
    "ssp": {
        "template": {
            "name": "FSPS"
        },
        "dust": {
                "extinction_model": "Cardelli89",
                "dust_to_gas_ratio": 0.01,
                "dust_to_metals_ratio": 0.4,
                "dust_grain_density": 3.5,
                "Rv": 3.1,
            },
    },        
}

## Step 2: Pipeline yaml

To run the RUBIX pipeline, you need a yaml file (stored in `rubix/config/pipeline_config.yml`) that defines which functions are used during the execution of the pipeline. This shows the example pipeline yaml to compute a stellar IFU cube.

```yaml
calc_ifu:
  Transformers:
    rotate_galaxy:
      name: rotate_galaxy
      depends_on: null
      args: []
      kwargs:
        type: "face-on"
    filter_particles:
      name: filter_particles
      depends_on: rotate_galaxy
      args: []
      kwargs: {}
    spaxel_assignment:
      name: spaxel_assignment
      depends_on: filter_particles
      args: []
      kwargs: {}
    reshape_data:
      name: reshape_data
      depends_on: spaxel_assignment
      args: []
      kwargs: {}
    calculate_spectra:
      name: calculate_spectra
      depends_on: reshape_data
      args: []
      kwargs: {}
    scale_spectrum_by_mass:
      name: scale_spectrum_by_mass
      depends_on: calculate_spectra
      args: []
      kwargs: {}
    doppler_shift_and_resampling:
      name: doppler_shift_and_resampling
      depends_on: scale_spectrum_by_mass
      args: []
      kwargs: {}
    calculate_datacube:
      name: calculate_datacube
      depends_on: doppler_shift_and_resampling
      args: []
      kwargs: {}
    convolve_psf:
      name: convolve_psf
      depends_on: calculate_datacube
      args: []
      kwargs: {}
    convolve_lsf:
      name: convolve_lsf
      depends_on: convolve_psf
      args: []
      kwargs: {}
    apply_noise:
      name: apply_noise
      depends_on: convolve_lsf
      args: []
      kwargs: {}
```

There is one thing you have to know about the naming of the functions in this yaml: To use the functions inside the pipeline, the functions have to be called exactly the same as they are returned from the core module function!

# Data organization

try simple approach for this thing for now. This is really stupid: just build a giant box of zeros, index into them in the right way, and use these indices to assign the values we want to slices in the box

In [None]:
# NBVAL_SKIP


# this function builds the data from the rubixdata object because that is easiest, but should not really be done imho. 
def build_data(inputdata): 
    long_axis = inputdata.stars.age.shape[0]
    data = jnp.zeros((long_axis, 6200), dtype=jnp.float32)
    inputdata.galaxy.redshift = jnp.float32(inputdata.galaxy.redshift)
    inputdata.galaxy.halfmassrad_stars = jnp.array(inputdata.galaxy.halfmassrad_stars, dtype=jnp.float32)
    inputdata.galaxy.center = jnp.array(inputdata.galaxy.center, dtype=jnp.float32)

    inputdata.stars.coords = jnp.array(inputdata.stars.coords, dtype=jnp.float32)
    inputdata.stars.age = jnp.array(inputdata.stars.age, dtype=jnp.float32)
    inputdata.stars.velocity = jnp.array(inputdata.stars.velocity, dtype=jnp.float32)
    inputdata.stars.metallicity = jnp.array(inputdata.stars.metallicity, dtype=jnp.float32)
    inputdata.stars.mass = jnp.array(inputdata.stars.mass, dtype=jnp.float32)
    # stars properties
    data = data.at[:, 0:3].set(inputdata.stars.coords)
    data = data.at[:, 3:6].set(inputdata.stars.velocity)
    data = data.at[:, 6].set(inputdata.stars.metallicity)
    data = data.at[:, 7].set(inputdata.stars.age)
    data = data.at[:, 8].set(inputdata.stars.mass)

    # galaxy properties
    data = data.at[:, 9].set(inputdata.galaxy.halfmassrad_stars)
    data = data.at[:, 10].set(inputdata.galaxy.redshift)
    data = data.at[:, 11:14].set(inputdata.galaxy.center)
    
    mesh = jax.make_mesh((jax.device_count(), ), ('x',))
    shard = NamedSharding(mesh, P('x'))

    data = jax.device_put(data, shard)

    return data

In [None]:
# NBVAL_SKIP

def stars(data: jnp.ndarray) -> jnp.ndarray:
    """
    Stars function to be used in the pipeline.
    """
    # Perform some operations on the data
    # For example, let's just return the data as is
    return data[:, 0:9]

def gas(data: jnp.ndarray) -> jnp.ndarray:
    return data # index after adjusting the above for gas

def galaxy(data: jnp.ndarray) -> jnp.ndarray:
    """
    Galaxy function to be used in the pipeline.
    """
    # Perform some operations on the data
    # For example, let's just return the data as is
    return data[:, 9:14]

In [None]:
# NBVAL_SKIP

def coords_idx(): 
    return jnp.s_[:, 0:3]

def coords(data: jnp.ndarray) -> jnp.ndarray:
    """
    Coords function to be used in the pipeline.
    """
    return data[coords_idx()]

def velocity_idx():
    return jnp.s_[:, 3:6]

def velocity(data: jnp.ndarray) -> jnp.ndarray:
    """
    Velocity function to be used in the pipeline.
    """
    return data[velocity_idx()]

def metallicity_idx():
    return jnp.s_[:, 6]

def metallicity(data: jnp.ndarray) -> jnp.ndarray:
    """
    Metallicity function to be used in the pipeline.
    """
    return data[metallicity_idx()]

def age_idx():
    return jnp.s_[:, 7]

def age(data: jnp.ndarray) -> jnp.ndarray:
    """
    Age function to be used in the pipeline.
    """
    return data[age_idx()]

def mass_idx():
    return jnp.s_[:, 8]

def mass(data: jnp.ndarray) -> jnp.ndarray:
    """
    Age function to be used in the pipeline.
    """
    return data[mass_idx()]

def halfmassrad_stars_idx():
    return jnp.s_[:, 9]

def halfmassrad_stars(data: jnp.ndarray) -> jnp.ndarray:
    """
    Halfmassrad_stars function to be used in the pipeline.
    """
    return data[halfmassrad_stars_idx()]


def redshift_idx():
    return jnp.s_[:, 10]

def redshift(data: jnp.ndarray) -> jnp.ndarray:
    """
    Redshift function to be used in the pipeline.
    """
    return data[redshift_idx()]

def center_idx():
    return jnp.s_[:, 11:14]

def center(data: jnp.ndarray) -> jnp.ndarray:
    """
    Center function to be used in the pipeline.
    """
    return data[center_idx()]

def mask_idx() :
    return jnp.s_[:, 14]

def mask(data: jnp.ndarray) -> jnp.ndarray:
    """
    Mask function to be used in the pipeline.
    """
    return data[mask_idx()]

def pixel_assignment_idx() : 
    return jnp.s_[:, 15]

def pixel_assignment(data: jnp.ndarray) -> jnp.ndarray:
    """
    Pixel assignment function to be used in the pipeline.
    """
    return data[pixel_assignment_idx()]


def spectra_index(): 
    return jnp.s_[:, 16:(16 + 5994)]

def spectra(data: jnp.ndarray) -> jnp.ndarray:
    """
    Spectra function to be used in the pipeline.
    """
    return data[spectra_index()]


try the sharding now with pipeline functions. since the pipeline functions use other data, I don´t use them directly, but build simplified versions here that only include stars. this involves the build up of the pipeline from the ground up in such a way that the data is sharded once and then we don´t have to touch it again

TODO: make sure the functions have the correct static argnums such that we don´t have to worry about the tracing shit

In [None]:
# NBVAL_SKIP

from functools import partial
from pipe import Pipe
from rubix.galaxy.alignment import moment_of_inertia_tensor, rotation_matrix_from_inertia_tensor, apply_init_rotation, apply_rotation
from rubix.core.telescope import get_spatial_bin_edges
from rubix.telescope.utils import mask_particles_outside_aperture
from rubix.core.pipeline import RubixPipeline 
from rubix.core.data import RubixData
from rubix.core.telescope import get_telescope
from jax import random as jrandom
from rubix.core.ssp import get_ssp, get_lookup_interpolation
from rubix.telescope.psf.kernels import gaussian_kernel_2d
from jax.scipy.signal import convolve2d
from rubix.telescope.lsf.lsf import _get_kernel
from jax.scipy.signal import convolve
from rubix import config as rubix_config

## galaxy rotation

In [None]:
# NBVAL_SKIP

def rotate_galaxy_impl(data: jnp.array, alpha, beta, gamma)->jnp.array: 

    I = moment_of_inertia_tensor(coords(data), mass(data), halfmassrad_stars(data),)
    R = rotation_matrix_from_inertia_tensor(I)
    data = data.at[coords_idx()].set(apply_rotation(apply_init_rotation(coords(data), R), alpha, beta, gamma))
    data = data.at[velocity_idx()].set(apply_rotation(apply_init_rotation(velocity(data), R), alpha, beta, gamma))
    return data

# TODO: generalize, get these numbers from the config
rotate_galaxy = partial(rotate_galaxy_impl, alpha=90.0, beta=0.0, gamma=0.0)

## filter particles

In [None]:

# NBVAL_SKIP

def filter_particles_impl(data: jnp.ndarray, spatial_bin_edges) -> jnp.ndarray:
    mask = mask_particles_outside_aperture(
        coords(data), spatial_bin_edges
    )

    data = data.at[mask_idx()].set(mask)

    for attr in [age_idx, mass_idx, metallicity_idx, ]: 
        data = data.at[attr()].set(
            jnp.where(mask, data[attr()], 0)
        )

    return data

filter_particles = partial(filter_particles_impl, spatial_bin_edges=get_spatial_bin_edges(config))

## spaxel assignment

In [None]:
# NBVAL_SKIP

def spaxel_assignment_square_impl(data: jnp.ndarray, spatial_bin_edges)-> jnp.ndarray:
    # Calculate assignment of of x and y coordinates to bins separately
    x_indices = (
        jnp.digitize(data[coords_idx()][:, 0], spatial_bin_edges) - 1
    )  # -1 to start indexing at 0
    y_indices = jnp.digitize(data[coords_idx()][:, 1], spatial_bin_edges) - 1

    number_of_bins = len(spatial_bin_edges) - 1

    # Clip the indices to the valid range
    x_indices = jnp.clip(x_indices, 0, number_of_bins - 1)
    y_indices = jnp.clip(y_indices, 0, number_of_bins - 1)

    # Flatten the 2D indices to 1D indices
    pixel_positions = x_indices + (number_of_bins * y_indices)
    return data.at[pixel_assignment_idx()].set(jnp.round(pixel_positions))


spaxel_assignment = partial(spaxel_assignment_square_impl, spatial_bin_edges=get_spatial_bin_edges(config))


## Calculate spectra

calculate spectra now. since this is so big, it would perpaps make sense to have a separate path for this thing instead of having to save this and drag it around all the time. 

In [None]:
# NBVAL_SKIP

# this needs to be optimized, it uses far too much memory
def calculate_spectra_impl(data: jnp.ndarray, lookup_interpolation) -> jnp.ndarray: 
    print("Calculating spectra")
    print("Data shape:", data.shape)
    print("lookup type: ", type(lookup_interpolation))
    print("lookup shape: ", lookup_interpolation.shape)
    # this thing is gigantic and probably cannot be stored in memory for serious data
    return data.at[spectra_index()].set(lookup_interpolation(
        data[metallicity_idx()],
        data[age_idx()],
    ))
# this creates a file access that should not be on the hot path. 
lookup_interpolation = get_lookup_interpolation(config)
calculate_spectra = partial(calculate_spectra_impl, lookup_interpolation=lookup_interpolation)

## scale spectrum by mass

In [None]:
# NBVAL_SKIP

def scale_spectrum_by_mass(data: jnp.ndarray) -> jnp.ndarray:

    return data.at[spectra_index()].set(
        data[spectra_index()] * data[mass_idx()][:, jnp.newaxis]
    )

## doppler shift

In [None]:
# NBVAL_SKIP

# get all the needed crap... 
velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"]
directions = {"x": 0, "y": 1, "z": 2}

In [None]:
# NBVAL_SKIP
# TODO: this needs to be fused with the resampling step such that the giant temporary array is not created
def apply_doppler_impl(data: jnp.ndarray, wavelength, c, direction) -> jnp.ndarray:

    # 3 is the index of the first velocity component
    d = jnp.exp(data[:, 3 + direction]/ c) # 3 is offset of the velocity component

    return jax.vmap(lambda d: wavelength * d)(d)

ssp = get_ssp(config)
ssp_wave= ssp.wavelength
direction = directions[velocity_direction]
cosmological_doppler_shift = (1 + config["galaxy"]["dist_z"]) * ssp.wavelength

apply_doppler = partial(apply_doppler_impl, wavelength=ssp_wave, c=3e8, direction=direction)

## resampling

In [None]:
# NBVAL_SKIP
def calculate_diff(
    vec, pad_with_zero: bool = True
):
    """
    Calculate the difference between each element in a vector.

    Args:
        vec (array-like): The input vector.
        pad_with_zero (bool, optional): Whether to prepend the first element of the vector to the differences. Default is True.

    Returns:
        The differences between each element in the vector (array-like).
    """

    if pad_with_zero:
        differences = jnp.diff(vec, prepend=vec[0])
    else:
        differences = jnp.diff(vec)
    return differences


def resample_spectrum_impl(init_spectrum: jnp.ndarray, initial_wavelength, target_wavelength) -> jnp.ndarray:
    in_range_mask = (initial_wavelength >= jnp.min(target_wavelength)) & (
        initial_wavelength <= jnp.max(target_wavelength)
    )

    intrinsic_wave_diff = calculate_diff(initial_wavelength) * in_range_mask

    # Get total luminsoity within the wavelength range
    total_lum = jnp.sum(init_spectrum * intrinsic_wave_diff)

    # Interpolate the wavelegnth to the telescope grid
    particle_lum = jnp.interp(target_wavelength, initial_wavelength, init_spectrum)

    # New total luminosity
    new_total_lum = jnp.sum(particle_lum * calculate_diff(target_wavelength))

    # Factor to conserve flux in the new spectrum
    scale_factor = total_lum / new_total_lum
    scale_factor = jnp.nan_to_num(
        scale_factor, nan=0.0
    )  # Otherwise we get NaNs if new_total_lum is zero
    lum = particle_lum * scale_factor

    return lum

# indexing stuff for spectra
def rs_spectra_index(out_size: int): 
    return jnp.s_[:, 16:(16 + out_size)]

def diff_spectra_index(in_size: int, out_size: int): 
    return jnp.s_[:, 16:(16 + (in_size - out_size))]

def rs_spectra(data: jnp.ndarray, out_size: int) -> jnp.ndarray:
    """
    Spectra function to be used in the pipeline.
    """
    return data[rs_spectra_index(out_size)]

def doppler_and_resample(data: jnp.array, target_wavelength: jnp.array,  out_size: int) -> jnp.ndarray:
    """
    Doppler shift and resample the spectrum.
    """
    # Apply the doppler shift
    v = apply_doppler(data)

    # Resample the spectrum
    data = data.at[rs_spectra_index(out_size)].set(
        jax.vmap(resample_spectrum_impl, in_axes=(0,0, None))(
            data[spectra_index()], v, target_wavelength
        )
    )
    data = data.at[diff_spectra_index(ssp_wave.shape[0], out_size)].set(0.0)

    return data

telescope = get_telescope(config)
telescope_wavelength = telescope.wave_seq
num_spaxels = int(telescope.sbin)
out_size = int(telescope_wavelength.shape[0])

resample = partial(doppler_and_resample,target_wavelength=telescope_wavelength, out_size = telescope_wavelength.shape[0])

get all the telescope data stuff and make a partial

In [None]:
# NBVAL_SKIP
telescope = get_telescope(config)
telescope_wavelength = telescope.wave_seq
num_spaxels = int(telescope.sbin)
out_size = int(telescope_wavelength.shape[0])

resample = partial(doppler_and_resample,target_wavelength=telescope_wavelength, out_size = telescope_wavelength.shape[0])

## apply extinction

In [None]:
# NBVAL_SKIP
from rubix.telescope.utils import calculate_spatial_bin_edges
from rubix.core.cosmology import get_cosmology
from rubix.spectra.dust.extinction_models import Rv_model_dict, Cardelli89, Gordon23


In [None]:
# NBVAL_SKIP
galaxy_dist_z = config["galaxy"]["dist_z"]
telescope = get_telescope(config)
telescope_wavelength = telescope.wave_seq
num_spaxels = int(telescope.sbin)
cosmology = get_cosmology(config)
ext_model = config["ssp"]["dust"]["extinction_model"]
Rv = config["ssp"]["dust"]["Rv"]
ext_model_class = Rv_model_dict[ext_model]
ext = ext_model_class(Rv=Rv)


In [None]:
# NBVAL_SKIP
_, spatial_bin_size = calculate_spatial_bin_edges(fov =telescope.fov, spatial_bins = telescope.sbin, dist_z = galaxy_dist_z, cosmology = cosmology)
spaxel_area = spatial_bin_size**2


In [None]:
# NBVAL_SKIP

def apply_extinction(data: jnp.ndarray, wavelength, spaxel_area, n_spaxels, ext) -> jnp.ndarray:
    # I don´t have gas in the data currently, so I skip this for now. 
    # The way it is done in the dust_extinction module has config lookups within the function, and the sorting should be avoided when possible! It's not clear why this is needed? 
    pass
    

## calculate datacube

In [None]:
# NBVAL_SKIP
def calculate_datacube_impl(data: jnp.ndarray, num_spaxels: int, out_size: int) -> jnp.ndarray:
    return jax.ops.segment_sum(
        data[rs_spectra_index(out_size)], # spectra
        data[pixel_assignment_idx()].astype('int32'), # pixel assignment
        num_segments=num_spaxels**2,
    ).reshape(
        (num_spaxels, num_spaxels, telescope_wavelength.shape[0])
    )

calculate_datacube = partial(calculate_datacube_impl, num_spaxels= int(telescope.sbin), out_size=out_size)

## convolve psf

In [None]:
# NBVAL_SKIP
m, n = config["telescope"]["psf"]["size"], config["telescope"]["psf"]["size"]
sigma = config["telescope"]["psf"]["sigma"]
kernel = gaussian_kernel_2d(m, n, sigma)

In [None]:
# NBVAL_SKIP
def apply_psf_impl(cube: jnp.ndarray, kernel) -> jnp.ndarray:

    return jnp.transpose(jax.vmap(partial(convolve2d, mode = "same"), in_axes = (2, None))(
        cube, 
        kernel,
    ), (1, 2, 0))
apply_psf = partial(apply_psf_impl, kernel=kernel)

## convolve lsf

In [None]:
# NBVAL_SKIP
sigma = config["telescope"]["lsf"]["sigma"]
telescope = get_telescope(config)
wave_resolution = telescope.wave_res
extend_factor = 12

kernel = _get_kernel(sigma, wave_resolution, factor=extend_factor)

In [None]:
# NBVAL_SKIP
def apply_lsf_impl(cube: jnp.ndarray, kernel: jnp.array, extend_factor: int) -> jnp.ndarray:
    reshaped_cube = cube.reshape(-1, cube.shape[-1])
    convolved = jax.vmap(partial(convolve, mode="full"), in_axes=(0, None))(reshaped_cube, kernel)
    end = reshaped_cube.shape[1] + kernel.shape[0] - 1  - extend_factor
    convolved= convolved[:, extend_factor:end]
    return convolved.reshape(cube.shape)

apply_lsf = partial(apply_lsf_impl, kernel=kernel, extend_factor=extend_factor)

## apply noise

In [None]:
# NBVAL_SKIP
signal_to_noise = config["telescope"]["noise"]["signal_to_noise"]

# Get the noise distribution
noise_distribution = config["telescope"]["noise"]["noise_distribution"]

In [None]:
# NBVAL_SKIP
def calculate_S2N(cube: jnp.ndarray, observation_s2n: float)->jnp.ndarray: 
    flux_image = jnp.sum(cube, axis=2)
    return jnp.where(flux_image > 0 , (jnp.sqrt(jnp.median(jnp.where(flux_image > 0 , flux_image, 0.)))/observation_s2n)/jnp.sqrt(flux_image), 0)

def apply_noise_impl(cube: jnp.array, signal_to_noise: float) -> jnp.ndarray:
    # TODO: this can probably be vmapped for better performance
    key = jrandom.PRNGKey(0)
    s2n = calculate_S2N(cube, signal_to_noise)
    return cube + cube*jrandom.normal(key, cube.shape) * s2n[:, :, None] 

apply_noise = partial(apply_noise_impl, signal_to_noise=signal_to_noise)


## build pipelines

looks like everything is in place now, so we can build pipelines for the data transformations and the cube transformations. This is only done for sake of debugging, in production the separation is not needed

In [None]:
# NBVAL_SKIP
@jax.jit
def transform_data(inputdata: jnp.ndarray) -> jnp.ndarray:

    data = rotate_galaxy(inputdata)
    data = filter_particles(data)
    data = spaxel_assignment(data)
    data = calculate_spectra(data)
    data = scale_spectrum_by_mass(data)
    return data

this pipeline building and data prepare needs to go eventually

In [None]:
# NBVAL_SKIP
pipe = RubixPipeline(config)
inputdata = pipe.prepare_data()

In [None]:
# NBVAL_SKIP
data = inputdata | Pipe(build_data)

In [None]:
# NBVAL_SKIP
jax.debug.visualize_array_sharding(data)

In [None]:
# NBVAL_SKIP
data = transform_data(data)

In [None]:
# NBVAL_SKIP
data.block_until_ready();

In [None]:
# NBVAL_SKIP

data.shape, data.nbytes// 1024**2, data.nbytes/1024**3

In [None]:
# NBVAL_SKIP

jax.debug.visualize_array_sharding(data)

The data array is still correctly sharded. yay!

when working with the cube pipeline now, we have to reshard it first and index into the padded cube or pad all the other data too. This is done in the `compute_cube` function using the first method

In [None]:
# NBVAL_SKIP
def reshard_cube(cube: jnp.ndarray,) -> jnp.ndarray:
    d = cube.shape[2]

    # we can only go upwards to not loose
    while d % jax.device_count() != 0:
        d += 1
    d
    padding = d - cube.shape[2]
    mesh = jax.make_mesh((jax.device_count(), ), ('devices',))
    shard = NamedSharding(mesh, P(None, None, 'devices'))

    cube = jax.device_put(jnp.pad(cube, ((0, 0), (0, 0), (0, padding))), shard)
    return cube

def compute_cube(inputdata: jnp.ndarray) -> jnp.ndarray:
    cube = calculate_datacube(inputdata)
    
    # not sure if this counteracts the sharding
    cube = apply_psf(cube)
    cube = apply_lsf(cube)
    cube = apply_noise(cube)
    return cube
    

simple cube is not sharded

In [None]:
# NBVAL_SKIP
cube = calculate_datacube(data)
jax.debug.visualize_array_sharding(cube.reshape(cube.shape[0]* cube.shape[1], cube.shape[2]))

In [None]:
# NBVAL_SKIP
cube = reshard_cube(cube)
jax.debug.visualize_array_sharding(cube.reshape(cube.shape[0]* cube.shape[1], cube.shape[2]))

I have not applied this to the computation now because it is messy to do and it's not the main objective. this data cube is tiny by comparison. What one has to do is pad the data that takes part in the computations in the cube pipeline to the size of the cube. then the sharding should be fine. indexing into the cube will destroy the sharding again apparently, distributing it over all devices in the case of this tiny one. not good... 

In [None]:
# NBVAL_SKIP
final_cube = compute_cube(data)
final_cube.block_until_ready()
jax.debug.visualize_array_sharding(final_cube.reshape(final_cube.shape[0]* final_cube.shape[1], final_cube.shape[2]))

not sharded correctly... :/

In [None]:
# NBVAL_SKIP

final_cube.shape, final_cube.nbytes / 1024**2, final_cube.dtype

... but it's also really small, so might be that? 

## memory usage 

The main point: which function causes memory explosion and why? 

So far, we barely need 710 MB for the data cube, and we are not efficiently using memory at all. On multiple GPUs with overall O(100)GB, we should easily be able to process the required data sizes.

**Expectation:**
For the 500k particles, this would amount to roughly (500/30)*710 = 11833, so 12 GB. Even with with double the number of spectral lines we should easily be able to run this on a 4090. up to ~800k particles on a single GPU with the current spectral line number should also be doable, and we do not talk about sharding here at all. 

When we have gas, this goes down by half. At any rate, how can this computation cause memory issues on this gpu?

**Observation**
However, something temporarily causes a gigantic number of allocations in temporary arrays that lets memory usage go up to 40G or more. this is the killer element, I don't think that the sharding as such is a problem. 

Experiments above show that it's happening when processing the data itself, the cube computations are harmless.

check each function of the pipeline with htop/nvtop or similar tools: htop -d 3 --> update ever 0.3 seconds

In [None]:
# NBVAL_SKIP
data = build_data(inputdata)
data.block_until_ready(); # not the culprit

In [None]:
# NBVAL_SKIP
data = rotate_galaxy(data)
data.block_until_ready(); #not the culprit

In [None]:
# NBVAL_SKIP
data = filter_particles(data)
data.block_until_ready(); #not the culprit

In [None]:
# NBVAL_SKIP
data = spaxel_assignment(data)
data.block_until_ready(); #not the culprit

In [None]:
# NBVAL_SKIP
data = calculate_spectra(data)
data.block_until_ready(); # very much the culprit! increases memory size to > 40 GB even though the input is only ~0.7 - 0.8 GB

In [None]:
# NBVAL_SKIP
data = scale_spectrum_by_mass(data)
data.block_until_ready(); #not the culprit

In [None]:
# NBVAL_SKIP
data = resample(data)
data.block_until_ready(); # moderate increase, not beyond a manageable size

In [None]:
# NBVAL_SKIP
cube = calculate_datacube(data)
cube.block_until_ready();  #not the culprit

just to be sure: check cube computation agani

In [None]:
# NBVAL_SKIP
final_cube = compute_cube(data)
final_cube.block_until_ready();  #not the culprit at all

There is a big problem in the spectra calculation that causes an enormous temporary memory issue. 