In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import healpy as hp
from nsb3.core import SceneComponents, ComponentType, ParameterSpec, Scene
from nsb3.core import InstrumentQuery, AtmosphereQuery, DiffuseQuery, CatalogQuery
from nsb3.core import render

from nsb3.atmosphere.scattering import rayleigh_phase, henyey_greenstein_phase, gradation_function
import astropy.units as u

# Creating a scene:

# Adding atmosphere:

In [2]:
param_scene = Scene()

# Setup base data
wvl = jnp.linspace(250,650)
N_wvl = len(wvl)
vals = hp.pixelfunc.get_interp_weights(2**4, np.zeros(960), np.zeros(960), 
                                      nest=False, lonlat=True)

lon, lat = hp.pix2ang(2**4, np.arange(hp.nside2npix(2**4)), 
                     nest=False, lonlat=True)

latmask = lat <= np.pi/2

theta = np.ones((2, 2, np.sum(latmask)))
sec_Z = np.ones((2, 2)) * 0.1
sec_z = 1/np.cos(np.deg2rad(lat[latmask]))

# Add atmosphere generator
def atmosphere_generator(params):
    t_r = 0.00879*(wvl/1e3)**-4.09 * jnp.exp(-1.8/params['r_scale_height'])
    t_m = params['AOD_440']*(wvl/440)**(-params['AE_440_870']) * jnp.exp(-1.8/params['m_scale_height'])
    tau = t_r + t_m
    indicatrix = (t_r[None,:] * rayleigh_phase(theta)[...,None] + 
                  t_m[None,:] * henyey_greenstein_phase(theta, params['g'])[...,None]) / tau
    gradation = gradation_function(tau[None,None,None,:], sec_Z[:,:,None,None], sec_z[None,None,:,None])
    return AtmosphereQuery(
        YiXi=[jnp.linspace(-1, 1, 2), jnp.linspace(-1, 1, 2)],
        tau=tau,
        extinction=jnp.exp(-tau[None,:] * sec_z[:,None]),
        scattering=indicatrix*gradation
    )

param_scene.add_generator(
    name="AngstromHG",
    component_type=ComponentType.ATMOSPHERE,
    generator_fn=atmosphere_generator,
    param_specs={
        'r_scale_height': ParameterSpec((1,), 8, description="Rayleigh scale height in km"),
        'AOD_440': ParameterSpec((1,), 0.1, description="Mie optical depth at 440 nm"),
        'AE_440_870': ParameterSpec((1,), 1, description="Angstrom exponent 440_870 nm"),
        'm_scale_height': ParameterSpec((1,), 1.8, description="Aerosol scale height in km"),
        'g': ParameterSpec((1,), 0.7, description="HG asymmetry parameter", bounds=(0, 1))
    }
)

# Adding instrument:

In [3]:
from nsb2.instrument import HESS
hess1 = HESS.CT1()

centers = hess1.pix_pos
grid = np.stack([np.vstack(pix.centers) for pix in hess1.pixels])
values = np.stack([pix.values for pix in hess1.pixels])
bandpass = hess1.bandpass(wvl*u.nm)

vals = hp.pixelfunc.get_interp_weights(2**4, np.zeros(960), np.zeros(960), 
                                       nest=False, lonlat=True)

# Add instrument generator
def instrument_generator(params):   
    return InstrumentQuery(
        centers=centers + params['shift'],
        hp_pixels=vals[0],
        hp_weight=vals[1],
        grid=grid + params['shift'][None,:,None],
        values=values * params['flatfield'][:,None,None],
        bandpass=bandpass
    )

param_scene.add_generator(
    name="HESS 1",
    component_type=ComponentType.INSTRUMENT,
    generator_fn=instrument_generator,
    param_specs={
        'shift': ParameterSpec((2,), jnp.zeros(2), description="Pixel shift in rad"),
        'flatfield': ParameterSpec((960,), jnp.ones(960), description="Flatfielding values"),
    }
)

# Adding airglow:

In [24]:
from nsb2.emitter import airglow
glow = airglow.from_eso_skycalc(87*u.km, 100)

In [25]:
def vanrhjin(h, z):
    r_rh = 6738 / (6738 + h)
    return 1 / (1 - r_rh**2 * np.sin(z) ** 2) ** 0.5

base_weight = vanrhjin(90, np.deg2rad(90-lat[latmask]))
base_spec = np.interp(wvl*u.nm, glow.spectral.wvl, glow.spectral.flx.flatten()).value
line_440 = np.zeros(len(wvl))
line_440[np.argmin(np.abs(wvl-440))] = 1000

In [26]:
def diffuse_generator(params):
    return DiffuseQuery(flux_map=base_weight[:,None]*(params['continuum']*base_spec + params['440_nm']*line_440)[None,:])

param_scene.add_generator(
    name="Airglow",
    component_type=ComponentType.DIFFUSE,
    generator_fn=diffuse_generator,
    param_specs={
        'continuum': ParameterSpec((1,), 1.0, description="Normalization of continuum",),
        '440_nm': ParameterSpec((1,), 1.0, description="Normalization of 440 nm line",)
    }
)

# Adding catalog:

In [27]:
N_sources=20000

def catalog_generator(params):
    sec_Z = jnp.ones(N_sources)
    image_coords = np.random.uniform(-1, 1, size=(N_sources, 2))
    base_flux_map = np.ones((3072, N_wvl))[latmask]
    
    return CatalogQuery(
        sec_Z=sec_Z,
        image_coords=image_coords,
        flux_values=params['source_fluxes'] * np.ones((N_sources, N_wvl)),
        flux_map=base_flux_map * params['map_scaling']
    )

param_scene.add_generator(
    name="GaiaDR3",
    component_type=ComponentType.CATALOG,
    generator_fn=catalog_generator,
    param_specs={
        'source_fluxes': ParameterSpec((1,), 1.0, description="Individual source fluxes"),
        'map_scaling': ParameterSpec((1,), 1.0, description="Flux map scaling", bounds=(0, 10))
    }
)

In [32]:
# Get initial parameters
initial_params = param_scene.get_initial_parameters()

# Print parameter table
print("\nInitial Parameters:")
param_scene.print_parameters(format='table')

# JIT-compiled fit function
@jax.jit
def fit(parameters):
    scene_instance = param_scene.realize(parameters)
    return render(scene_instance)

fit = jax.checkpoint(fit)

# Test gradient computation
grad_fit = jax.grad(lambda p: jnp.sum(fit(p)))

print("\n\nTesting JIT and gradient compilation...")
result = fit(initial_params)
print("✓ JIT compilation successful")
gradients = grad_fit(initial_params)
print("✓ Gradient computation successful")


Initial Parameters:
┌────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ Component  │ Type       │ Parameter      │ Shape  │ Initial │ Current │ Bounds  │ Description                  │ 
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ AngstromHG │ atmosphere │ r_scale_height │ (1,)   │ 8.0000  │ -       │ -       │ Rayleigh scale height in km  │ 
│ AngstromHG │ atmosphere │ AOD_440        │ (1,)   │ 0.1000  │ -       │ -       │ Mie optical depth at 440 nm  │ 
│ AngstromHG │ atmosphere │ AE_440_870     │ (1,)   │ 1.0000  │ -       │ -       │ Angstrom exponent 440_870 nm │ 
│ AngstromHG │ atmosphere │ m_scale_height │ (1,)   │ 1.8000  │ -       │ -       │ Aerosol scale height in km   │ 
│ AngstromHG │ atmosphere │ g              │ (1,)   │ 0.7000  │ -       │ (0, 1)  │ HG asymmetry parameter       │ 
│ HESS 1     │ instrument │ shift          │ (2,)   │

In [30]:
%%time
result = fit(initial_params).block_until_ready()

CPU times: user 732 ms, sys: 50.1 ms, total: 782 ms
Wall time: 91.7 ms
