In [None]:
# %matplotlib widget

In [None]:
# for now, just request a random composite spectrum from facula_and_spot_creator
# and try to decompose it - aka can we regenerate the w's

# eventually can read in external data or some training data from a large hdf5 file etc

from itertools import product
from pathlib import Path
import astropy
from astropy.table import QTable
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import scipy as sp
from astropy.visualization import quantity_support
quantity_support()
from tqdm import tqdm
import astropy.units as u
from scipy.interpolate import interp1d
from astropy.units import Quantity
from joblib import Parallel, delayed
import os

from spots_and_faculae_model.spectrum import spectrum
from spots_and_faculae_model.readers import read_JWST_fits
from spots_and_faculae_model.simpler_spectral_grid import simpler_spectral_grid

external_spectrum_path = Path("../../assets/MAST_2025-10-26T11_57_04.058Z - LTT-3780/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg001_nis_x1dints.fits")
script_dir = os.getcwd()  # usually the folder where notebook is running")
wavelength_grid_absolute_path = (script_dir / external_spectrum_path).resolve()

spectrum_to_decompose : spectrum = read_JWST_fits(wavelength_grid_absolute_path, INTEGRATION_INDEX=1)
spectrum_to_decompose.plot()
print(spectrum_to_decompose)

mask = np.isfinite(spectrum_to_decompose.Fluxes)

spectrum_to_decompose = spectrum_to_decompose[mask]

print("reading in hdf5")
spectral_grid_relative_path = Path("../../assets/new_spectral_grid.hdf5")
spectral_grid_absolute_path = (script_dir / spectral_grid_relative_path).resolve()
spec_grid : simpler_spectral_grid = simpler_spectral_grid.from_hdf5(absolute_path=spectral_grid_absolute_path)
lookup_table = spec_grid.to_lookup_table()
print("finished reading in hdf5")

print(lookup_table[2500 *u.K, 0.0 * u.dimensionless_unscaled, 0.0 * u.dimensionless_unscaled])

In [None]:
def get_optimality(A, result):
    determined_spectrum = spectrum(spectrum_to_decompose.Wavelengths, A @ result.x)
    residual = (determined_spectrum.Fluxes - spectrum_to_decompose.Fluxes) / spectrum_to_decompose.Fluxes
    rmse = np.sqrt(np.mean(residual**2))
    rss  = np.sum(residual**2)

    return rmse, rss

In [None]:
from typing import Sequence, Tuple
from scipy.optimize._optimize import OptimizeResult

def calc_result(parameter_space, lookup_table, total_number_of_components : int = None, verbose : bool = True) -> Tuple[np.ndarray, OptimizeResult]:
    A = np.empty((0, 0))

    def force_to_janskys(T_eff : Quantity, FeH : Quantity, log_g : Quantity, wavelengths : Sequence[Quantity], mask):
        fluxes = lookup_table[T_eff, FeH, log_g]
        return fluxes.to(u.Jy, equivalencies=u.spectral_density(wavelengths))[mask]

    normalised_and_converted_spectral_components : list[list[Quantity]] = Parallel(n_jobs=-1, prefer="threads")(
        delayed(force_to_janskys)(T_eff, FeH, log_g, spec_grid.Wavelengths, mask) for T_eff, FeH, log_g in tqdm(parameter_space, total=total_number_of_components, desc="Appending values to A matrix...", disable=not verbose)
    )

    A = np.column_stack(normalised_and_converted_spectral_components)

    if verbose:
        print("minimising")
    
    # assume that w \in [0,1] : but I think this will only be true for real data if normalisation has been done correctly (???)
    result : OptimizeResult = sp.optimize.lsq_linear(A, [i.value for i in spectrum_to_decompose.Fluxes], bounds = (0, 1), verbose = 2 if verbose else 0)#, max_iter=2), tol=1e-10, lsmr_tol=1e-5)
    
    if verbose:
        print(result)
        print(f"sum of weights={np.sum(result.x)}")

    return A, result

all_parameters = list(product(spec_grid.T_effs, spec_grid.FeHs, spec_grid.Log_gs))

total_number_of_components = len(spec_grid.T_effs) * len(spec_grid.FeHs) * len(spec_grid.Log_gs)

A, result = calc_result(all_parameters, lookup_table, total_number_of_components)
print(f"residual MSE = {get_optimality(A, result)}")

In [None]:

# # # plot some data # # #
# dependent on the old spectrum_grid class, but its fine for now (and its just dependent on some arbitrary strings anyway)
from spots_and_faculae_model.spectrum_grid import TEFF_COLUMN, FEH_COLUMN, LOGG_COLUMN
WEIGHT_COLUMN : str = "weight"

def plot_nicely(A, result, parameter_space):
    result_map = {}
    i = 0
    for (T_eff, FeH, log_g) in parameter_space:
        key = (T_eff, FeH, log_g)
        result_map[key] = i
        i += 1

    hash_map = pd.DataFrame(columns=[TEFF_COLUMN, FEH_COLUMN, LOGG_COLUMN, WEIGHT_COLUMN])
    
    for (T_eff, FeH, log_g) in parameter_space:
        new_row = {TEFF_COLUMN: T_eff, FEH_COLUMN: FeH, LOGG_COLUMN: log_g, WEIGHT_COLUMN: result.x[result_map[(T_eff, FeH, log_g)]]}
        hash_map = pd.concat([hash_map, pd.DataFrame([new_row])], ignore_index=True)

    print(hash_map.sort_values(WEIGHT_COLUMN, ascending=False).head(10).round(3))

    fig, axes = plt.subplots(4, 4, figsize=(15, 15), sharex=True, sharey=True)
    axes = axes.ravel()
    for i, log_g in enumerate(spec_grid.Log_gs):
        subset = hash_map[hash_map[LOGG_COLUMN] == log_g]
        x_vals = [a.value for a in subset[TEFF_COLUMN]]
        y_vals = subset[FEH_COLUMN]
        z_vals = subset[WEIGHT_COLUMN]

        sc = axes[i].scatter(x_vals, y_vals, c=z_vals**.2, cmap='plasma', vmin=0, vmax=1)

        axes[i].set_title(f"log_g={log_g}")
        axes[i].set_xlabel("Temperature / K")
        axes[i].set_ylabel("FeHs / relative to solar")
        # axes[i].set_xticks(np.arange(np.min(T_effs) / u.K, np.max(T_effs) / u.K + 1, 50) * u.K)
        # axes[i].grid()

    STAR_NAME : str = "TRAPPIST-1"
    cbar = fig.colorbar(sc, ax=axes, orientation='vertical', fraction=0.05, pad=0.04)
    cbar.set_label("Weights")
    fig.suptitle(STAR_NAME)
    plt.show()

    plt.figure(figsize=(12,8))
    plt.title(STAR_NAME)
    
    determined_spectrum = spectrum(spectrum_to_decompose.Wavelengths, A @ result.x)
    plt.plot(spectrum_to_decompose.Wavelengths, spectrum_to_decompose.Fluxes, label="observational JWST spectrum")
    plt.plot(determined_spectrum.Wavelengths, determined_spectrum.Fluxes, label="numerically found solution (sum of spectral components)")

    plt.legend()
    plt.show()

    residual = (determined_spectrum.Fluxes - spectrum_to_decompose.Fluxes) / spectrum_to_decompose.Fluxes
    plt.clf()
    plt.plot(spectrum_to_decompose.Wavelengths, residual)
    plt.show()

    return hash_map

hash_map = plot_nicely(A, result, all_parameters)

In [None]:
# now lets re-run that, but with only the top few components and see if the fit is better

# prolly want a new class for this lmao
# or could be a list of phoenix spectra; maybe that would help the above code be a bit neater too

ns = np.arange(1, 20, 1)
optimalities = np.array([])

def get_main_components(hash_map, number_of_top_components_to_use : int) -> list[Tuple[Quantity, Quantity, Quantity]]:
    main_components : list[(Quantity, Quantity, Quantity)] = []

    for _, row in hash_map.sort_values(WEIGHT_COLUMN, ascending=False)[0:number_of_components_to_keep].iterrows():
        main_components.append((row[TEFF_COLUMN], row[FEH_COLUMN], row[LOGG_COLUMN]))
    
    return main_components

for number_of_components_to_keep in tqdm(ns):
    
    main_components = get_main_components(hash_map, number_of_components_to_keep)

    A_restricted, result_restricted = calc_result(main_components, lookup_table, len(main_components), verbose=False)

    optimalities = np.append(optimalities, get_optimality(A_restricted, result_restricted)[0])

# plt.clf()
plt.figure(figsize=(10,4))
plt.semilogy(ns, optimalities)
plt.semilogy(ns, [get_optimality(A, result)[0]] * len(ns), linestyle="dashed", label="optimality when using all PHOENIX spectra")
plt.xticks(ns)
plt.grid()
plt.legend()
plt.xlabel("number of components considered")
plt.ylabel("optimality of matrix minimisation method (lower is more optimal)")
plt.title(f"using the top N components from the minimisation of all the spectra.")
plt.show()

In [None]:
main_components = get_main_components(hash_map, 7)

shifted_lookup_table = {k: v.copy() for k, v in lookup_table.items()}

for key, flux in shifted_lookup_table.items():
    flux[:] = np.roll(flux, 0) # need the [:] for in place modification i.e. to change the reference

A_restricted, result_restricted = calc_result(main_components, shifted_lookup_table, len(main_components), verbose=False)

print(f"Residual MSE = {get_optimality(A_restricted, result_restricted)[0]}")

_ = plot_nicely(A_restricted, result_restricted, main_components)

In [None]:
# can try another meta-optimiser: shift the spectrum between [-20,+20] or smthn resolutions to the left & right
# there might be some small zero error on the phoenix grid or smthn (I wouldn't be suprised) - this would help check for that

# from experimenting: no rolling is best
main_comps = get_main_components(hash_map, 10)

optimalities = []

# independent copy - we dont want the shift to affect the original lookup_table
shifted_lookup_table = {k: v.copy() for k, v in lookup_table.items()}

shifts = np.arange(-10,10,1)

for key, flux in shifted_lookup_table.items():
    flux[:] = np.roll(flux, np.min(shifts)) # need the [:] for in place modification i.e. to change the reference

A, result = calc_result(main_comps, shifted_lookup_table, len(main_comps), verbose=False)
optimalities.append(get_optimality(A, result)[0])

positive_shifts = shifts + np.abs(np.min(shifts))

for shift in tqdm(positive_shifts[0:-1]):
    for key, flux in shifted_lookup_table.items():
        flux[:] = np.roll(flux, 1) # need the [:] for in place modification i.e. to change the reference
    A, result = calc_result(main_comps, shifted_lookup_table, len(main_comps), verbose=False)
    optimalities.append(get_optimality(A, result)[0])

    # plot_nicely(A, result, main_comps)

plt.clf()
plt.semilogy(shifts, optimalities)
plt.xlabel("roll / shift")
plt.ylabel("optimality of solution (using top 10 best components only)")
plt.show()

# now we can fit in sections too

# can also interpolate in different ways

# also can visualise other spectra; see what works

# also can use other spectra from lalitha