In [None]:
# %matplotlib widget

In [None]:
from itertools import product
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt
from astropy.visualization import quantity_support
from astropy.units import Quantity
quantity_support()
from tqdm import tqdm
import astropy.units as u
import os

from spectrum_component_analyser.internals.spectrum import spectrum
from spectrum_component_analyser.internals.readers import read_JWST_fits,read_JWST_fits_all_spectra
from spectrum_component_analyser.internals.spectral_grid import spectral_grid
from spectrum_component_analyser.minimisation import calc_fitted_spectrum, get_optimality, plot_nicely, get_main_components

star_name : str = "LTT 3780"
star_temperature : Quantity[u.K] = 3350 * u.K

__file__ = os.getcwd()
jwst_file_segment_001 = (__file__ / Path("../../observed_spectra/MAST_2025-10-26T11_57_04.058Z - LTT-3780/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg001_nis_x1dints.fits")).resolve()
jwst_file_segment_002 = (__file__ / Path("../../observed_spectra/MAST_2025-10-26T11_57_04.058Z - LTT-3780/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg002_nis_x1dints.fits")).resolve()
jwst_file_segment_003 = (__file__ / Path("../../observed_spectra/MAST_2025-10-26T11_57_04.058Z - LTT-3780/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg003_nis_x1dints.fits")).resolve()
jwst_file_segment_004 = (__file__ / Path("../../observed_spectra/MAST_2025-10-26T11_57_04.058Z - LTT-3780/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg004_nis_x1dints.fits")).resolve()

spectrum_to_decompose : spectrum = read_JWST_fits(jwst_file_segment_001, INTEGRATION_INDEX=100, name=star_name, T_eff = star_temperature)

mask = np.isfinite(spectrum_to_decompose.Fluxes) # & (spectrum_to_decompose.Wavelengths < 1.8 * u.um)

spectrum_to_decompose = spectrum_to_decompose[mask]

spectrum_to_decompose.plot()

print("[SPECTRUM COMPONENT ANALYSER] : reading in hdf5")
# spectral_grid_relative_path = Path("../../spectral_grids/JWST_convolved_spectral_grid.hdf5")
spectral_grid_relative_path = Path("../../spectral_grids/JWST_convolved_not_oversmoothed.hdf5")
spectral_grid_absolute_path = (__file__ / spectral_grid_relative_path).resolve()
spec_grid : spectral_grid = spectral_grid.from_hdf5(absolute_path=spectral_grid_absolute_path)
print("[SPECTRUM COMPONENT ANALYSER] : finished reading in hdf5")

In [None]:
# lets analyse the JWST fits file: lets make a transmission curve
from scipy.signal import medfilt

from spectrum_component_analyser.internals.readers import JWST_NORMALISING_POINT

all_segments = [jwst_file_segment_001, jwst_file_segment_002, jwst_file_segment_003, jwst_file_segment_004]

all_spectra = []

for segment in all_segments:
    all_spectra.extend(read_JWST_fits_all_spectra(segment, T_eff = None, name=star_name))

total_fluxes = []
for spec in all_spectra:
    # this looks to remove quite a lot of information
    spec.Wavelengths = spec.Wavelengths[mask]
    spec.Fluxes = medfilt(spec.Fluxes[mask], kernel_size=3) * spec.Fluxes.unit
    total_fluxes.append(np.sum(spec.Fluxes).value)

    # less harsh option
    # total_fluxes.append(np.sum(spec.Fluxes[mask]).value)

plt.clf()
plt.plot([i for i in range(len(total_fluxes))], total_fluxes)
plt.xlabel("Integration Index")
plt.ylabel("Total Flux / arbitrary units")
plt.title("Transit Light Curve for" + star_name)
plt.show()

min_included_integration_index = 0
max_included_integration_index = 300

all_fluxes=[spec.Fluxes for spec in all_spectra[min_included_integration_index:max_included_integration_index]]

spectrum_to_decompose = spectrum(wavelengths=all_spectra[0].Wavelengths,
                                 fluxes=np.mean(all_fluxes, axis=0) * all_fluxes[0].unit,
                                 normalised_point=JWST_NORMALISING_POINT,
                                 temperature = star_temperature,
                                 observational_resolution=None,
                                 observational_wavelengths=None,
                                 name="averaged" + star_name)

# spectrum_to_decompose = spectrum_to_decompose[mask]

spectrum_to_decompose.plot()

In [None]:
all_parameters = list(product(spec_grid.T_effs, [0, 1] * u.dimensionless_unscaled, [4.5, 5] * u.dimensionless_unscaled))
# all_parameters = list(product(spec_grid.T_effs, [0 * u.dimensionless_unscaled], spec_grid.Log_gs))
# all_parameters = list(product(spec_grid.T_effs, spec_grid.FeHs, [5 * u.dimensionless_unscaled]))
# all_parameters = list(product(spec_grid.T_effs, spec_grid.FeHs, spec_grid.Log_gs))

total_number_of_components = len(all_parameters)

A, result = calc_fitted_spectrum(all_parameters,
                                 spec_grid=spec_grid,
                                 spectrum_to_decompose=spectrum_to_decompose,
                                 mask=mask,
                                 total_number_of_components=total_number_of_components,
                                 max_iterations=1000)

print(f"residual MSE, residual sum of squares = {get_optimality(A, result, spectrum_to_decompose)}")

In [None]:
%matplotlib widget

from spectrum_component_analyser.minimisation import FEH_COLUMN, LOGG_COLUMN, TEFF_COLUMN, WEIGHT_COLUMN

hash_map = plot_nicely(A, result, all_parameters, spec_grid, spectrum_to_decompose, star_name)

weights = hash_map[WEIGHT_COLUMN].values

teff_avg = np.average(hash_map[TEFF_COLUMN], weights=weights)
feh_avg  = np.average(hash_map[FEH_COLUMN],  weights=weights)
logg_avg = np.average(hash_map[LOGG_COLUMN], weights=weights)

print(f"T_eff avg = {teff_avg}")
print(f"FeH avg = {feh_avg}")
print(f"log g avg = {logg_avg}")

In [None]:
# now lets re-run that, but with only the top few components and see if the fit is better

# prolly want a new class for this
# or could be a list of phoenix spectra; maybe that would help the above code be a bit neater too

ns = np.arange(1, 20, 1)
optimalities = np.array([])

for number_of_components_to_keep in tqdm(ns):
    
    main_components = get_main_components(hash_map, number_of_components_to_keep)

    A_restricted, result_restricted = calc_fitted_spectrum(main_components,
                                                  spectrum_to_decompose=spectrum_to_decompose,
                                                  spec_grid = spec_grid,
                                                  mask=mask,
                                                  total_number_of_components=len(main_components),
                                                  verbose=False)

    optimalities = np.append(optimalities, get_optimality(A_restricted, result_restricted, spectrum_to_decompose)[0])

    # _ = plot_nicely(A_restricted, result_restricted, main_components, spec_grid, spectrum_to_decompose)

# plt.clf()
plt.figure(figsize=(10,4))
plt.semilogy(ns, optimalities)
plt.semilogy(ns, [get_optimality(A, result, spectrum_to_decompose)[0]] * len(ns), linestyle="dashed", label="optimality when using all PHOENIX spectra")
plt.xticks(ns)
plt.grid()
plt.legend()
plt.xlabel("number of components considered")
plt.ylabel("Residual Mean Squared Error")  #of matrix minimisation method (lower is more optimal)
plt.title(f"Optimality vs Number of Components Used")
plt.show()

In [None]:
# now carry out an effectively identical analysis to calibration.py : input some known phoenix spectra, and shift their wavelengths back and forth until the spectra line up. see if this significantly improves things

from matplotlib import pyplot as plt
import numpy as np
from astropy import units as u
from astropy.visualization import quantity_support
from astropy.modeling import models
import scipy as sp

from spectrum_component_analyser.minimisation import calc_fitted_spectrum_from_spectral_list

quantity_support()

from spectrum_component_analyser.internals.phoenix_spectrum import phoenix_spectrum
from spectrum_component_analyser.internals.readers import JWST_NORMALISING_POINT, JWST_RESOLUTION
from spectrum_component_analyser.internals.spectral_grid import download_spectrum, get_wavelength_grid
from spectrum_component_analyser.internals.spectral_list import spectral_list
from spectrum_component_analyser.internals.spectral_component import spectral_component

number_of_components_to_keep = 5

main_components : list[spectral_component] = get_main_components(hash_map, number_of_components_to_keep)


max_roll_delta = 2
roll_resolution = .0005 * u.um

rolls = [i for i in reversed(range(-max_roll_delta, max_roll_delta))]

def get_MSE(
        shift : float, # in u.um (this isn't a quantity because scipy minimise prolly will complain)
        plot : bool = False
        ):
    shift *= u.um
    spec_list : spectral_list = spectral_list.from_internet(
        spectral_components=main_components,
        normalising_point=JWST_NORMALISING_POINT,
        observational_resolution=JWST_RESOLUTION,
        observational_wavelengths=spectrum_to_decompose.Wavelengths + shift,
        name="small spectral grid"
    )

    # can plot phoenix spectra for debugging
    # s : phoenix_spectrum
    # for s in spec_list.PhoenixSpectra:
    #     s.plot(clear=False, show=False)
    #     bb = models.BlackBody(temperature=s.T_eff)
    #     plt.plot(s.Wavelengths, (bb(s.Wavelengths) * 4 * np.pi * u.sr).to(s.Fluxes.unit) / (models.BlackBody(temperature=3500 *u.K)(JWST_NORMALISING_POINT) * 4 * np.pi * u.sr ).to(s.Fluxes.unit).value, linestyle='dashed')
    # plt.show()


    A, result = calc_fitted_spectrum_from_spectral_list(
        spec_list=spec_list,
        mask=mask,
        spectrum_to_decompose=spectrum_to_decompose,
        verbose=False
    )

    # taken from minimisation.py's plot_nicely
    if (plot):
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8), sharex=True)

        # --- First subplot: spectrum comparison ---
        ax1.set_title(star_name)

        determined_spectrum = spectrum(
            spectrum_to_decompose.Wavelengths,
            A @ result.x,
            normalised_point=None, # carry out no normalisation or resampling on the determined spectrum (as this spectrum is a sum of spectra from PHOENIX which should already be formatted in this way)
            observational_resolution=None,
            observational_wavelengths=None,
            temperature=None
        )

        ax1.plot(
            spectrum_to_decompose.Wavelengths,
            spectrum_to_decompose.Fluxes,
            label="Observational JWST spectrum"
        )
        ax1.plot(
            determined_spectrum.Wavelengths,
            determined_spectrum.Fluxes,
            label="Fitted Spectrum"
        )

        ax1.legend()

        # --- Second subplot: residuals ---
        residual = (determined_spectrum.Fluxes - spectrum_to_decompose.Fluxes) / spectrum_to_decompose.Fluxes

        ax2.plot(spectrum_to_decompose.Wavelengths, residual)
        ax2.set_ylabel(r"Residual = $\frac{\mathrm{Fitted\ Flux}-\mathrm{Observed\ Flux}}{\mathrm{Observed\ Flux}}$")
        ax2.set_xlabel("Wavelength / $\mu$m")

        plt.tight_layout()
        plt.show()

    MSE, _ = get_optimality(A, result, spectrum_to_decompose)

    print(f"roll of {shift} gave MSE = {MSE}")
    return MSE

# for roll_step in rolls:
#     MSE = get_MSE(roll_step)
#     print(f"MSE = {MSE} : shift of {roll_step * roll_resolution}")

# sp.optimize.minimize(get_MSE, [0.0], bounds=[(-.01, .01)]) # gave me -0.0003748 * u.um

# %matplotlib widget
shifted_MSE = get_MSE(-0.0003656, plot = True) # shift is tiny if it does exist (~ resolution / 3)
print(shifted_MSE)

# original_MSE = get_MSE(-0., plot = True)
# print(original_MSE)


In [None]:
# get wavelength grid and example observational spectrum, test for rolls (this is a more crude approach than the one one or two code blocks above)

phoenix_wavelengths = get_wavelength_grid()

# nearby as in: nearby in parameter space to the observed spectrum
nearby_phoenix_spectrum : phoenix_spectrum = download_spectrum(
    3400 * u.K,
    0.0,
    5.0,
    lte=True,
    alphaM=0,
    phoenix_wavelengths=phoenix_wavelengths,
    normalising_point= JWST_NORMALISING_POINT,
    observational_resolution= JWST_RESOLUTION,
    observational_wavelengths = None,
    name="nearby phoenix spectrum"
)

mask = np.isfinite(spectrum_to_decompose.Fluxes)

roll_resolution = 0.00001 * u.um

def get_MSE(roll : int = 0) -> float:
    # roll to find minimum
    placed_onto_fluxes = np.interp(spectrum_to_decompose.Wavelengths, nearby_phoenix_spectrum.Wavelengths + roll * roll_resolution, nearby_phoenix_spectrum.Fluxes) # new y = np.interp(new x | old x | old y)

    # mask _after_ rolling & interpolating
    # the jwst data always seems to have some NaN; let's remove them
    placed_onto_fluxes = placed_onto_fluxes[mask]

    residual = (placed_onto_fluxes - spectrum_to_decompose[mask].Fluxes) / spectrum_to_decompose[mask].Fluxes
    residual_mean_squared_error = np.sqrt(np.mean(residual**2))

    return residual_mean_squared_error

# how far to move the jwst spectrum to either side to check
max_roll_delta : int = 1000

MSEs = []
rolls = [i for i in reversed(range(-max_roll_delta, max_roll_delta))]

for i in rolls:
    MSEs.append(get_MSE(i))

plt.clf()
plt.cla()
plt.plot(rolls * roll_resolution, MSEs)
plt.show()

# looks to be 0.0042 um off - but fits from main.ipynb suggest moore like 0.0010 (just from eyeing it)