In [None]:
# %matplotlib widget

In [None]:
# for now, just request a random composite spectrum from facula_and_spot_creator
# and try to decompose it - aka can we regenerate the w's

# eventually can read in external data or some training data from a large hdf5 file etc

from itertools import product
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt
from astropy.visualization import quantity_support
quantity_support()
from tqdm import tqdm
import astropy.units as u
import os

from spectrum_component_analyser.internals.spectrum import spectrum
from spectrum_component_analyser.internals.readers import read_JWST_fits,read_JWST_fits_all_spectra
from spectrum_component_analyser.internals.spectral_grid import spectral_grid
from spectrum_component_analyser.helper import calc_fitted_spectrum, get_optimality, plot_nicely, get_main_components

script_dir = os.getcwd()

jwst_file_segment_001 = (script_dir / Path("../../observed_spectra/MAST_2025-10-26T11_57_04.058Z - LTT-3780/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg001_nis_x1dints.fits")).resolve()
jwst_file_segment_002 = (script_dir / Path("../../observed_spectra/MAST_2025-10-26T11_57_04.058Z - LTT-3780/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg002_nis_x1dints.fits")).resolve()
jwst_file_segment_003 = (script_dir / Path("../../observed_spectra/MAST_2025-10-26T11_57_04.058Z - LTT-3780/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg003_nis_x1dints.fits")).resolve()
jwst_file_segment_004 = (script_dir / Path("../../observed_spectra/MAST_2025-10-26T11_57_04.058Z - LTT-3780/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg004_nis_x1dints.fits")).resolve()

spectrum_to_decompose : spectrum = read_JWST_fits(jwst_file_segment_001, INTEGRATION_INDEX=100, name="LTT 3780")

mask = np.isfinite(spectrum_to_decompose.Fluxes) & (spectrum_to_decompose.Wavelengths < 2.00 * u.um)

spectrum_to_decompose = spectrum_to_decompose[mask]

spectrum_to_decompose.plot()

print("[SPECTRUM COMPONENT ANALYSER] : reading in hdf5")
# spectral_grid_relative_path = Path("../../spectral_grids/JWST_convolved_spectral_grid.hdf5")
spectral_grid_relative_path = Path("../../spectral_grids/JWST_convolved_not_oversmoothed.hdf5")
spectral_grid_absolute_path = (script_dir / spectral_grid_relative_path).resolve()
spec_grid : spectral_grid = spectral_grid.from_hdf5(absolute_path=spectral_grid_absolute_path)
lookup_table = spec_grid.to_lookup_table()
print("[SPECTRUM COMPONENT ANALYSER] : finished reading in hdf5")

In [None]:
# lets analyse the JWST fits file: lets make a transmission curve
from scipy.signal import medfilt

from spectrum_component_analyser.internals.readers import JWST_normalising_point

all_segments = [jwst_file_segment_001, jwst_file_segment_002, jwst_file_segment_003, jwst_file_segment_004]

all_spectra = []

for segment in all_segments:
    all_spectra.extend(read_JWST_fits_all_spectra(segment, name="LTT 3780"))

total_fluxes = []
for spec in all_spectra:
    # this looks to remove quite a lot of information
    spec.Wavelengths = spec.Wavelengths[mask]
    spec.Fluxes = medfilt(spec.Fluxes[mask], kernel_size=3) * spec.Fluxes.unit
    total_fluxes.append(np.sum(spec.Fluxes).value)

    # less harsh option
    # total_fluxes.append(np.sum(spec.Fluxes[mask]).value)

plt.clf()
plt.plot([i for i in range(len(total_fluxes))], total_fluxes)
plt.xlabel("Integration Index")
plt.ylabel("Total Flux / arbitrary units")
plt.title("Transit Light Curve for LTT 3780")
plt.show()

min_included_integration_index = 0
max_included_integration_index = 250

all_fluxes=[spec.Fluxes for spec in all_spectra[min_included_integration_index:max_included_integration_index]]

spectrum_to_decompose = spectrum(wavelengths=all_spectra[0].Wavelengths,
                                 fluxes=np.mean(all_fluxes, axis=0) * all_fluxes[0].unit,
                                 normalised_point=JWST_normalising_point,
                                 observational_resolution=None,
                                 observational_wavelengths=None,
                                 name="averaged LTT 3780")

# spectrum_to_decompose = spectrum_to_decompose[mask]

spectrum_to_decompose.plot()

In [None]:
all_parameters = list(product(spec_grid.T_effs, [0 * u.dimensionless_unscaled], [5 * u.dimensionless_unscaled]))
# all_parameters = list(product(spec_grid.T_effs, spec_grid.FeHs, [5 * u.dimensionless_unscaled]))
# all_parameters = list(product(spec_grid.T_effs, spec_grid.FeHs, spec_grid.Log_gs))

total_number_of_components = len(all_parameters)

A, result = calc_fitted_spectrum(all_parameters,
                                 lookup_table,
                                 spec_grid=spec_grid,
                                 spectrum_to_decompose=spectrum_to_decompose,
                                 mask=mask,
                                 total_number_of_components=total_number_of_components,
                                 max_iterations=1000)
print(f"residual MSE = {get_optimality(A, result, spectrum_to_decompose)}")

In [None]:
from spectrum_component_analyser.helper import FEH_COLUMN, LOGG_COLUMN, TEFF_COLUMN, WEIGHT_COLUMN

hash_map = plot_nicely(A, result, all_parameters, spec_grid, spectrum_to_decompose)

weights = hash_map[WEIGHT_COLUMN].values

teff_avg = np.average(hash_map[TEFF_COLUMN], weights=weights)
feh_avg  = np.average(hash_map[FEH_COLUMN],  weights=weights)
logg_avg = np.average(hash_map[LOGG_COLUMN], weights=weights)

print(teff_avg)
print(feh_avg)
print(logg_avg)

In [None]:
# now lets re-run that, but with only the top few components and see if the fit is better

# prolly want a new class for this
# or could be a list of phoenix spectra; maybe that would help the above code be a bit neater too

ns = np.arange(1, 20, 1)
optimalities = np.array([])

for number_of_components_to_keep in tqdm(ns):
    
    main_components = get_main_components(hash_map, number_of_components_to_keep)

    A_restricted, result_restricted = calc_fitted_spectrum(main_components,
                                                  lookup_table,
                                                  spectrum_to_decompose=spectrum_to_decompose,
                                                  spec_grid = spec_grid,
                                                  mask=mask,
                                                  total_number_of_components=len(main_components),
                                                  verbose=False)

    optimalities = np.append(optimalities, get_optimality(A_restricted, result_restricted, spectrum_to_decompose)[0])

    # _ = plot_nicely(A_restricted, result_restricted, main_components, spec_grid, spectrum_to_decompose)

# plt.clf()
plt.figure(figsize=(10,4))
plt.semilogy(ns, optimalities)
plt.semilogy(ns, [get_optimality(A, result, spectrum_to_decompose)[0]] * len(ns), linestyle="dashed", label="optimality when using all PHOENIX spectra")
plt.xticks(ns)
plt.grid()
plt.legend()
plt.xlabel("number of components considered")
plt.ylabel("Residual Mean Squared Error")  #of matrix minimisation method (lower is more optimal)
plt.title(f"Optimality vs Number of Components Used")
plt.show()

In [None]:
main_components = get_main_components(hash_map, 7)

shifted_lookup_table = {k: v.copy() for k, v in lookup_table.items()}

for key, flux in shifted_lookup_table.items():
    flux[:] = np.roll(flux, 0) # need the [:] for in place modification i.e. to change the reference

A_restricted, result_restricted = calc_fitted_spectrum(main_components,
                                              shifted_lookup_table,
                                              spectrum_to_decompose=spectrum_to_decompose,
                                              spec_grid=spec_grid,
                                              mask=mask,
                                              total_number_of_components=len(main_components),
                                              verbose=False)

print(f"Residual MSE = {get_optimality(A_restricted, result_restricted, spectrum_to_decompose)[0]}")

_ = plot_nicely(A_restricted, result_restricted, main_components, spec_grid, spectrum_to_decompose)

In [None]:
# can try another meta-optimiser: shift the spectrum between [-20,+20] or smthn resolutions to the left & right
# there might be some small zero error on the phoenix grid or smthn (I wouldn't be suprised) - this would help check for that

# from experimenting: no rolling is best
main_comps = get_main_components(hash_map, 10)

optimalities = []

# independent copy - we dont want the shift to affect the original lookup_table
shifted_lookup_table = {k: v.copy() for k, v in lookup_table.items()}

shifts = np.arange(-10,10,1)

for key, flux in shifted_lookup_table.items():
    flux[:] = np.roll(flux, np.min(shifts)) # need the [:] for in place modification i.e. to change the reference

A, result = calc_fitted_spectrum(main_comps,
                        shifted_lookup_table,
                        spectrum_to_decompose=spectrum_to_decompose,
                        spec_grid=spec_grid,
                        mask=mask,
                        total_number_of_components=len(main_comps),
                        verbose=False)

optimalities.append(get_optimality(A, result, spectrum_to_decompose)[0])

positive_shifts = shifts + np.abs(np.min(shifts))

for shift in tqdm(positive_shifts[0:-1]):
    for key, flux in shifted_lookup_table.items():
        flux[:] = np.roll(flux, 1) # need the [:] for in place modification i.e. to change the reference
    A, result = calc_fitted_spectrum(main_comps,
                            shifted_lookup_table,
                            spectrum_to_decompose=spectrum_to_decompose,
                            spec_grid=spec_grid,
                            mask=mask,
                            total_number_of_components=len(main_comps),
                            verbose=False)
    optimalities.append(get_optimality(A, result, spectrum_to_decompose)[0])

    # plot_nicely(A, result, main_comps)

plt.clf()
plt.semilogy(shifts, optimalities)
plt.xlabel("roll / shift")
plt.ylabel("optimality of solution (using top 10 best components only)")
plt.show()

# now we can fit in sections too

# can also interpolate in different ways

# also can visualise other spectra; see what works

# also can use other spectra from lalitha