In [None]:
# for now, just request a random composite spectrum from facula_and_spot_creator
# and try to decompose it - aka can we regenerate the w's

# eventually can read in external data or some training data from a large hdf5 file etc

from itertools import product
from pathlib import Path
import astropy
from astropy.table import QTable
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import scipy as sp
from astropy.visualization import quantity_support
quantity_support()
from tqdm import tqdm
import astropy.units as u
from scipy.interpolate import interp1d
from astropy.units import Quantity
from joblib import Parallel, delayed
import os

from spots_and_faculae_model.spectrum import spectrum
from spots_and_faculae_model.readers import read_JWST_fits
from spots_and_faculae_model.simpler_spectral_grid import simpler_spectral_grid

external_spectrum_path = Path("../../assets/MAST_2025-10-26T11_57_04.058Z - LTT/MAST_2025-10-26T11_57_04.058Z/JWST/jw03557004001_04101_00001-seg001_nis_x1dints.fits")
script_dir = os.getcwd()  # usually the folder where notebook is running")
wavelength_grid_absolute_path = (script_dir / external_spectrum_path).resolve()

spectrum_to_decompose : spectrum = read_JWST_fits(wavelength_grid_absolute_path)

print(spectrum_to_decompose)

mask = np.isfinite(spectrum_to_decompose.Fluxes)

spectrum_to_decompose = spectrum_to_decompose[mask]

print("reading in hdf5")
spectral_grid_relative_path = Path("../../assets/new_spectral_grid.hdf5")
spectral_grid_absolute_path = (script_dir / spectral_grid_relative_path).resolve()
spec_grid : simpler_spectral_grid = simpler_spectral_grid.from_hdf5(absolute_path=spectral_grid_absolute_path)
lookup_table = spec_grid.to_lookup_table()
print("finished reading in hdf5")

print(lookup_table[2500 *u.K, 0.0 * u.dimensionless_unscaled, 0.0 * u.dimensionless_unscaled])

In [None]:
from typing import Sequence

A = np.empty((0, 0))

def force_to_janskys(T_eff : Quantity, FeH : Quantity, log_g : Quantity, wavelengths : Sequence[Quantity], mask):
    fluxes = lookup_table[T_eff, FeH, log_g]
    fluxes = (fluxes / fluxes[0].unit) * u.Jy
    return fluxes.to(u.Jy, equivalencies=u.spectral_density(wavelengths))[mask]

normalised_and_converted_spectral_components : list[list[Quantity]] = Parallel(n_jobs=-1, prefer="threads")(
    delayed(force_to_janskys)(T_eff, FeH, log_g, spec_grid.Wavelengths, mask) for T_eff, FeH, log_g in tqdm(product(spec_grid.T_effs, spec_grid.FeHs, spec_grid.Log_gs), total=len(spec_grid.T_effs) * len(spec_grid.FeHs) * len(spec_grid.Log_gs), desc="Appending values to A matrix...")
)

A = np.column_stack(normalised_and_converted_spectral_components)

In [None]:
print("minimising")
# assume that w \in [0,1] : but I think this will only be true for real data if normalisation has been done correctly (???)
result = sp.optimize.lsq_linear(A, [i.value for i in spectrum_to_decompose.Fluxes], bounds = (0, 1), verbose = 2)#, max_iter=600, tol=1e-10, lsmr_tol=1e-5)
print(result)
print(f"sum of weights={np.sum(result.x)}")

In [None]:

# # # plot some data # # #

from spots_and_faculae_model.spectrum_grid import TEFF_COLUMN, FEH_COLUMN, LOGG_COLUMN
result_map = {}
i = 0
for T_eff, FeH, log_g in product(spec_grid.T_effs, spec_grid.FeHs, spec_grid.Log_gs):
    key = (T_eff, FeH, log_g)
    result_map[key] = i
    i += 1

WEIGHT_COLUMN : str = "weight"

hash_map = pd.DataFrame(columns=[TEFF_COLUMN, FEH_COLUMN, LOGG_COLUMN, WEIGHT_COLUMN])

for T_eff, FeH, log_g in product(spec_grid.T_effs, spec_grid.FeHs, spec_grid.Log_gs):
    new_row = {TEFF_COLUMN: T_eff, FEH_COLUMN: FeH, LOGG_COLUMN: log_g, WEIGHT_COLUMN: result.x[result_map[(T_eff, FeH, log_g)]]}
    hash_map = pd.concat([hash_map, pd.DataFrame([new_row])], ignore_index=True)

print(hash_map.sort_values(WEIGHT_COLUMN, ascending=False).head(10).round(3))

fig, axes = plt.subplots(4, 4, figsize=(15, 15), sharex=True, sharey=True)
axes = axes.ravel()
for i, log_g in enumerate(spec_grid.Log_gs):
    subset = hash_map[hash_map[LOGG_COLUMN] == log_g]
    x_vals = [a.value for a in subset[TEFF_COLUMN]]
    y_vals = subset[FEH_COLUMN]
    z_vals = subset[WEIGHT_COLUMN]

    sc = axes[i].scatter(x_vals, y_vals, c=z_vals**.2, cmap='plasma', vmin=0, vmax=1)

    axes[i].set_title(f"log_g={log_g}")
    axes[i].set_xlabel("Temperature / K")
    axes[i].set_ylabel("FeHs / relative to solar")
    # axes[i].set_xticks(np.arange(np.min(T_effs) / u.K, np.max(T_effs) / u.K + 1, 50) * u.K)
    # axes[i].grid()

STAR_NAME : str = "TRAPPIST-1"
cbar = fig.colorbar(sc, ax=axes, orientation='vertical', fraction=0.05, pad=0.04)
cbar.set_label("Weights")
fig.suptitle(STAR_NAME)
plt.show()

plt.figure(figsize=(12,8))
plt.title(STAR_NAME)
print(spectrum_to_decompose.Fluxes[0].unit)
print(A[0][0].unit)
determined_spectrum = spectrum(spectrum_to_decompose.Wavelengths, A @ result.x)
plt.plot(spectrum_to_decompose.Wavelengths, spectrum_to_decompose.Fluxes, label="experimental spectrum")
plt.plot(determined_spectrum.Wavelengths, determined_spectrum.Fluxes, label="numerically found solution (sum of spectral components)")

plt.legend()
plt.show()