In [1]:
import sys
import os
import jax
import jax.numpy as np
import jax.tree as jtu
import jax.random as jr
import equinox as eqx
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Basic jax import
jax.config.update("jax_enable_x64", True)

# Check if running on remote, and set directory to where notebook is run
if jax.devices()[0].platform == "gpu":
    os.chdir("code/amigo_project/notebooks/calibration")

# Add parent directories
paths = [os.path.abspath(os.path.join(os.getcwd(), path)) for path in ['..', "../.."]]
for path in paths:
    if path not in sys.path:
        sys.path.insert(0, path)
        
# Plotting set up
%matplotlib inline
plt.rcParams["image.cmap"] = "inferno"
plt.rcParams["font.family"] = "serif"
plt.rcParams["image.origin"] = "lower"
plt.rcParams["figure.dpi"] = 120

inferno = mpl.colormaps["inferno"]
seismic = mpl.colormaps["seismic"]
coolwarm = mpl.colormaps["coolwarm"]

inferno.set_bad("k", 0.5)
seismic.set_bad("k", 0.5)
coolwarm.set_bad("k", 0.5)

def merge_cbar(ax):
    return make_axes_locatable(ax).append_axes("right", size="5%", pad=0.0)

ERROR:2025-05-20 17:03:10,551:jax._src.xla_bridge:647: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/home/louis/miniconda3/envs/jax_gpu/lib/python3.13/site-packages/jax/_src/xla_bridge.py", line 645, in discover_pjrt_plugins
    plugin_module.initialize()
    ~~~~~~~~~~~~~~~~~~~~~~~~^^
  File "/home/louis/miniconda3/envs/jax_gpu/lib/python3.13/site-packages/jax_plugins/xla_cuda12/__init__.py", line 105, in initialize
    triton.register_compilation_handler(
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'jax._src.lib.triton' has no attribute 'register_compilation_handler'


In [2]:
import zodiax as zdx
import dLux.utils as dlu
from copy import deepcopy


def get_spectra(wavels, filt_weights, spectra):
    xs = np.linspace(-1, 1, len(wavels), endpoint=True)
    spectra_slopes = 1 + spectra * xs
    weights = filt_weights * spectra_slopes
    return wavels, weights / weights.sum()


class SimpleModel(zdx.Base):
    positions: np.ndarray
    fluxes: np.ndarray
    spectra: np.ndarray
    aberrations: np.ndarray
    amplitudes: np.ndarray
    phases: np.ndarray
    downsample: int = eqx.field(static=True)

    def __init__(self, optics, vis_model, downsample=3):
        self.positions = np.zeros(2)
        self.fluxes = np.zeros(1)
        self.spectra = np.zeros(1)
        self.aberrations = optics.pupil_mask.abb_coeffs
        self.amplitudes = np.zeros(vis_model.n_basis)
        self.phases = np.zeros(vis_model.n_basis)
        self.downsample = int(downsample)

    @eqx.filter_jit
    def model(self, optics, vis_model, filt):
        # Model the wavefront
        position = dlu.arcsec2rad(self.positions)
        wavels, weights = get_spectra(*optics.filters[filt], self.spectra)
        optics = optics.set("pupil_mask.abb_coeffs", self.aberrations)
        wfs = optics.propagate(wavels, position, weights, return_wf=True)

        # Model the visibilities and downsample
        flux = 10 ** self.fluxes
        psf = vis_model.model_vis(wfs, self.amplitudes, self.phases, filt).data
        return dlu.downsample(flux * psf, self.downsample, mean=False).flatten()


struct = {
    "positions": {},
    "fluxes": {},
    "spectra": {},
    "aberrations": {},
    "amplitudes": {},
    "phases": {},
}
outputs = {
    "psf": {},
    "jacobian": deepcopy(struct),
    "hessian": deepcopy(struct),
    "covariance": deepcopy(struct),
}

In [3]:
from amigo.misc import tqdm
from amigo.optical_models import AMIOptics
from amigo.vis_models import LogVisModel
from jax.flatten_util import ravel_pytree
from amigo.stats import batched_jacobian, gauss_hessian

state = np.load("../../GPU_files/results/cal_model.npy", allow_pickle=True).item()
vis_basis = np.load("../../GPU_files/results/vis_basis.npy", allow_pickle=True).item()

vis_model = LogVisModel(vis_basis, n_basis=1300) # Calc jac for all basis terms
optics = AMIOptics().set(
    ["psf_upsample", "transmission", "pupil_mask.abb_coeffs"],
    [1, state["transmission"], state["aberrations"]],
)

for filt in tqdm(["F380M", "F430M", "F480M"]):
    # Update the optics to the correct defocus and build the model
    optics = optics.set("defocus", state["defocus"][filt])
    model = SimpleModel(optics, vis_model)

    # Build the input argument to a single vector
    X, unravel_fn = ravel_pytree(model)
    def model_fn(X):
        return unravel_fn(X).model(optics, vis_model, filt)

    # Calculate the Jacobian and Hessian
    psf = model_fn(X)
    J_cov = np.eye(psf.size) * psf[..., None]
    J = batched_jacobian(X, model_fn, n_batch=200)
    H = gauss_hessian(J, J_cov)
    cov = np.linalg.inv(H)

    # Get the sizes and lengths to re-build the parameters
    params = list(struct.keys())
    sizes = [model.get(param).size for param in params]
    lengths = [0] + [int(l) for l in list(np.cumsum(np.array(sizes)))]

    # Save the outputs
    J_dict, H_dict = {}, {}
    for i, param in enumerate(params):
        s, e = lengths[i], lengths[i + 1]
        outputs["jacobian"][param][filt] = J[s:e, :]
        outputs["hessian"][param][filt] = H[s:e, s:e]
        outputs["covariance"][param][filt] = cov[s:e, s:e]
    outputs["psf"][filt] = psf

np.save("../../GPU_files/results/jac_outputs.npy", outputs, allow_pickle=True)

  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
jtu.map(lambda x: x.shape, outputs)

{'covariance': {'aberrations': {'F380M': (70, 70),
   'F430M': (70, 70),
   'F480M': (70, 70)},
  'amplitudes': {'F380M': (1300, 1300),
   'F430M': (1300, 1300),
   'F480M': (1300, 1300)},
  'fluxes': {'F380M': (1, 1), 'F430M': (1, 1), 'F480M': (1, 1)},
  'phases': {'F380M': (1300, 1300),
   'F430M': (1300, 1300),
   'F480M': (1300, 1300)},
  'positions': {'F380M': (2, 2), 'F430M': (2, 2), 'F480M': (2, 2)},
  'spectra': {'F380M': (1, 1), 'F430M': (1, 1), 'F480M': (1, 1)}},
 'hessian': {'aberrations': {'F380M': (70, 70),
   'F430M': (70, 70),
   'F480M': (70, 70)},
  'amplitudes': {'F380M': (1300, 1300),
   'F430M': (1300, 1300),
   'F480M': (1300, 1300)},
  'fluxes': {'F380M': (1, 1), 'F430M': (1, 1), 'F480M': (1, 1)},
  'phases': {'F380M': (1300, 1300),
   'F430M': (1300, 1300),
   'F480M': (1300, 1300)},
  'positions': {'F380M': (2, 2), 'F430M': (2, 2), 'F480M': (2, 2)},
  'spectra': {'F380M': (1, 1), 'F430M': (1, 1), 'F480M': (1, 1)}},
 'jacobian': {'aberrations': {'F380M': (70, 640

: 