In [None]:
# %matplotlib widget

In [None]:
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
from itertools import product
from tqdm import tqdm
from astropy.visualization import quantity_support
quantity_support() # required for numpy to not get annoyed when doing e.g. np.stack()
from joblib import Parallel, delayed

from spectrum_component_analyser.internals.readers import read_JWST_fits_all_spectra
from spectrum_component_analyser.internals.spectrum import spectrum
from spectrum_component_analyser.internals.spectral_grid  import spectral_grid

"""
one difference between this and main.ipynb is that here the spectra are processed by appending rows into the matrix A, whereas in main.ipynb spectra are appended as _columns_ into the matrix A
"""

from itertools import product
from pathlib import Path
import numpy as np
from astropy.visualization import quantity_support
quantity_support()
import os

from spectrum_component_analyser.internals.spectrum import spectrum
from spectrum_component_analyser.internals.spectral_grid import spectral_grid
from spectrum_component_analyser.helper import calc_result, get_optimality, plot_nicely
from spectrum_component_analyser.internals.readers import read_HARPS_fits

external_spectrum_path = Path("../../assets/ADP.2016-02-04T01_02_52.843.fits")
script_dir = os.getcwd()  # usually the folder where notebook is running")
wavelength_grid_absolute_path = (script_dir / external_spectrum_path).resolve()

spectrum_to_decompose : spectrum = read_HARPS_fits(wavelength_grid_absolute_path, INTEGRATION_INDEX=0, verbose=False)
spectrum_to_decompose.plot()

mask = np.isfinite(spectrum_to_decompose.Fluxes)

spectrum_to_decompose = spectrum_to_decompose[mask]

print("reading in hdf5")
spectral_grid_relative_path = Path("../../assets/HARPS_convolved_spectral_grid.hdf5")
spectral_grid_absolute_path = (script_dir / spectral_grid_relative_path).resolve()
spec_grid : spectral_grid = spectral_grid.from_hdf5(absolute_path=spectral_grid_absolute_path)
lookup_table = spec_grid.to_lookup_table()


In [None]:

# read in spectrum to decompose - use it to determine the following 
spectrum_num_points = len(spectrum_to_decompose)

# read in data cube - use it to determine the following
number_phoenix_spectra : int = len(spec_grid.T_effs) * len(spec_grid.FeHs) * len(spec_grid.Log_gs) # number of spectra we'll use for train/val

# so each ROW in X,A is 1 graph
# and the columns of X are just the graph values (ys)
# and the columns of X are the different parameters )

# effectively we're doing the same thing as in main.ipynb, but without the horrendous bodge for working out what the weights are
# here, we're being systematic about the weights and their meaning by appending them all to a matrix (# of PHOENIX spectra rows x 3 columns) beforehand
y_curves = []
# params is a list of lists of the form [T_eff, FeH, log_g]
params = []

import astropy.units as u

def get_spectra_and_params(T_eff, FeH, log_g, mask) -> tuple[np.array, np.array]:
    fluxes = lookup_table[T_eff, FeH, log_g][mask]
    params = [T_eff.value, FeH, log_g]
    
    return fluxes, params

fluxes_and_params = Parallel(n_jobs=-1, prefer="threads")(
    delayed(get_spectra_and_params)(T_eff, FeH, log_g, mask=mask) for T_eff, FeH, log_g in tqdm(product(spec_grid.T_effs, spec_grid.FeHs, spec_grid.Log_gs), total=len(spec_grid.T_effs) * len(spec_grid.FeHs) * len(spec_grid.Log_gs), desc="appending values from data cube to y_curves, params lists")
    )

y_curves, params = zip(*fluxes_and_params)  # each element of results is (y_curve, param)

In [None]:
print(y_curves[:100])
print(params[:100])

In [None]:
# in this block we combine some y curves of different parameters to mimic spots

combination_ys = []
combination_parameters = []

number_of_combined_curves_to_generate : int = int(.1*len(y_curves))

number_of_components_to_use = 1

number_parameters : int = number_of_components_to_use * 4 # (weight, T_eff, FeH, log_g) x number_of_components_to_use, weight_a, weight_b

import random

for i in tqdm(range(number_of_combined_curves_to_generate), total=number_of_combined_curves_to_generate):
    # or could use e.g. num = random.uniform(0, 0.3) for a uniform float between 0 and 0.3
    weights = np.random.dirichlet(alpha=np.ones(number_of_components_to_use))  # shape (n,)

    weights[0] = 0.999
    weights /= np.sum(weights)
    
    y_combination = np.zeros_like(y_curves[0]) # length of the spectrum
    parameters = [] # this will end up being number_of_components_to_use lots of (weight, T_eff, FeH, log_g)
    
    for weight in weights:
        spectrum_index = random.randint(0, len(y_curves) - 1)

        y_combination += weight * y_curves[spectrum_index]
        parameters.extend([weight, *params[spectrum_index]])

    # add some noise here too
    noise_max_amplitude = max(y_combination.value) / 20
    y_combination += np.random.uniform(low=-noise_max_amplitude, high=noise_max_amplitude, size=len(y_combination)) * y_combination[0].unit 
    
    combination_ys.append(y_combination)
    combination_parameters.append(parameters)


In [None]:
display(combination_ys[:2])
display(combination_parameters[:2])

In [None]:
# X = np.stack([y_curves for spectrum in data_cube(1...number_phoenix_spectra)])
X = np.stack(combination_ys)

# A = np.stack([list_of_params_that_describe_curve_1,...,number_parameters])
A = np.stack(combination_parameters)

from sklearn.preprocessing import StandardScaler

X_scaler = StandardScaler()
X_tensor = torch.tensor(X_scaler.fit_transform(X), dtype=torch.float32)

A_scaler = StandardScaler()
A_tensor = torch.tensor(A_scaler.fit_transform(A), dtype=torch.float32)

# X_tensor = torch.tensor(X, dtype=torch.float32)
# A_tensor = torch.tensor(A, dtype=torch.float32)

training_fraction : float = 0.7
validation_fraction : float = 1.0 - training_fraction

cutoff : int = int(number_of_combined_curves_to_generate * training_fraction)
print(number_phoenix_spectra)
print(len(X_tensor))
X_train, X_val = X_tensor[:cutoff], X_tensor[cutoff:]
A_train, A_val = A_tensor[:cutoff], A_tensor[cutoff:]

display(X_train)
display(A_train)

display(X_val)
display(A_val)

In [None]:


class CurveRegressor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.ReLU(),

            nn.Linear(512, output_dim)
        )

    def forward(self, x):
        return self.net(x)
model = CurveRegressor(spectrum_num_points, number_parameters)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-4)#, weight_decay=1e-4)

max_epochs : int = 1600
progress_bar = tqdm(range(max_epochs), desc=f"Initialising...")

best_val_loss = np.inf
patience_counter = 0
patience = 10

for epoch in progress_bar:
    model.train()
    optimizer.zero_grad()

    preds = model(X_train)
    loss = criterion(preds, A_train)

    loss.backward()
    optimizer.step()

    # validation loss
    with torch.no_grad():
        model.eval()
        val_loss = criterion(model(X_val), A_val)
    
    progress_bar.set_description(f"Epoch {epoch:02d} | train loss: {loss:.4f} | val loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1
    
    if patience_counter >= patience:
        print("[ML] : early stopping")
        break

In [None]:

import matplotlib.pyplot as plt
import pandas as pd

from spectrum_component_analyser.helper import TEFF_COLUMN, FEH_COLUMN, LOGG_COLUMN, WAVELENGTH_COLUMN, FLUX_COLUMN

INTEGRATION_INDEX_COLUMN : str = "Integration Index"

WEIGHT_COLUMN : str = "Weight"

num_samples_mc = 5 # number of Monte Carlo forward passes used for MC-Dropout to estimate uncertainty.

model.train()

index = 0

results = pd.DataFrame()

all_observational_spectra = [spectrum_to_decompose]
target_name = "K2-18"

for spectrum_to_decompose in tqdm(all_observational_spectra):
    preds = []
    mask = np.isfinite(spectrum_to_decompose.Fluxes) # mask that removes np.inf values from spectra - this could be observational spectra dependent; so we have to apply this to the data cube after creating the data cube, not before
    spectrum_to_decompose = spectrum_to_decompose[mask]
    # scale fluxes by same X_scaler that we used to scale the training data
    x_input_scaled = X_scaler.transform(spectrum_to_decompose.Fluxes.reshape(1, -1))
    x_tensor = torch.tensor(x_input_scaled, dtype=torch.float32)

    # x_tensor = torch.tensor(spectrum_to_decompose.Fluxes, dtype=torch.float32)

    with torch.no_grad():
        for _ in range(num_samples_mc):
            preds.append(model(x_tensor))

    preds = torch.stack(preds) # shape [num_samples_mc, 1, number_parameters]

    mean_prediction = preds.mean(dim = 0)
    std_prediction = preds.std(dim = 0)

    mean_pred_scaled = mean_prediction.detach().numpy()
    mean_pred_physical : np.array = A_scaler.inverse_transform(mean_pred_scaled.reshape(1, -1))[0]

    std_pred_scaled = std_prediction.detach().numpy()
    std_pred_physical = std_pred_scaled * A_scaler.scale_

    n_variables = 4

    column_names = [WEIGHT_COLUMN, TEFF_COLUMN, FEH_COLUMN, LOGG_COLUMN]

    # Convert list to array
    arr = np.array(mean_pred_physical)

    # Build dictionary dynamically using the string variables
    data = {name: arr[i::n_variables] for i, name in enumerate(column_names)}

    # Create DataFrame
    result = pd.DataFrame(data)
    result[INTEGRATION_INDEX_COLUMN] = index
    result = result.sort_values(by=column_names[0], ascending=False)

    print(result)

    plt.scatter(x=result[TEFF_COLUMN][0], y=result[WEIGHT_COLUMN][0], marker=">", color=plt.cm.Spectral_r(index/len(all_observational_spectra)), alpha=0.8)
    # plt.scatter(x=result[TEFF_COLUMN][1], y=result[WEIGHT_COLUMN][1], marker="x", color=plt.cm.Spectral_r(index/len(all_observational_spectra)), alpha=0.8)
    # plt.scatter(x=result[TEFF_COLUMN][2], y=result[WEIGHT_COLUMN][2], marker="+", color=plt.cm.Spectral_r(index/len(all_observational_spectra)), alpha=0.8)
    # plt.scatter(x=result[TEFF_COLUMN][3], y=result[WEIGHT_COLUMN][3], marker="o", color=plt.cm.Spectral_r(index/len(all_observational_spectra)), alpha=0.8)
    # plt.scatter(x=df["T_eff"][4], y=df["Weight"][4], marker="v", color=plt.cm.Spectral_r(index/len(all_observational_spectra)), alpha=0.8)
    index += 1

    results = pd.concat([results, result], ignore_index=True)

plt.title(f"JWST observational spectra decomposition for {target_name} for all integration indices")
plt.xlabel("$T_\mathrm{eff}$ / K")
plt.ylabel("Weight")
plt.show()

there are 3 ways to turn our fitted parameters into graphs that we calc a residual from

1. interpolate onto PHOENIX grid
2. interpolate phoenix graph onto the found parameters
3. make the model be a classifier that chooses weights from all the PHOENIX grid options (so make the ML output be discrete)

In [None]:
%matplotlib widget

In [None]:
# 1. interpolate onto PHOENIX grid

# lets just do the 1st integration index for now
display(results)

INTEGRATION_INDEX = 0
print(all_observational_spectra[INTEGRATION_INDEX].Fluxes[mask])

result_subset = results[results[INTEGRATION_INDEX_COLUMN] == INTEGRATION_INDEX]
closest_T_eff = min(spec_grid.T_effs, key=lambda x: abs(x.value - result_subset[TEFF_COLUMN][0]))
closest_FeH = min(spec_grid.FeHs, key=lambda x: abs(x - result_subset[FEH_COLUMN][0]))
closest_log_g = min(spec_grid.Log_gs, key=lambda x: abs(x - result_subset[LOGG_COLUMN][0]))

print(closest_T_eff, closest_FeH, closest_log_g)

# get ML prediction from phoenix data
normalised_phoenix_spectrum = spec_grid.get_spectrum(closest_T_eff, closest_FeH, closest_log_g)[mask]

observational_spectrum = all_observational_spectra[INTEGRATION_INDEX][mask]

plt.clf()
plt.plot(observational_spectrum.Wavelengths, observational_spectrum.Fluxes, label=f"observed JWST spectrum for target {target_name}")
plt.plot(normalised_phoenix_spectrum.Wavelengths, normalised_phoenix_spectrum.Fluxes, label="found spectrum from ML")
plt.legend()
plt.show()

In [None]:
residual = (observational_spectrum.Fluxes - normalised_phoenix_spectrum.Fluxes) / observational_spectrum.Fluxes

plt.clf()
plt.plot(observational_spectrum.Wavelengths, residual)
plt.title("(observation - prediction) / observation")
plt.show()