In [1]:
import jax.numpy as np
from tqdm.notebook import tqdm
from optimisation import optimise_fisher



dLux: Jax is running in 32-bit, to enable 64-bit visit: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision


In [2]:
seeds = np.arange(5)

# Margingal Parameters
full_params = [

    # Source Parameters
    "separation",     # arcseconds
    "x_position",     # arcseconds
    "y_position",     # arcseconds
    "position_angle", # degrees
    "log_flux",       # log10(photons(/s?))
    "contrast",       # ratio
    "wavelengths",    # m

    # Instrument Parameters
    "psf_pixel_scale",         # arcseconds
    "aberrations.coefficients" # Zernikes
    ]

In [3]:
path = f"data/gradient_energy/full_params"

all_losses, all_coefficients = [], []
for seed in tqdm(seeds):
    losses, coefficients = optimise_fisher(seed, full_params)

    all_losses.append(losses)
    all_coefficients.append(coefficients)

    # Iteratevely save so progress can be examined line
    np.save(f"{path}_losses.npy", np.array(all_losses))
    np.save(f"{path}_coefficients.npy", np.array(all_coefficients))

  0%|          | 0/5 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Margingal Parameters
reduced_params = [

    # Source Parameters
    "separation",     # arcseconds
    "x_position",     # arcseconds
    "y_position",     # arcseconds
    "position_angle", # degrees
    "log_flux",       # log10(photons(/s?))
    "contrast",       # ratio

    # Instrument Parameters
    "aberrations.coefficients" # Zernikes
    ]

In [None]:
path = f"data/gradient_energy/reduced_params"

all_losses, all_coefficients = [], []
for seed in tqdm(seeds):
    losses, coefficients = optimise_fisher(seed, reduced_params)

    all_losses.append(losses)
    all_coefficients.append(coefficients)

    # Iteratevely save so progress can be examined line
    np.save(f"{path}_losses.npy", np.array(all_losses))
    np.save(f"{path}_coefficients.npy", np.array(all_coefficients))