In [None]:
from dask.distributed import LocalCluster
cluster = LocalCluster(n_workers=16, threads_per_worker=1)
client = cluster.get_client()

In [None]:
cluster

In [None]:
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt
import asdf
from tqdm.dask import TqdmCallback

import astropy.units as u
from astropy.visualization import quantity_support
quantity_support()

from sunraster.instr.spice import read_spice_l2_fits
from sospice.calibrate import spice_error

from astropy.modeling.fitting_parallel import parallel_fit_model_nd
from astropy.modeling import fitting


filename = "solo_L2_spice-n-ras_20230415T120519_V02_184549780-000.fits.gz"
window='N IV 765 ... Ne VIII 770 (Merged)'

spice = read_spice_l2_fits(filename)
spice = spice[window]
spice.mask |= (spice.data <= 0)
include = ~np.all(spice.mask, axis=(0, 2, 3))

hdulist = fits.open(filename)
for h in hdulist:
    if h.name == window:
        hdu = h
        break
av_cojstant_noise_level, sigmadict = spice_error(hdu)
sigma = sigmadict["Total"].value
spice.mask = spice.mask | np.isnan(sigma) | (sigma <= 0)
# drop leading length 1 dimension
spice = spice[0]
spice

In [None]:
# We were given a model to fit, I assume it's "emprical" starting parameters
with asdf.open("spice-model.asdf") as af:
    initial_model = af["spice-model"]

In [None]:
#spice = spice[:, 250:300, 50:100]
spice

In [None]:
wave = spice.axis_world_coords("em.wl")[0].to(u.AA)

In [None]:
spice_model_fit = parallel_fit_model_nd(
        model=initial_model,
        fitter=fitting.TRFLSQFitter(),
        data=spice.data,
        fitting_axes=0,
        world = {0: wave},
        diagnostics="failed",
        diagnostics_path="diag",
        fitter_kwargs={"filter_non_finite": True},
        chunk_n_max=50,
        scheduler="default",
)    

In [None]:
all_fits = spice_model_fit(wave[:, None, None].to_value(u.AA))

In [None]:
interesting_looking_pixel = np.s_[:, 284, 116]
interesting_looking_pixel = np.s_[:, 25, 25]

fig, ax = plt.subplots()
ax.set_title("Intitial guess")
ax.plot(wave.to(u.nm), spice.data[interesting_looking_pixel], "o-", label="Data")
ax.plot(wave, initial_model(wave.to_value(u.AA)), "--", label="Initial Guess")
ax.plot(wave, all_fits[interesting_looking_pixel], "--", label="Fit")
plt.legend()