In [None]:
# NBVAL_SKIP
import os
#  os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'
#os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'
#os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'
os.environ['SPS_HOME'] = '/export/home/aschaibl/fsps'
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3, 4 '


In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 400000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
        "dust": {
                "extinction_model": "Cardelli89",
                "dust_to_gas_ratio": 0.01,
                "dust_to_metals_ratio": 0.4,
                "dust_grain_density": 3.5,
                "Rv": 3.1,
            },
    },        
}

In [None]:
# NBVAL_SKIP
import jax
import jax.numpy as jnp

n_particles = 400_000

age = jnp.linspace(0, 20, n_particles, )
metallicity = jnp.linspace(0., 0.05, n_particles, )

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import shard_map
from jax.sharding import NamedSharding



devices = jax.devices()
mesh = Mesh(devices, axis_names=('N_particles',))
sharding = NamedSharding(mesh, P('N_particles')) 

age = jax.device_put(age, sharding)
metallicity = jax.device_put(metallicity, sharding)

In [None]:
# NBVAL_SKIP
age = jnp.atleast_1d(age)
metallicity = jnp.atleast_1d(metallicity)

In [None]:
# NBVAL_SKIP
from rubix.core.ssp import get_lookup_interpolation

In [None]:
# NBVAL_SKIP
lookup_interpolation = get_lookup_interpolation(config)

In [None]:
# NBVAL_SKIP
print("lookup_interpolation", lookup_interpolation)

In [None]:
# NBVAL_SKIP
def lookup_interpolation_lax(age_metallicity):
    age, metallicity = age_metallicity
    return lookup_interpolation(age, metallicity)

interpolation = jax.lax.map(lookup_interpolation_lax, (age, metallicity), batch_size=1)

In [None]:
# NBVAL_SKIP
_, interpolation = jax.lax.scan(
                    lambda carry, x: (carry, lookup_interpolation_lax(x)),
                    None,
                    (age, metallicity),
                )