In [None]:
import dkist
import dkist.net
from sunpy.net import Fido, attrs as a

In [None]:
from astropy.modeling.models import Lorentz1D, Const1D, Gaussian1D
import astropy.units as u

In [None]:
import matplotlib.pyplot as plt
from astropy.visualization import quantity_support

In [None]:
from astropy.modeling.fitting import TRFLSQFitter
import numpy as np

In [None]:
quantity_support()

In [None]:
%matplotlib widget

In [None]:
res = Fido.search(a.dkist.Dataset("ALDLJ"))

In [None]:
dataset_path = "/data/dkist/prod/pid_2_114/{dataset_id}"
asdf_file = Fido.fetch(res, path=dataset_path)

In [None]:
visp = dkist.load_dataset(asdf_file)

In [None]:
# Have a running globus connect personal endpoint and then run this to download the actual data
#visp.files.download()

In [None]:
visp[0, :, 1000].plot()

In [None]:
wave = visp[0, :, 1000].axis_world_coords("em.wl")[0]

In [None]:
line_1 = Lorentz1D(amplitude=-0.6*u.ct, fwhm=0.1*u.nm, x_0=854.3*u.nm)
line_1_constrained = line_1.copy()
line_1_constrained.x_0.min = 854.25*u.nm
line_1_constrained.x_0.max = 854.35*u.nm

line_2 = Lorentz1D(amplitude=-0.25*u.ct, fwhm=0.01*u.nm, x_0=853.98*u.nm)
line_2_constrained = line_2.copy()
line_2_constrained.x_0.min = 853.95
line_2_constrained.x_0.max = 854.00

line_3 = Lorentz1D(amplitude=-0.15*u.ct, fwhm=0.01*u.nm, x_0=854.08*u.nm)
line_3_constrained = line_3.copy()
line_3_constrained.x_0.min = 854.05
line_3_constrained.x_0.max = 854.13

In [None]:
model_constrained = (
    Const1D(1*u.ct) + 
    line_1_constrained +
    line_2_constrained +
    line_3_constrained
)

model = (
    Const1D(1*u.ct) + 
    line_1 +
    line_2 +
    line_3
)

In [None]:
fit_constrained = TRFLSQFitter()(model_constrained, wave, visp[0, :, 1000].data.compute() * u.ct)
fit = TRFLSQFitter()(model, wave, visp[0, :, 1000].data.compute() * u.ct)

In [None]:
fig, ax = plt.subplots()
ax.set_title("VISP")
ax.plot(wave, np.mean(visp[0, :, :].data, axis=1), label="slit average")
ax.plot(wave, model(wave), label="initial guess")
ax.plot(wave, fit(wave), label="fit (unconstrained)")
ax.plot(wave, fit_constrained(wave), label="fit (constrained)")
plt.legend()

In [None]:
from astropy.modeling.fitting_parallel import parallel_fit_model_nd

In [None]:
visp_model_fit = parallel_fit_model_nd(
        model=model_constrained,
        fitter=TRFLSQFitter(),
        data=visp[0:1].data,
        fitting_axes=1,
        world = visp[0:1].wcs,
        diagnostics="failed",
        diagnostics_path="diag",
        fitter_kwargs={"filter_non_finite": True},
        chunk_n_max=50,
        scheduler="synchronous",
)  