In [None]:
# NBVAL_SKIP
import os
os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'
#os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'
#os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'

In [None]:
# NBVAL_SKIP
import logging

# Disable all logging messages
#logging.disable(logging.CRITICAL)

# Load ssp template from FSPS

In [None]:
# NBVAL_SKIP
from rubix.spectra.ssp.factory import get_ssp_template
ssp_fsps = get_ssp_template("FSPS")

In [None]:
# NBVAL_SKIP
age_values = ssp_fsps.age
print(age_values.shape)

metallicity_values = ssp_fsps.metallicity
print(metallicity_values.shape)

In [None]:
# NBVAL_SKIP
index_age = 90
index_metallicity = 7

initial_metallicity_index = 5
initial_age_index = 70

learning_age = 0.5
learning_metallicity = 1e-3
tol = 1e-8

print(f"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}")
print(f"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}")

# Configure pipeline

In [None]:
# NBVAL_SKIP
from rubix.core.pipeline import RubixPipeline
import os

config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 1.2},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
        "dust": {
                "extinction_model": "Cardelli89",
                "dust_to_gas_ratio": 0.01,
                "dust_to_metals_ratio": 0.4,
                "dust_grain_density": 3.5,
                "Rv": 3.1,
            },
    },        
}

In [None]:
# NBVAL_SKIP
pipe = RubixPipeline(config)
inputdata = pipe.prepare_data()
rubixdata = pipe.run(inputdata)

# Set target values

In [None]:
# NBVAL_SKIP
import jax.numpy as jnp

rubixdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])
rubixdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])
rubixdata.stars.mass = jnp.array([[1.0, 1.0]])
rubixdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
rubixdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

In [None]:
# NBVAL_SKIP
targetdata = pipe.run(rubixdata)

# Set initial datracube

In [None]:
# NBVAL_SKIP
rubixdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])
rubixdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])
rubixdata.stars.mass = jnp.array([[1.0, 1.0]])
rubixdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
rubixdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

In [None]:
# NBVAL_SKIP
initialdata = pipe.run(rubixdata)

# Levenberg Marquardt optimizer

In [None]:
# NBVAL_SKIP
from rubix.pipeline import linear_pipeline as pipeline

pipeline_instance = RubixPipeline(config)

pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(
    pipeline_instance.pipeline_config, 
    pipeline_instance._get_pipeline_functions()
)
pipeline_instance._pipeline.assemble()
pipeline_instance.func = pipeline_instance._pipeline.compile_expression()

In [None]:
# NBVAL_SKIP
import copy
import jax
import jax.numpy as jnp
from jaxopt import LevenbergMarquardt

# Define the sigmoid and its inverse (logit) for mapping.
def sigmoid(u):
    return 1.0 / (1.0 + jnp.exp(-u))

def inverse_transform(p, p_min, p_max):
    """
    Maps a parameter value p in [p_min, p_max] to an unconstrained value u.
    """
    # Normalize p to [0, 1]
    p_norm = (p - p_min) / (p_max - p_min)
    # Avoid division by zero or log(0)
    epsilon = 1e-8
    p_norm = jnp.clip(p_norm, epsilon, 1 - epsilon)
    # Inverse of the sigmoid is the logit function.
    return jnp.log(p_norm / (1 - p_norm))

# Define your bounds.
age_min, age_max = 0.0, 10.3
metal_min, metal_max = 1e-4, 0.05

def transform_params(u):
    """
    Transforms the unconstrained parameters u to the constrained parameters.
    u: a 1D array with two unconstrained parameters [u_age, u_metallicity].
       Returns a JAX array: [age, metallicity] within the specified bounds.
    """
    u_norm = sigmoid(u)  # maps each element to (0,1)
    age = age_min + u_norm[0] * (age_max - age_min)
    metallicity = metal_min + u_norm[1] * (metal_max - metal_min)
    return jnp.array([age, metallicity])

def residual(u, data, target):
    # Convert unconstrained parameters u to constrained parameters.
    age, metallicity = transform_params(u)
    # Make a deep copy to avoid side effects.
    new_data = copy.deepcopy(data)
    new_data.stars.age = jnp.array([[age, age]])
    new_data.stars.metallicity = jnp.array([[metallicity, metallicity]])
    output = pipeline_instance.func(new_data)
    return (output.stars.datacube - target.stars.datacube).ravel()


In [None]:
# NBVAL_SKIP
# Use the correct inverse transform to initialize the unconstrained parameters.
u_age_init = inverse_transform(age_values[initial_age_index], age_min, age_max)
u_metal_init = inverse_transform(metallicity_values[initial_metallicity_index], metal_min, metal_max)
init_params = jnp.array([u_age_init, u_metal_init])

# Instantiate the Levenberg-Marquardt optimizer.
lm_optimizer = LevenbergMarquardt(residual, maxiter=1000)

# Run the optimizer.
result = lm_optimizer.run(init_params, data=inputdata, target=targetdata)

# Retrieve and print the raw optimized unconstrained parameters.
optimized_params = result.params
print("Raw Optimized u_age:", optimized_params[0])
print("Raw Optimized u_metallicity:", optimized_params[1])

# Transform back to the constrained space.
constrained_params = transform_params(result.params)
print("Optimized Age:", constrained_params[0])
print("Optimized Metallicity:", constrained_params[1])


## Loss history

In [None]:
# NBVAL_SKIP
import matplotlib.pyplot as plt
import numpy as np

# If loss_history is a JAX array, convert it to a NumPy array:
loss_history_np = np.array(loss_history)

# Create an array for the x-axis indices
indices = np.arange(len(loss_history_np))

plt.figure(figsize=(8, 6))
plt.plot(indices, loss_history_np, marker='o', linestyle='-')
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Loss History")
plt.grid(True)
#plt.ylim(-2e-8, 5e-7)
#plt.savefig(f"./output/optimizer/optimization_progress_loss_metals{metallicity_index}_agestart{initial_age_index}_to{index}_learning{learning}_tol{tol}.png")
plt.show()

In [None]:
# NBVAL_SKIP
import matplotlib.pyplot as plt
import numpy as np

# If age_history is a JAX array, convert it to a NumPy array:
age_history_np = np.array(age_history)

# Create an array for the x-axis indices
indices = np.arange(len(age_history_np))

plt.figure(figsize=(8, 6))
plt.plot(indices, age_history_np, marker='o', linestyle='-')
plt.hlines(y=age_values[index_age], xmin=0, xmax=len(age_history_np), color='r', linestyle='-')
plt.xlabel("Iteration")
plt.ylabel("Age")
plt.title("Age History")
plt.grid(True)
#plt.savefig(f"./output/optimizer/optimization_progress_age_metals{metallicity_index}_agestart{initial_age_index}_to{index}_learning{learning}_tol{tol}.png")
plt.show()

In [None]:
# NBVAL_SKIP
metallicity_history_np = np.array(metallicity_history)

# Create an array for the x-axis indices
indices = np.arange(len(metallicity_history_np))

plt.figure(figsize=(8, 6))
plt.plot(indices, metallicity_history_np, marker='o', linestyle='-')
plt.hlines(y=metallicity_values[index_metallicity], xmin=0, xmax=len(metallicity_history_np), color='r', linestyle='-')
plt.xlabel("Iteration")
plt.ylabel("Metallicity")
plt.title("Metallicity History")
plt.grid(True)
#plt.savefig(f"./output/optimizer/optimization_progress_metals{metallicity_index}_agestart{initial_age_index}_to{index}_learning{learning}_tol{tol}.png")
plt.show()

In [None]:
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 0
rubixdata.stars.age = jnp.array([age_history[i], age_history[i]])
rubixdata.stars.metallicity = jnp.array([metallicity_history[i], metallicity_history[i]])
rubixdata.stars.mass = jnp.array([[1.0, 1.0]])
rubixdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

pipe = RubixPipeline(config)
rubixdata = pipe.run(rubixdata)

#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq

spectra_target = targetdata.stars.datacube
spectra_optimitzed = rubixdata.stars.datacube

plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]:.2f}, metal. = {metallicity_history[i]:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()

In [None]:
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = -1
rubixdata.stars.age = jnp.array([age_history[i], age_history[i]])
rubixdata.stars.metallicity = jnp.array([metallicity_history[i], metallicity_history[i]])
rubixdata.stars.mass = jnp.array([[1.0, 1.0]])
rubixdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

pipe = RubixPipeline(config)
rubixdata = pipe.run(rubixdata)

#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq

spectra_target = targetdata.stars.datacube
spectra_optimitzed = rubixdata.stars.datacube

plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]:.2f}, metal. = {metallicity_history[i]:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()