In [None]:
import jax.numpy as jnp

In [None]:
import os
import multiprocessing

# Logical cores (includes hyperthreads)
print("Logical cores:", os.cpu_count())


# Total threads/cores via multiprocessing
print("multiprocessing.cpu_count():", multiprocessing.cpu_count())

In [None]:
import os


# Tell XLA to fake 2 host CPU devices
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=3'

# Only make GPU 0 and GPU 1 visible to JAX:
# os.environ['CUDA_VISIBLE_DEVICES'] = '7, 8, 9'

# for making sure that JAX doesnt'consume all memory at once
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]   = "false"

import jax
# Now JAX will list two CpuDevice entries
print(jax.devices())
# → [CpuDevice(id=0), CpuDevice(id=1)]

In [None]:
# NBVAL_SKIP
import os
#  os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'
#os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'
#os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'
os.environ['SPS_HOME'] = '/home/hmack/.cache/fsps'
os.environ['ILLUSTRIS_API_KEY'] = ''

# RUBIX pipeline

RUBIX is designed as a linear pipeline, where the individual functions are called and constructed as a pipeline. This allows as to execute the whole data transformation from a cosmological hydrodynamical simulation of a galaxy to an IFU cube in two lines of code. This notebook shows, how to execute the pipeline on multiple machines. To see, how the pipeline is executed in small individual steps per individual function, we refer to the notebook `rubix_pipeline_stepwise.ipynb`.

## How to use the Pipeline
1) Define a `config`
2) Setup the `pipeline yaml`
3) Run the RUBIX pipeline
4) Do science with the mock-data

## Step 1: Config

The `config` contains all the information needed to run the pipeline. Those are run specfic configurations. Currently we just support Illustris as simulation, but extensions to other simulations (e.g. NIHAO) are planned.

For the `config` you can choose the following options:
- `pipeline`: you specify the name of the pipeline that is stored in the yaml file in rubix/config/pipeline_config.yml
- `logger`: RUBIX has implemented a logger to report to the user, what is happening during the pipeline execution and give warnings
- `data - args - particle_type`: load only stars particle ("particle_type": ["stars"]) or only gas particle ("particle_type": ["gas"]) or both ("particle_type": ["stars","gas"])
- `data - args - simulation`: choose the Illustris simulation (e.g. "simulation": "TNG50-1")
- `data - args - snapshot`: which time step of the simulation (99 for present day)
- `data - args - save_data_path`: set the path to save the downloaded Illustris data
- `data - load_galaxy_args - id`: define, which Illustris galaxy is downloaded
- `data - load_galaxy_args - reuse`: if True, if in the save_data_path directory a file for this galaxy id already exists, the downloading is skipped and the preexisting file is used
- `data - subset`: only a defined number of stars/gas particles is used and stored for the pipeline. This may be helpful for quick testing
- `simulation - name`: currently only IllustrisTNG is supported
- `simulation - args - path`: where the data is stored and how the file will be named
- `output_path`: where the hdf5 file is stored, which is then the input to the RUBIX pipeline
- `telescope - name`: define the telescope instrument that is observing the simulation. Some telescopes are predefined, e.g. MUSE. If your instrument does not exist predefined, you can easily define your instrument in rubix/telescope/telescopes.yaml
- `telescope - psf`: define the point spread function that is applied to the mock data
- `telescope - lsf`: define the line spread function that is applied to the mock data
- `telescope - noise`: define the noise that is applied to the mock data
- `cosmology`: specify the cosmology you want to use, standard for RUBIX is "PLANCK15"
- `galaxy - dist_z`: specify at which redshift the mock-galaxy is observed
- `galaxy - rotation`: specify the orientation of the galaxy. You can set the types edge-on or face-on or specify the angles alpha, beta and gamma as rotations around x-, y- and z-axis
- `ssp - template`: specify the simple stellar population lookup template to get the stellar spectrum for each stars particle. In RUBIX frequently "BruzualCharlot2003" is used.

In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os

config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 30000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
    "ssp": {
        "template": {
            "name": "FSPS"
        },
        "dust": {
                "extinction_model": "Cardelli89",
                "dust_to_gas_ratio": 0.01,
                "dust_to_metals_ratio": 0.4,
                "dust_grain_density": 3.5,
                "Rv": 3.1,
            },
    },        
}

## Step 2: Pipeline yaml

To run the RUBIX pipeline, you need a yaml file (stored in `rubix/config/pipeline_config.yml`) that defines which functions are used during the execution of the pipeline. This shows the example pipeline yaml to compute a stellar IFU cube.

```yaml
calc_ifu:
  Transformers:
    rotate_galaxy:
      name: rotate_galaxy
      depends_on: null
      args: []
      kwargs:
        type: "face-on"
    filter_particles:
      name: filter_particles
      depends_on: rotate_galaxy
      args: []
      kwargs: {}
    spaxel_assignment:
      name: spaxel_assignment
      depends_on: filter_particles
      args: []
      kwargs: {}
    reshape_data:
      name: reshape_data
      depends_on: spaxel_assignment
      args: []
      kwargs: {}
    calculate_spectra:
      name: calculate_spectra
      depends_on: reshape_data
      args: []
      kwargs: {}
    scale_spectrum_by_mass:
      name: scale_spectrum_by_mass
      depends_on: calculate_spectra
      args: []
      kwargs: {}
    doppler_shift_and_resampling:
      name: doppler_shift_and_resampling
      depends_on: scale_spectrum_by_mass
      args: []
      kwargs: {}
    calculate_datacube:
      name: calculate_datacube
      depends_on: doppler_shift_and_resampling
      args: []
      kwargs: {}
    convolve_psf:
      name: convolve_psf
      depends_on: calculate_datacube
      args: []
      kwargs: {}
    convolve_lsf:
      name: convolve_lsf
      depends_on: convolve_psf
      args: []
      kwargs: {}
    apply_noise:
      name: apply_noise
      depends_on: convolve_lsf
      args: []
      kwargs: {}
```

There is one thing you have to know about the naming of the functions in this yaml: To use the functions inside the pipeline, the functions have to be called exactly the same as they are returned from the core module function!

In [None]:
#NBVAL_SKIP
pipe = RubixPipeline(config)
inputdata = pipe.prepare_data()

In [None]:
from jax.sharding import PartitionSpec as P, NamedSharding


In [None]:
    
mesh = jax.make_mesh((jax.device_count(), ), ('x',))
shard = NamedSharding(mesh, P('x'))
data = jax.device_put(inputdata, shard)

why this no work?? 

try simpler approach for this thing for now. This is really stupid: just build a giant box of zeros, index into them in the right way, and use these indices to assign the values we want to slices in the box

In [None]:
# this function builds the data from the rubixdata object because that is easiest, but should not really be done imho. 
def build_data(input): 
    long_axis = input.stars.age.shape[0]
    data = jnp.zeros((long_axis, 6200), dtype=jnp.float32)
    inputdata.galaxy.redshift = jnp.float32(inputdata.galaxy.redshift)
    inputdata.galaxy.halfmassrad_stars = jnp.array(inputdata.galaxy.halfmassrad_stars, dtype=jnp.float32)
    inputdata.galaxy.center = jnp.array(inputdata.galaxy.center, dtype=jnp.float32)

    inputdata.stars.coords = jnp.array(inputdata.stars.coords, dtype=jnp.float32)
    inputdata.stars.age = jnp.array(inputdata.stars.age, dtype=jnp.float32)
    inputdata.stars.velocity = jnp.array(inputdata.stars.velocity, dtype=jnp.float32)
    inputdata.stars.metallicity = jnp.array(inputdata.stars.metallicity, dtype=jnp.float32)
    inputdata.stars.mass = jnp.array(inputdata.stars.mass, dtype=jnp.float32)
    # stars properties
    data = data.at[:, 0:3].set(inputdata.stars.coords)
    data = data.at[:, 3:6].set(inputdata.stars.velocity)
    data = data.at[:, 6].set(inputdata.stars.metallicity)
    data = data.at[:, 7].set(inputdata.stars.age)
    data = data.at[:, 8].set(inputdata.stars.mass)

    # galaxy properties
    data = data.at[:, 9].set(inputdata.galaxy.halfmassrad_stars)
    data = data.at[:, 10].set(inputdata.galaxy.redshift)
    data = data.at[:, 11:14].set(inputdata.galaxy.center)
    
    mesh = jax.make_mesh((jax.device_count(), ), ('x',))
    shard = NamedSharding(mesh, P('x'))

    data = jax.device_put(data, shard)

    return data

In [None]:
def stars(data: jnp.ndarray) -> jnp.ndarray:
    """
    Stars function to be used in the pipeline.
    """
    # Perform some operations on the data
    # For example, let's just return the data as is
    return data[:, 0:9]

def galaxy(data: jnp.ndarray) -> jnp.ndarray:
    """
    Galaxy function to be used in the pipeline.
    """
    # Perform some operations on the data
    # For example, let's just return the data as is
    return data[:, 9:14]

In [None]:


def coords_idx(): 
    return jnp.s_[:, 0:3]

def coords(data: jnp.ndarray) -> jnp.ndarray:
    """
    Coords function to be used in the pipeline.
    """
    return data[coords_idx()]

def velocity_idx():
    return jnp.s_[:, 3:6]

def velocity(data: jnp.ndarray) -> jnp.ndarray:
    """
    Velocity function to be used in the pipeline.
    """
    return data[velocity_idx()]

def metallicity_idx():
    return jnp.s_[:, 6]

def metallicity(data: jnp.ndarray) -> jnp.ndarray:
    """
    Metallicity function to be used in the pipeline.
    """
    return data[metallicity_idx()]

def age_idx():
    return jnp.s_[:, 7]

def age(data: jnp.ndarray) -> jnp.ndarray:
    """
    Age function to be used in the pipeline.
    """
    return data[age_idx()]

def mass_idx():
    return jnp.s_[:, 8]

def mass(data: jnp.ndarray) -> jnp.ndarray:
    """
    Age function to be used in the pipeline.
    """
    return data[mass_idx()]

def halfmassrad_stars_idx():
    return jnp.s_[:, 9]

def halfmassrad_stars(data: jnp.ndarray) -> jnp.ndarray:
    """
    Halfmassrad_stars function to be used in the pipeline.
    """
    return data[halfmassrad_stars_idx()]


def redshift_idx():
    return jnp.s_[:, 10]

def redshift(data: jnp.ndarray) -> jnp.ndarray:
    """
    Redshift function to be used in the pipeline.
    """
    return data[redshift_idx()]

def center_idx():
    return jnp.s_[:, 11:14]

def center(data: jnp.ndarray) -> jnp.ndarray:
    """
    Center function to be used in the pipeline.
    """
    return data[center_idx()]

def mask_idx() :
    return jnp.s_[:, 14]

def mask(data: jnp.ndarray) -> jnp.ndarray:
    """
    Mask function to be used in the pipeline.
    """
    return data[mask_idx()]

def pixel_assignment_idx() : 
    return jnp.s_[:, 15]

def pixel_assignment(data: jnp.ndarray) -> jnp.ndarray:
    """
    Pixel assignment function to be used in the pipeline.
    """
    return data[pixel_assignment_idx()]


def spectra_index(): 
    return jnp.s_[:, 16:(16 + 5994)]

def spectra(data: jnp.ndarray) -> jnp.ndarray:
    """
    Spectra function to be used in the pipeline.
    """
    return data[spectra_index()]


In [None]:
data = build_data(inputdata)

In [None]:
data.nbytes

In [None]:
jax.debug.visualize_array_sharding(data)


try the sharding now with pipeline functions. since the pipeline functions use other data, I don´t use them directly, but build simplified versions here that only include stars. this involves the build up of the pipeline from the ground up in such a way that the data is sharded once and then we don´t have to touch it again

TODO: make sure the functions have the correct static argnums such that we don´t have to worry about the tracing shit

In [None]:
from functools import partial
from pipe import Pipe

galaxy rotation

In [None]:
from rubix.galaxy.alignment import moment_of_inertia_tensor, rotation_matrix_from_inertia_tensor, apply_init_rotation, apply_rotation

def rotate_galaxy_impl(data: jnp.array, alpha, beta, gamma)->jnp.array: 

    I = moment_of_inertia_tensor(coords(data), mass(data), halfmassrad_stars(data),)
    R = rotation_matrix_from_inertia_tensor(I)
    data = data.at[coords_idx()].set(apply_rotation(apply_init_rotation(coords(data), R), alpha, beta, gamma))
    data = data.at[velocity_idx()].set(apply_rotation(apply_init_rotation(velocity(data), R), alpha, beta, gamma))
    return data

In [None]:
r = rotate_galaxy_impl(data, 0.1, 0.2, 0.3)
type(r), r.shape, r.dtype, r.nbytes

In [None]:
rotate_galaxy = partial(rotate_galaxy_impl, alpha=90.0, beta=0.0, gamma=0.0)

filter particles

In [None]:
from rubix.core.telescope import get_spatial_bin_edges
from rubix.telescope.utils import mask_particles_outside_aperture


def filter_particles_impl(data: jnp.ndarray, spatial_bin_edges) -> jnp.ndarray:
    mask = mask_particles_outside_aperture(
        coords(data), spatial_bin_edges
    )

    data = data.at[mask_idx()].set(mask)

    for attr in [age_idx, mass_idx, metallicity_idx, ]: 
        data = data.at[attr()].set(
            jnp.where(mask, data[attr()], 0)
        )

    return data

filter_particles = partial(filter_particles_impl, spatial_bin_edges=get_spatial_bin_edges(config))

In [None]:
get_spatial_bin_edges(config).shape 

try it out

In [None]:
data = filter_particles(data)

try out simple pipeline 

In [None]:
data = inputdata | Pipe(build_data) | Pipe(rotate_galaxy) | Pipe(filter_particles)

In [None]:
data.shape, data.nbytes / 1024**2, data.dtype

In [None]:
jax.debug.visualize_array_sharding(data)

try to compile it and run it then,then check sharding

In [None]:
from rubix.core.pipeline import RubixPipeline 
from rubix.core.data import RubixData

In [None]:
@jax.jit 
def pipeline(data: jnp.array) -> jnp.ndarray:
    data = rotate_galaxy(data)
    data = filter_particles(data)
    return data

In [None]:
data = build_data(inputdata)
data = pipeline(data)

In [None]:
data.shape, data.nbytes / 1024**2, data.dtype

In [None]:
jax.debug.visualize_array_sharding(data)

spaxel assignment

In [None]:
def spaxel_assignment_square_impl(data: jnp.ndarray, spatial_bin_edges)-> jnp.ndarray:
    # Calculate assignment of of x and y coordinates to bins separately
    x_indices = (
        jnp.digitize(data[coords_idx()][:, 0], spatial_bin_edges) - 1
    )  # -1 to start indexing at 0
    y_indices = jnp.digitize(data[coords_idx()][:, 1], spatial_bin_edges) - 1

    number_of_bins = len(spatial_bin_edges) - 1

    # Clip the indices to the valid range
    x_indices = jnp.clip(x_indices, 0, number_of_bins - 1)
    y_indices = jnp.clip(y_indices, 0, number_of_bins - 1)

    # Flatten the 2D indices to 1D indices
    pixel_positions = x_indices + (number_of_bins * y_indices)
    return data.at[pixel_assignment_idx()].set(jnp.round(pixel_positions))


spaxel_assignment = partial(spaxel_assignment_square_impl, spatial_bin_edges=get_spatial_bin_edges(config))


try it out again

In [None]:
data = inputdata | Pipe(build_data) | Pipe(rotate_galaxy) | Pipe(filter_particles) | Pipe(spaxel_assignment)

In [None]:
jax.debug.visualize_array_sharding(data)

calculate spectra now. since this is so big, it would perpaps make sense to have a separate path for this thing instead of having to save this and drag it around all the time. 

In [None]:
from rubix.core.ssp import get_ssp, get_lookup_interpolation

In [None]:
def calculate_spectra_impl(data: jnp.ndarray, lookup_interpolation) -> jnp.ndarray: 

    # this thing is gigantic and probably cannot be stored in memory for serious data
    return data.at[spectra_index()].set(lookup_interpolation(
        data[metallicity_idx()],
        data[age_idx()],
    ))


In [None]:
lookup_interpolation = get_lookup_interpolation(config)

calculate_spectra = partial(calculate_spectra_impl, lookup_interpolation=lookup_interpolation)

In [None]:
data = inputdata | Pipe(build_data) | Pipe(rotate_galaxy) | Pipe(filter_particles) | Pipe(spaxel_assignment) | Pipe(calculate_spectra)

In [None]:
type(data), data.shape, data.dtype, data.nbytes / 1024**2

In [None]:
jax.debug.visualize_array_sharding(data)

scale spectrum by mass

In [None]:
def scale_spectrum_by_mass(data: jnp.ndarray) -> jnp.ndarray:

    return data.at[spectra_index()].set(
        data[spectra_index()] * data[mass_idx()][:, jnp.newaxis]
    )

In [None]:
data = scale_spectrum_by_mass(data)

In [None]:
type(data), data.shape, data.dtype, data.nbytes / 1024**2

So far, we barely need 710 MB for everything we do, and we are not efficient at all wrt memory. On multiple GPUs with overall 100GB, we should easily be able to process the required data sizes? 

In [None]:
jax.debug.visualize_array_sharding(data)

doppler shift

In [None]:
# get all the needed crap... 
from rubix import config as rubix_config
velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"]
directions = {"x": 0, "y": 1, "z": 2}

In [None]:
velocity_direction

In [None]:
def apply_doppler_impl(data: jnp.ndarray, wavelength, c, direction) -> jnp.ndarray:
    print("shapes: ", data[velocity_idx()].shape, wavelength.shape)

    # FIXME: this needs to be vmapped or broadcasted in such a way that every velocity component is doppler shifted for each wavelength. 
    # calculate classic doppler shift 
    v = data[velocity_idx()][:, direction]
    return data.at[velocity_idx()][:, direction].set(
        wavelength * jnp.exp(v/c)
    )

ssp = get_ssp(config)
ssp_wave= ssp.wavelength
direction = directions[velocity_direction]
cosmological_doppler_shift = (1 + config["galaxy"]["dist_z"]) * ssp.wavelength

apply_doppler = partial(apply_doppler_impl, wavelength=ssp_wave, c=3e8, direction=direction)

In [None]:
apply_doppler(data)

resampling