# StellarRegGAE

# Stellar Parameter Regression with MLP-ResNet

We construct a multitask regressor that maps high-resolution stellar spectra to effective temperature (Teff), metallicity ([Fe/H]), and surface gravity (log g). The notebook lays out an end-to-end workflow that can be reproduced or adapted for collaborative astrophysical modeling projects.

The narrative walks through preparing the data products, engineering spectroscopy-aware augmentations, configuring a residual multilayer perceptron, and interpreting the resulting predictions with diagnostics so the full methodology can be shared confidently in a public repository.

In [None]:
# ===============================
# Standard library
# ===============================
import random
from pathlib import Path
from typing import Dict, Optional, Tuple, Literal  # type hints

# ===============================
# Third-party libraries
# ===============================
import numpy as np  # numerical computing
import matplotlib.pyplot as plt  # plotting

from scipy.stats import gaussian_kde, norm, multivariate_normal  # statistics & densities

from sklearn.metrics import mean_absolute_error, r2_score  # evaluation metrics
from sklearn.preprocessing import StandardScaler, RobustScaler  # feature scaling
from sklearn.decomposition import PCA  # dimensionality reduction

import tensorflow as tf  # deep learning backend
from keras import layers, Model, regularizers  # Keras layers, base Model, regularization
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau  # training utilities

import keras_tuner as kt  # hyperparameter tuning

import joblib  # model/pipeline persistence


## Data Loading
The experiments rely on pre-generated NumPy `.npy` tensors stored inside the `datasets/` directory.

- `X_raw_*` arrays hold continuum-normalized flux values sampled on a common wavelength grid; each row represents a single stellar spectrum and each column corresponds to a wavelength bin.
- `y_raw_*` arrays store the associated stellar labels stacked as `[Teff, [Fe/H], log g]` in their native physical units (Kelvin and dex).

Place the training and test splits in `datasets/` before running the notebook. If you generate alternative calibrations or splits, update the filenames in the next cell to point to the new assets.

In [None]:
DATA_DIR = Path('D:\IC\programas_IC_organizados\problema_2\datasets')

if not DATA_DIR.exists():
    raise FileNotFoundError(
        "Dataset directory 'datasets' was not found. Place the required .npy files in this folder."
    )

X_raw_train = np.load(DATA_DIR / 'X_raw_resampled_filtered_train.npy')
y_raw_train = np.load(DATA_DIR / 'y_raw_loaded_filtered_train.npy')

X_raw_test = np.load(DATA_DIR / 'X_raw_resampled_filtered_test.npy')
y_raw_test = np.load(DATA_DIR / 'y_raw_loaded_filtered_test.npy')


In [None]:
print(f'X_raw_train shape: {X_raw_train.shape}')
print(f'y_raw_train shape: {y_raw_train.shape}')
print(f'X_raw_test shape: {X_raw_test.shape}')
print(f'y_raw_test shape: {y_raw_test.shape}')

print(f'Teff range in y_raw_train: {y_raw_train[:, 0].min():.2f} - {y_raw_train[:, 0].max():.2f}')
print(f'[Fe/H] range in y_raw_train: {y_raw_train[:, 1].min():.2f} - {y_raw_train[:, 1].max():.2f}')
print(f'log g range in y_raw_train: {y_raw_train[:, 2].min():.2f} - {y_raw_train[:, 2].max():.2f}')
print(f'Teff range in y_raw_test: {y_raw_test[:, 0].min():.2f} - {y_raw_test[:, 0].max():.2f}')
print(f'[Fe/H] range in y_raw_test: {y_raw_test[:, 1].min():.2f} - {y_raw_test[:, 1].max():.2f}')
print(f'log g range in y_raw_test: {y_raw_test[:, 2].min():.2f} - {y_raw_test[:, 2].max():.2f}')


## Preprocessing

In [None]:
def filter_stellar_sample(X, y, plot_hr_diagram=True, return_indices=False):
    """
    Apply broad physical sanity checks and IQR-based outlier removal to labels and fluxes.

    This keeps stars from different evolutionary stages while discarding obvious outliers.
    """
    teff = y[:, 0]
    feh = y[:, 1]
    logg = y[:, 2]

    print("Applying physical sanity filters...")
    print(f"Initial sample size: {len(y)}")

    total = len(y)
    mask = np.ones(total, dtype=bool)

    temp_mask = (teff >= 2500) & (teff <= 10000)
    mask &= temp_mask
    print(f"After temperature sanity check (2500 <= Teff <= 10000 K): {np.sum(mask)}")

    logg_mask = (logg >= 0.0) & (logg <= 6.0)
    mask &= logg_mask
    print(f"After surface gravity sanity check (0 <= log g <= 6): {np.sum(mask)}")

    metallicity_mask = (feh >= -5.0) & (feh <= 1.0)
    mask &= metallicity_mask
    print(f"After metallicity sanity check (-5.0 <= [Fe/H] <= +1.0): {np.sum(mask)}")

    def iqr_outlier_mask(values, base_mask, multiplier=1.5):
        subset = values[base_mask]
        if subset.size < 2:
            return np.ones_like(values, dtype=bool)
        q1 = np.percentile(subset, 25)
        q3 = np.percentile(subset, 75)
        iqr = q3 - q1
        if iqr == 0:
            return np.ones_like(values, dtype=bool)
        lower = q1 - multiplier * iqr
        upper = q3 + multiplier * iqr
        return (values >= lower) & (values <= upper)

    teff_outlier_mask = iqr_outlier_mask(teff, mask)
    feh_outlier_mask = iqr_outlier_mask(feh, mask)
    logg_outlier_mask = iqr_outlier_mask(logg, mask)

    mask &= teff_outlier_mask
    print(f"After IQR temperature outlier removal: {np.sum(mask)}")
    mask &= feh_outlier_mask
    print(f"After IQR metallicity outlier removal: {np.sum(mask)}")
    mask &= logg_outlier_mask
    print(f"After IQR gravity outlier removal: {np.sum(mask)}")

    def flux_iqr_mask(spectra, base_mask, multiplier=1.5, tol_frac=0.01):
        subset = spectra[base_mask]
        if subset.size == 0:
            return np.ones(spectra.shape[0], dtype=bool)
        q1 = np.percentile(subset, 25); q3 = np.percentile(subset, 75)
        iqr = q3 - q1
        if iqr == 0:
            return np.ones(spectra.shape[0], dtype=bool)
        lower, upper = q1 - multiplier*iqr, q3 + multiplier*iqr
        within = (spectra >= lower) & (spectra <= upper)
        # keep if at least (1 - tol_frac) pixels are within bounds
        return (within.sum(axis=1) / spectra.shape[1]) >= (1 - tol_frac)


    flux_outlier_mask = flux_iqr_mask(X, mask)
    mask &= flux_outlier_mask
    print(f"After IQR flux outlier removal: {np.sum(mask)}")

    finite_mask = np.isfinite(X).all(axis=1) & np.isfinite(y).all(axis=1)
    mask &= finite_mask
    print(f"After removing rows with non-finite values: {np.sum(mask)}")

    retained = int(np.sum(mask))
    retention_rate = (retained / total * 100.0) if total else 0.0
    print(f"Final sample after quality filtering: {retained} stars")
    print(f"Retention rate: {retention_rate:.1f}%")

    if plot_hr_diagram:
        plot_hr_filtering(teff, logg, mask)

    X_filtered = X[mask]
    y_filtered = y[mask]

    if return_indices:
        indices = np.where(mask)[0]
        return X_filtered, y_filtered, indices

    return X_filtered, y_filtered


def plot_hr_filtering(teff, logg, keep_mask):
    """Plot HR diagram showing the filtering process."""
    plt.figure(figsize=(12, 8))

    # Plot all stars
    plt.subplot(1, 2, 1)
    plt.scatter(teff, logg, alpha=0.5, s=1, c='gray', label='All stars')
    plt.xlabel('Effective Temperature (K)', fontsize=20)
    plt.ylabel('log g (surface gravity)', fontsize=20)
    plt.title('Original Sample', fontsize=24)
    plt.gca().invert_xaxis()  # Hot stars on the left
    plt.gca().invert_yaxis()  # High gravity at bottom
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot retained vs removed stars
    plt.subplot(1, 2, 2)
    plt.scatter(teff[~keep_mask], logg[~keep_mask], alpha=0.3, s=1, c='red', label='Removed')
    plt.scatter(teff[keep_mask], logg[keep_mask], alpha=0.7, s=2, c='blue', label='Retained')
    plt.xlabel('Effective Temperature (K)', fontsize=20)
    plt.ylabel('log g (surface gravity)', fontsize=20)
    plt.title('After Quality Filtering', fontsize=24)
    plt.gca().invert_xaxis()
    plt.gca().invert_yaxis()
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

def improved_preprocessing_pipeline(X_train_aug, X_test,
                                   y_train_aug, y_test,
                                   use_robust_scaling=True,
                                   apply_physical_filters_train=True,
                                   apply_physical_filters_test=False):
    """
    Improved preprocessing pipeline with optional physical sanity filtering and robust scaling,
    but without PCA (uses full normalized spectra).
    """

    print("=" * 50)
    print("IMPROVED PREPROCESSING PIPELINE (NO PCA)")
    print("=" * 50)

    # 1. Apply physical sanity filters if requested
    if apply_physical_filters_train:
        print("\n1. Applying physical sanity filters to training...")
        X_train_aug, y_train_aug = filter_stellar_sample(
            X_train_aug, y_train_aug, plot_hr_diagram=True
        )
    if apply_physical_filters_test:
        print("\n1. Applying physical sanity filters to test...")
        X_test, y_test = filter_stellar_sample(
            X_test, y_test, plot_hr_diagram=True
        )

    # 2. Spectral normalization (per spectrum)
    print("\n2. Normalizing spectra...")

    def normalize_spectra(X, eps=1e-8):
        mu  = np.mean(X, axis=1, keepdims=True)
        std = np.std(X, axis=1, keepdims=True)
        return (X - mu) / (std + eps)

    X_train_norm = normalize_spectra(X_train_aug)
    X_test_norm = normalize_spectra(X_test)

    # 3. Check for and handle NaN/inf values
    print("\n3. Checking for invalid values...")
    def clean_data(X, y, name):
        X_mask = np.isfinite(X).all(axis=1)
        y_mask = np.isfinite(y).all(axis=1)
        combined_mask = X_mask & y_mask
        if not combined_mask.all():
            print(f"   {name}: Removing {np.sum(~combined_mask)} samples with invalid values")
            X = X[combined_mask]
            y = y[combined_mask]
        else:
            print(f"   {name}: No invalid values found")
        return X, y

    X_train_norm, y_train_aug = clean_data(X_train_norm, y_train_aug, "Training")
    X_test_norm, y_test = clean_data(X_test_norm, y_test, "Test")

    # 4. Target scaling (physics-aware)
    print("\n4. Scaling target parameters...")
    if use_robust_scaling:
        scalers_y = [RobustScaler() for _ in range(y_train_aug.shape[1])]
        print("   Using RobustScaler (better for outliers)")
    else:
        scalers_y = [StandardScaler() for _ in range(y_train_aug.shape[1])]
        print("   Using StandardScaler")

    param_names = ['Teff (K)', '[Fe/H] (dex)', 'log g (dex)']

    y_train_scaled = np.zeros_like(y_train_aug, dtype=float)
    y_test_scaled = np.zeros_like(y_test, dtype=float)

    for i in range(y_train_aug.shape[1]):
        scalers_y[i].fit(y_train_aug[:, i].reshape(-1, 1))
        y_train_scaled[:, i] = scalers_y[i].transform(y_train_aug[:, i].reshape(-1, 1)).flatten()
        y_test_scaled[:, i] = scalers_y[i].transform(y_test[:, i].reshape(-1, 1)).flatten()

        print(f"   {param_names[i]}:")
        print(f"     Original range: [{y_train_aug[:, i].min():.2f}, {y_train_aug[:, i].max():.2f}]")
        print(f"     Scaled range: [{y_train_scaled[:, i].min():.2f}, {y_train_scaled[:, i].max():.2f}]")

    # 5. Final data summary
    print("\n5. Final data shapes:")
    print(f"   X_train: {X_train_norm.shape}")
    print(f"   X_test:  {X_test_norm.shape}")
    print(f"   y_train: {y_train_scaled.shape}")
    print(f"   y_test:  {y_test_scaled.shape}")

    processed_data = {
        'X_train': X_train_norm,
        'X_test': X_test_norm,
        'y_train': y_train_scaled,
        'y_test': y_test_scaled,
        'y_train_original': y_train_aug,
        'y_test_original': y_test
    }

    scalers = {'target_scalers': scalers_y}

    return processed_data, scalers


def run_complete_preprocessing(X_train_aug, X_test, y_train_aug, y_test):
    """
    Complete preprocessing pipeline without PCA.
    """
    processed_data, scalers = improved_preprocessing_pipeline(
        X_train_aug, X_test,
        y_train_aug, y_test,
        use_robust_scaling=True,
        apply_physical_filters_train=True,
        apply_physical_filters_test=False
    )

    plot_parameter_distributions(
        processed_data['y_train_original'],
        processed_data['y_train']
    )

    print("\n" + "="*50)
    print("PREPROCESSING COMPLETE! (NO PCA)")
    print("="*50)
    return processed_data, scalers


def plot_parameter_distributions(y_original, y_scaled, param_names=['Teff (K)', '[Fe/H] (dex)', 'log g (dex)']):

    fig, axes = plt.subplots(2, 3, figsize=(15, 8), sharex=False)

    # top: originals
    for i in range(3):
        ax = axes[0, i]
        ax.hist(y_original[:, i], bins=50, alpha=0.8, edgecolor='black')
        ax.set_title(f"{param_names[i]} - Original")
        ax.set_xlabel(param_names[i].split()[0])
        ax.set_ylabel('Density')
        ax.grid(True, alpha=0.3)

    # bottom: scaled (independent axes + correct labels)
    for i, label in enumerate(['Scaled Teff', 'Scaled [Fe/H]', 'Scaled log g']):
        ax = axes[1, i]
        ax.hist(y_scaled[:, i], bins=50, alpha=0.8, edgecolor='black')
        ax.set_title(f"{label} - After Balanced Augmentation")
        ax.set_xlabel(label)
        ax.set_ylabel('Density')
        ax.set_xlim(y_scaled[:, i].min(), y_scaled[:, i].max())  # ensure proper scale
        ax.grid(True, alpha=0.3)

    fig.suptitle('Comparison of Parameter Distributions', fontsize=18)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()



def inverse_transform_predictions(y_scaled, scalers_y):
    """
    Inverse transform scaled predictions back to original units
    """
    y_original = np.zeros_like(y_scaled)

    for i in range(y_scaled.shape[1]):
        y_original[:, i] = scalers_y[i].inverse_transform(y_scaled[:, i].reshape(-1, 1)).flatten()

    return y_original

In [None]:
processed_data, scalers = run_complete_preprocessing(X_raw_train, X_raw_test, y_raw_train, y_raw_test)

# Extract the final preprocessed data
X_train_final = processed_data['X_train']
X_test_final = processed_data['X_test']
y_train_final = processed_data['y_train']
y_test_final = processed_data['y_test']

# Shuffle the training data
perm = np.random.permutation(X_train_final.shape[0])
X_train_final = X_train_final[perm]
y_train_final = y_train_final[perm]

X_val_final = X_train_final[-2000:]  # Last 2000 samples for validation
y_val_final = y_train_final[-2000:]  # Last 2000 samples for validation


X_train_final = X_train_final[:-2000]  # Remaining samples for training
y_train_final = y_train_final[:-2000]  # Remaining samples for training

print(f"Final training set shape: X={X_train_final.shape}, y={y_train_final.shape}")
print(f"Final validation set shape: X={X_val_final.shape}, y={y_val_final.shape}")
print(f"Final test set shape: X={X_test_final.shape}, y={y_test_final.shape}")

In [None]:
X_train_final = X_train_final.astype('float32')
X_val_final   = X_val_final.astype('float32')
X_test_final  = X_test_final.astype('float32')
y_train_final = y_train_final.astype('float32')
y_val_final   = y_val_final.astype('float32')
y_test_final  = y_test_final.astype('float32')

In [None]:
# Clip all splits to the shared Teff, [Fe/H], and log g limits
teff_min = max(y_train_final[:, 0].min(), y_val_final[:, 0].min())
teff_max = min(y_train_final[:, 0].max(), y_val_final[:, 0].max())
feH_min  = max(y_train_final[:, 1].min(), y_val_final[:, 1].min())
feH_max  = min(y_train_final[:, 1].max(), y_val_final[:, 1].max())
logg_min = max(y_train_final[:, 2].min(), y_val_final[:, 2].min())
logg_max = min(y_train_final[:, 2].max(), y_val_final[:, 2].max())

train_mask = (y_train_final[:, 0] >= teff_min) & (y_train_final[:, 0] <= teff_max) & \
             (y_train_final[:, 1] >= feH_min)  & (y_train_final[:, 1] <= feH_max)  & \
             (y_train_final[:, 2] >= logg_min) & (y_train_final[:, 2] <= logg_max)
             
val_mask = (y_val_final[:, 0] >= teff_min) & (y_val_final[:, 0] <= teff_max) & \
           (y_val_final[:, 1] >= feH_min)  & (y_val_final[:, 1] <= feH_max)  & \
           (y_val_final[:, 2] >= logg_min) & (y_val_final[:, 2] <= logg_max)    
           
test_mask = (y_test_final[:, 0] >= teff_min) & (y_test_final[:, 0] <= teff_max) & \
            (y_test_final[:, 1] >= feH_min)  & (y_test_final[:, 1] <= feH_max)  & \
            (y_test_final[:, 2] >= logg_min) & (y_test_final[:, 2] <= logg_max)

X_train_final = X_train_final[train_mask]
y_train_final = y_train_final[train_mask]

X_val_final = X_val_final[val_mask]
y_val_final = y_val_final[val_mask]

X_test_final = X_test_final[test_mask]
y_test_final = y_test_final[test_mask]

In [None]:
print(f"After applying common limits:")
print(f"Final training set shape: X={X_train_final.shape}, y={y_train_final.shape}")
print(f"Final validation set shape: X={X_val_final.shape}, y={y_val_final.shape}")
print(f"Final test set shape: X={X_test_final.shape}, y={y_test_final.shape}")

In [None]:
# Inspect the Teff, [Fe/H], and log g distributions across splits
plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1)
plt.hist(y_train_final[:, 0], bins=50, alpha=0.5, label='Train', color='blue', density=True)
plt.hist(y_val_final[:, 0], bins=50, alpha=0.5, label='Validation', color='orange', density=True)
plt.hist(y_test_final[:, 0], bins=50, alpha=0.5, label='Test', color='green', density=True)
plt.xlabel('Scaled Teff', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.title('Distribution of Scaled Teff', fontsize=16)
plt.legend()
plt.subplot(1, 3, 2)
plt.hist(y_train_final[:, 1], bins=50, alpha=0.5, label='Train', color='blue', density=True)
plt.hist(y_val_final[:, 1], bins=50, alpha=0.5, label='Validation', color='orange', density=True)
plt.hist(y_test_final[:, 1], bins=50, alpha=0.5, label='Test', color='green', density=True)
plt.xlabel('Scaled [Fe/H]', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.title('Distribution of Scaled [Fe/H]', fontsize=16)
plt.legend()
plt.subplot(1, 3, 3)
plt.hist(y_train_final[:, 2], bins=50, alpha=0.5, label='Train', color='blue', density=True)
plt.hist(y_val_final[:, 2], bins=50, alpha=0.5, label='Validation', color='orange', density=True)
plt.hist(y_test_final[:, 2], bins=50, alpha=0.5, label='Test', color='green', density=True)
plt.xlabel('Scaled log g', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.title('Distribution of Scaled log g', fontsize=16)
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# Inspect limits in the original physical scale
y_train_original = inverse_transform_predictions(y_train_final, scalers['target_scalers'])
print(f"Limits in training set:")
print(f"Teff: {y_train_original[:, 0].min()} - {y_train_original[:, 0].max()}")
print(f"FeH:  {y_train_original[:, 1].min()} - {y_train_original[:, 1].max()}")
print(f"logg: {y_train_original[:, 2].min()} - {y_train_original[:, 2].max()}")
print(f"Limits in validation set:")
y_val_original = inverse_transform_predictions(y_val_final, scalers['target_scalers'])  
print(f"Teff: {y_val_original[:, 0].min()} - {y_val_original[:, 0].max()}")
print(f"FeH:  {y_val_original[:, 1].min()} - {y_val_original[:, 1].max()}")
print(f"logg: {y_val_original[:, 2].min()} - {y_val_original[:, 2].max()}")
print(f"Limits in test set:")   
y_test_original = inverse_transform_predictions(y_test_final, scalers['target_scalers'])
print(f"Teff: {y_test_original[:, 0].min()} - {y_test_original[:, 0].max()}")
print(f"FeH:  {y_test_original[:, 1].min()} - {y_test_original[:, 1].max()}")
print(f"logg: {y_test_original[:, 2].min()} - {y_test_original[:, 2].max()}")

In [None]:
# Remove duplicates from the test set while leaving them in the training data
def remove_duplicates_from_test(y_train, y_val, X_test, y_test):
    """Remove duplicate entries from the test set based on labels"""
    train_val_labels = set(map(tuple, np.vstack((y_train, y_val))))
    
    mask = np.array([tuple(label) not in train_val_labels for label in y_test])
    
    num_duplicates = np.sum(~mask)
    if num_duplicates > 0:
        print(f"Removing {num_duplicates} duplicate entries from the test set.")
    else:
        print("No duplicate entries to remove from the test set.")
    
    return X_test[mask], y_test[mask]

X_test_final, y_test_final = remove_duplicates_from_test(y_train_final, y_val_final, X_test_final, y_test_final)

In [None]:
print(f"Test set shape after removing duplicates: X={X_test_final.shape}, y={y_test_final.shape}")

In [None]:
def check_duplicates_on_train_and_val(y_train, y_val):
    """Check for duplicate entries between training and validation sets"""
    train_labels = set(map(tuple, y_train))
    val_labels = set(map(tuple, y_val))
    
    duplicates = train_labels.intersection(val_labels)
    
    if duplicates:
        print(f"Found {len(duplicates)} duplicate entries between training and validation sets.")
    else:
        print("No duplicate entries found between training and validation sets.")
        
check_duplicates_on_train_and_val(y_train_final, y_val_final)

In [None]:
# === Plot normalized spectra ===

# Plot several normalized spectra for inspection (three plots per row across two rows)
for i in range(6):
    plt.figure(figsize=(12, 6))
    plt.plot(X_train_final[i], label=f'Spectrum {i+1}')
    plt.title('Examples of Normalized Spectra (Training)', fontsize=24)
    plt.xlabel('Pixel (res-sampled wavelength)', fontsize=20)
    plt.ylabel('Normalized Flux', fontsize=20)
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:

# --- Primitive transforms -----------------------------------------------------
def add_noise(spectrum: np.ndarray, noise_level: float = 0.05) -> np.ndarray:
    """Additive Gaussian noise proportional to signal std."""
    signal_std = float(np.nanstd(spectrum))
    noise_std = signal_std * noise_level
    noise = np.random.normal(0.0, noise_std, spectrum.shape)
    return spectrum + noise

def shift_spectrum(spectrum: np.ndarray, max_shift: int = 2) -> np.ndarray:
    """Small circular shift (pixels)."""
    shift_amount = np.random.randint(-max_shift, max_shift + 1)
    return np.roll(spectrum, shift_amount)

def ripple_spectrum(
    spectrum: np.ndarray,
    amplitude: float = 0.02,
    freq_range: Tuple[float, float] = (1.0, 5.0)
) -> np.ndarray:
    """Low-frequency multiplicative ripple (continuum modulation)."""
    length = len(spectrum)
    freq = np.random.uniform(*freq_range)
    phase = np.random.uniform(0.0, 2.0 * np.pi)
    wave = 1.0 + amplitude * np.sin(np.linspace(0.0, freq * 2.0 * np.pi, length) + phase)
    return spectrum * wave

# --- Sampler weights for BALANCED augmentation --------------------------------
def _balanced_sampling_weights(
    y_train: np.ndarray,
    eps: float = 1e-8,
    target: Literal["mvn", "self-kde"] = "mvn"
) -> np.ndarray:
    """
    Compute per-sample weights ~ p_target(y) / p_current(y).
    target='mvn' uses N(mean, cov); target='self-kde' flattens density via 1/p_current.
    """
    assert y_train.ndim == 2, "y_train must be (n_samples, n_targets)"
    kde = gaussian_kde(y_train.T)
    p_current = kde(y_train.T)  # shape: (n_samples,)

    if target == "mvn":
        mean_target = np.mean(y_train, axis=0)
        cov_target = np.cov(y_train.T)
        mvn = multivariate_normal(mean=mean_target, cov=cov_target, allow_singular=True)
        p_target = mvn.pdf(y_train)
        w = p_target / (p_current + eps)
    else:
        # Equalize by down-weighting dense regions only
        w = 1.0 / (p_current + eps)

    w = np.clip(w, a_min=0.0, a_max=np.finfo(np.float64).max)
    w_sum = w.sum()
    if not np.isfinite(w_sum) or w_sum <= 0.0:
        # Fall back to uniform
        w = np.ones_like(w) / len(w)
    else:
        w /= w_sum
    return w

# --- Main API -----------------------------------------------------------------
def augment_data(
    X_train: np.ndarray,
    y_train: np.ndarray,
    augmentation_factor: float = 3.0,
    mode: Literal["balanced", "uniform", "none"] = "balanced",
    balanced_target: Literal["mvn", "self-kde"] = "mvn",
    transform_p: Dict[str, float] = None,
    transform_cfg: Dict[str, dict] = None,
    rng_seed: Optional[int] = 42,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Configurable augmentation for 1D spectra.

    Parameters
    ----------
    X_train : (N, L) float
        Spectra matrix (N samples, L wavelengths).
    y_train : (N, D) float
        Targets matrix (D parameters).
    augmentation_factor : float
        Total size multiplier (e.g., 3.0 => produce ~2N synthetic + keep N original).
        If <= 1.0 or mode=='none', returns inputs unchanged.
    mode : {'balanced','uniform','none'}
        'balanced' -> weighted resampling using KDE (and MVN target or self-KDE equalization).
        'uniform'  -> uniform random resampling.
        'none'     -> no augmentation.
    balanced_target : {'mvn','self-kde'}
        For 'balanced' mode: use MVN target or self-KDE flattening.
    transform_p : dict
        Per-transform application probabilities, e.g.:
        {'noise':0.5, 'shift':0.5, 'ripple':0.5}
    transform_cfg : dict
        Per-transform hyperparameters, e.g.:
        {'noise':{'noise_level':0.03}, 'shift':{'max_shift':2}, 'ripple':{'amplitude':0.01,'freq_range':(1,5)}}
    rng_seed : int or None
        Seed for numpy/random for reproducibility.

    Returns
    -------
    X_out, y_out : arrays
        Augmented (or original) datasets.
    """
    if rng_seed is not None:
        np.random.seed(rng_seed)
        random.seed(rng_seed)

    N = X_train.shape[0]
    if augmentation_factor <= 1.0 or mode == "none":
        return X_train, y_train

    # Defaults
    if transform_p is None:
        transform_p = {'noise': 0.5, 'shift': 0.5, 'ripple': 0.5}
    if transform_cfg is None:
        transform_cfg = {
            'noise':  {'noise_level': 0.03},
            'shift':  {'max_shift': 2},
            'ripple': {'amplitude': 0.01, 'freq_range': (1.0, 5.0)},
        }

    # Number of synthetic samples to add
    n_new = int((augmentation_factor - 1.0) * N)
    if n_new <= 0:
        return X_train, y_train

    # Sampling indices
    if mode == "balanced":
        weights = _balanced_sampling_weights(y_train, target=balanced_target)
        chosen_idx = np.random.choice(np.arange(N), size=n_new, p=weights)
    elif mode == "uniform":
        chosen_idx = np.random.randint(0, N, size=n_new)
    else:
        # Shouldn't reach (caught earlier), but keep safe
        return X_train, y_train

    # Generate augmented spectra
    aug_X = [X_train[i].copy() for i in chosen_idx]
    aug_y = [y_train[i].copy() for i in chosen_idx]

    for j in range(n_new):
        s = aug_X[j]

        if random.random() < float(transform_p.get('noise', 0.0)):
            s = add_noise(s, **transform_cfg.get('noise', {}))

        if random.random() < float(transform_p.get('shift', 0.0)):
            s = shift_spectrum(s, **transform_cfg.get('shift', {}))

        if random.random() < float(transform_p.get('ripple', 0.0)):
            s = ripple_spectrum(s, **transform_cfg.get('ripple', {}))

        aug_X[j] = s

    # Concatenate and shuffle (preserve dtype)
    X_out = np.concatenate([X_train, np.asarray(aug_X, dtype=X_train.dtype)], axis=0)
    y_out = np.concatenate([y_train, np.asarray(aug_y, dtype=y_train.dtype)], axis=0)

    # Global shuffle with same seed for determinism
    perm = np.random.permutation(X_out.shape[0])
    X_out = X_out[perm]
    y_out = y_out[perm]
    return X_out, y_out

# --- Convenience wrappers (optional) ------------------------------------------
def augment_uniform(
    X_train: np.ndarray, y_train: np.ndarray, factor: float = 3.0, **kwargs
) -> Tuple[np.ndarray, np.ndarray]:
    return augment_data(X_train, y_train, augmentation_factor=factor, mode="uniform", **kwargs)

def augment_balanced(
    X_train: np.ndarray, y_train: np.ndarray, factor: float = 3.0,
    balanced_target: Literal["mvn","self-kde"] = "mvn", **kwargs
) -> Tuple[np.ndarray, np.ndarray]:
    return augment_data(
        X_train, y_train, augmentation_factor=factor, mode="balanced",
        balanced_target=balanced_target, **kwargs
    )

In [None]:

def plot_parameter_distributions(
    y_original: np.ndarray,
    y_augmented: np.ndarray,
    param_names = ('Teff', '[Fe/H]', 'log g'),
    bins: int = 50,
    augmented_label: str = 'After Balanced Augmentation'
):
    """
    Compare label distributions before vs. after augmentation.

    - Uses common bin edges per parameter for fair comparison
    - Plots densities (normalized histograms)
    - Ignores NaNs safely
    - Returns the figure so you can save it
    """
    if y_original.ndim != 2 or y_augmented.ndim != 2:
        raise ValueError("y_original and y_augmented must be 2D arrays of shape (N, D).")
    if y_original.shape[1] != y_augmented.shape[1]:
        raise ValueError("Both arrays must have the same number of columns (parameters).")
    if len(param_names) != y_original.shape[1]:
        raise ValueError("param_names length must match number of parameters (columns).")

    D = y_original.shape[1]
    fig, axes = plt.subplots(2, D, figsize=(5*D, 8), sharex='col')
    fig.suptitle('Comparison of Parameter Distributions', fontsize=18)

    for i in range(D):
        yo = y_original[:, i]
        ya = y_augmented[:, i]
        yo = yo[np.isfinite(yo)]
        ya = ya[np.isfinite(ya)]

        # Handle degenerate cases
        if yo.size == 0 and ya.size == 0:
            for r in (0, 1):
                axes[r, i].text(0.5, 0.5, 'No finite data', ha='center', va='center')
                axes[r, i].set_axis_off()
            continue

        data_min = np.nanmin([yo.min() if yo.size else np.inf,
                              ya.min() if ya.size else np.inf])
        data_max = np.nanmax([yo.max() if yo.size else -np.inf,
                              ya.max() if ya.size else -np.inf])
        if not np.isfinite(data_min) or not np.isfinite(data_max) or data_min == data_max:
            for r in (0, 1):
                axes[r, i].text(0.5, 0.5, 'Insufficient spread', ha='center', va='center')
                axes[r, i].set_axis_off()
            continue

        edges = np.linspace(data_min, data_max, bins+1)

        # Row 1: original
        axes[0, i].hist(yo, bins=edges, density=True, alpha=0.75, edgecolor='black')
        axes[0, i].set_title(f"{param_names[i]} - Original", fontsize=12)
        axes[0, i].set_ylabel('Density', fontsize=11)
        axes[0, i].grid(True, alpha=0.25)

        # Row 2: augmented
        axes[1, i].hist(ya, bins=edges, density=True, alpha=0.75, edgecolor='black')
        axes[1, i].set_title(f"{param_names[i]} - {augmented_label}", fontsize=12)
        axes[1, i].set_xlabel(param_names[i], fontsize=11)
        axes[1, i].set_ylabel('Density', fontsize=11)
        axes[1, i].grid(True, alpha=0.25)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    return fig


# Balanced augmentation with specific transform settings
X_train_aug_balanced, y_train_aug_balanced = augment_uniform(
    X_train_final,
    y_train_final,
    factor=7.0,                          # total size multiplier
    balanced_target="mvn",               # or "self-kde"
    transform_p={'noise': 0.5, 'shift': 0.5, 'ripple': 0.5},
    transform_cfg={
        'noise':  {'noise_level': 0.03},
        'shift':  {'max_shift': 2},
        'ripple': {'amplitude': 0.01, 'freq_range': (1.0, 5.0)},
    },
    rng_seed=42
)

# Plot (common bin edges + densities)
fig = plot_parameter_distributions(
    y_train_final,
    y_train_aug_balanced,
    param_names=['Teff', '[Fe/H]', 'log g'],
    augmented_label='After Uniform Augmentation'
)
# Optionally save:
# fig.savefig('param_dists_before_after_balanced_aug.png', dpi=300, bbox_inches='tight')

In [None]:
# Inspect the parameter limits across the three splits
print(f"Limits in training set:")
print(f"Teff: {y_train_final[:, 0].min()} - {y_train_final[:, 0].max()}")
print(f"FeH:  {y_train_final[:, 1].min()} - {y_train_final[:, 1].max()}")
print(f"logg: {y_train_final[:, 2].min()} - {y_train_aug_balanced[:, 2].max()}")

In [None]:
y_train_dict = {
    'teff_output': y_train_aug_balanced[:, 0],
    'feh_output': y_train_aug_balanced[:, 1],
    'logg_output': y_train_aug_balanced[:, 2]
}

y_test_dict = {
    'teff_output': y_test_final[:, 0],
    'feh_output': y_test_final[:, 1],
    'logg_output': y_test_final[:, 2]
}

print("\n--- Data shapes ---")
print("X_train shape:", X_train_aug_balanced.shape)
print("X_val_final shape:", X_val_final.shape)
print("X_test shape:", X_test_final.shape)
print("y_train_teff shape:", y_train_dict['teff_output'].shape)


## Model Architecture
We optimize a fully connected residual network (MLP-ResNet) tailored for multitask regression across the three stellar labels.

- **Shared feature extractor**: stacked Dense + BatchNorm + ReLU blocks with residual skip connections capture non-linear spectral structure while keeping gradients stable.
- **Task-specific heads**: lightweight Dense blocks with dropout and \(L_2\) regularization specialize the shared representation for Teff, [Fe/H], and log g predictions.
- **Hyperparameter search**: Keras Tuner (BayesianOptimization) sweeps layer widths, activation choices, dropout rates, and optimizer configurations to match dataset scale.
- **Training signals**: the network minimizes mean-squared error per task, monitors MAE for interpretability, and inverts the stored scalers to report metrics in physical units.

This layout keeps the capacity flexible enough for precise regression while remaining lightweight for experimentation on standard GPUs.

In [None]:


def resnet_block(x, units, dropout_rate=0.2, l2_reg=1e-4):
    """ResNet block composed of Dense -> BatchNorm -> ReLU -> Dropout with a skip connection."""
    shortcut = x
    x = layers.Dense(
        units,
        kernel_regularizer=regularizers.l2(l2_reg),
        use_bias=False
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(dropout_rate)(x)

    x = layers.Dense(
        units,
        kernel_regularizer=regularizers.l2(l2_reg),
        use_bias=False
    )(x)
    x = layers.BatchNormalization()(x)

    # Adjust the shortcut connection if required
    if shortcut.shape[-1] != units:
        shortcut = layers.Dense(units, use_bias=False)(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    return x


def build_model(input_shape, num_outputs=3, dropout_rate=0.3, l2_reg=1e-4, hp=None):
    """
    Multitask model built with ResNet blocks. When ``hp`` is provided, Keras Tuner samples
    hyperparameters for the shared trunk and the task-specific heads.
    """
    if isinstance(input_shape, (tuple, list)):
        input_shape_tuple = tuple(input_shape)
    else:
        input_shape_tuple = (input_shape,)

    inputs = layers.Input(shape=input_shape_tuple)

    if hp is None:
        x = layers.Dense(
            256,
            kernel_regularizer=regularizers.l2(l2_reg),
            use_bias=False
        )(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

        x = resnet_block(x, 256, dropout_rate, l2_reg)
        x = resnet_block(x, 128, dropout_rate, l2_reg)
        x = resnet_block(x, 64, dropout_rate, l2_reg)

        shared = layers.Dense(64, activation='relu')(x)
    else:
        shared_l2 = hp.Float('shared_l2', 1e-6, 1e-3, sampling='log')

        initial_units = hp.Choice('initial_units', [128, 192, 256, 320])
        x = layers.Dense(
            initial_units,
            kernel_regularizer=regularizers.l2(shared_l2),
            use_bias=False
        )(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.Activation(hp.Choice('initial_activation', ['relu', 'selu']))(x)
        initial_dropout = hp.Float('initial_dropout', 0.0, 0.5, step=0.1)
        if initial_dropout > 0:
            x = layers.Dropout(initial_dropout)(x)

        num_blocks = hp.Int('shared_blocks', min_value=1, max_value=3)
        for block_idx in range(num_blocks):
            block_units = hp.Choice(
                f'shared_block_units_{block_idx}',
                values=[128, 160, 192, 224, 256]
            )
            block_dropout = hp.Float(
                f'shared_block_dropout_{block_idx}',
                min_value=0.0,
                max_value=0.5,
                step=0.1
            )
            x = resnet_block(
                x,
                units=block_units,
                dropout_rate=block_dropout,
                l2_reg=shared_l2
            )

        shared_units = hp.Choice('shared_head_units', [64, 96, 128, 160, 192])
        shared_activation = hp.Choice('shared_head_activation', ['relu', 'selu', 'gelu'])
        shared = layers.Dense(
            shared_units,
            activation=shared_activation,
            kernel_regularizer=regularizers.l2(shared_l2)
        )(x)
        shared_dropout = hp.Float('shared_head_dropout', 0.0, 0.5, step=0.1)
        if shared_dropout > 0:
            shared = layers.Dropout(shared_dropout)(shared)

    outputs = []
    for i in range(num_outputs):
        head = shared
        if hp is None:
            head = layers.Dense(32, activation='relu')(head)
            head = layers.Dropout(dropout_rate)(head)
        else:
            head_layers = hp.Int(f'task_{i}_layers', min_value=1, max_value=3)
            head_units = hp.Choice(f'task_{i}_units', [32, 48, 64, 96, 128])
            head_activation = hp.Choice(f'task_{i}_activation', ['relu', 'selu', 'gelu'])
            head_dropout = hp.Float(f'task_{i}_dropout', 0.0, 0.5, step=0.1)
            head_l2 = hp.Float(f'task_{i}_l2', 1e-6, 1e-3, sampling='log')

            for layer_idx in range(head_layers):
                head = layers.Dense(
                    head_units,
                    activation=head_activation,
                    kernel_regularizer=regularizers.l2(head_l2)
                )(head)
                if head_dropout > 0:
                    head = layers.Dropout(head_dropout)(head)

        out = layers.Dense(1, activation='linear', name=f'output_{i}')(head)
        outputs.append(out)

    model = Model(inputs=inputs, outputs=outputs)

    if hp is None:
        model.compile(
            optimizer='adam',
            loss=['huber'] * num_outputs,
            metrics=['mae'] * num_outputs
        )
    else:
        learning_rate = hp.Float('learning_rate', 1e-4, 5e-3, sampling='log')
        optimizer_name = hp.Choice('optimizer', ['adam', 'nadam'])
        if optimizer_name == 'adam':
            optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        else:
            optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate)

        model.compile(
            optimizer=optimizer,
            loss=['huber'] * num_outputs,
            metrics=['mae'] * num_outputs
        )

    return model


def build_bayesian_tuner(
    input_shape,
    num_outputs=3,
    max_trials=7,
    executions_per_trial=1,
    directory='keras_tuner',
    project_name='stellar_multitask'
):
    """Create a Bayesian tuner for the multitask model."""

    def model_builder(hp):
        return build_model(
            input_shape=input_shape,
            num_outputs=num_outputs,
            hp=hp
        )

    tuner = kt.BayesianOptimization(
        hypermodel=model_builder,
        objective=kt.Objective('val_loss', direction='min'),
        max_trials=max_trials,
        executions_per_trial=executions_per_trial,
        directory=directory,
        project_name=project_name,
        overwrite=False,
    )
    return tuner

In [None]:
y_train_split = {
    'output_0': y_train_dict['teff_output'],
    'output_1': y_train_dict['feh_output'],
    'output_2': y_train_dict['logg_output'],
}
y_val_split = {
    'output_0': y_val_final[:, 0],
    'output_1': y_val_final[:, 1],
    'output_2': y_val_final[:, 2],
}


In [None]:
num_outputs = y_train_aug_balanced.shape[1]
input_dim = X_train_aug_balanced.shape[1]

print("\n--- Starting Bayesian search with Keras Tuner ---")
tuner = build_bayesian_tuner(
    input_shape=input_dim,
    num_outputs=num_outputs,
    max_trials=40,
    executions_per_trial=1,
    directory='keras_tuner',
    project_name='stellar_multitask_head_search_4000_huber'
)

search_callbacks = [
    EarlyStopping(monitor='val_loss', patience=8, verbose=1, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, verbose=1)
]

tuner.search(
    X_train_aug_balanced,
    y_train_split,
    epochs=30,
    batch_size=256,
    validation_data=(X_val_final, y_val_split),
    callbacks=search_callbacks,
    verbose=1
)

tuner.results_summary()

best_hp = tuner.get_best_hyperparameters(num_trials=1)[0]
print("\nBest hyperparameters:")
for name, value in best_hp.values.items():
    print(f"   {name}: {value}")

model = build_model(
    input_shape=input_dim,
    num_outputs=num_outputs,
    hp=best_hp
)

final_callbacks = [
    EarlyStopping(monitor='val_loss', patience=12, verbose=1, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=6, verbose=1),
    ModelCheckpoint('best_stellar_multitask_model.keras', monitor='val_loss', save_best_only=True, verbose=1)
]

print("\n--- Training the best model found ---")
history = model.fit(
    X_train_aug_balanced,
    y_train_split,
    epochs=120,
    batch_size=128,
    validation_data=(X_val_final, y_val_split),
    callbacks=final_callbacks,
    verbose=1
)

model.save('stellar_regression_model.keras')
print("\nFinal model saved as 'stellar_regression_model.keras'")


In [None]:
# Plot the training history
plt.figure(figsize=(12, 5))
plt.subplot(1, 4, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss (MSE)')
plt.legend()
plt.grid(True)
plt.subplot(1, 4, 2)
plt.plot(history.history['output_0_mae'], label='Train MAE')
plt.plot(history.history['val_output_0_mae'], label='Val MAE')
plt.title('Training and Validation MAE')
plt.xlabel('Epochs')
plt.ylabel('MAE')
plt.legend()
plt.grid(True)
plt.subplot(1, 4, 3)
plt.plot(history.history['output_1_mae'], label='Train MAE')
plt.plot(history.history['val_output_1_mae'], label='Val MAE')
plt.title('Training and Validation MAE')
plt.xlabel('Epochs')
plt.ylabel('MAE')
plt.legend()
plt.grid(True)
plt.subplot(1, 4, 4)
plt.plot(history.history['output_2_mae'], label='Train MAE')
plt.plot(history.history['val_output_2_mae'], label='Val MAE')
plt.title('Training and Validation MAE')
plt.xlabel('Epochs')
plt.ylabel('MAE')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:

def _to_matrix(pred):
    """Accepts list of heads [(N,1), ...] or array (N,T); returns (N,T)."""
    if isinstance(pred, (list, tuple)):
        cols = []
        for a in pred:
            a = np.asarray(a).reshape(len(a), -1)
            # if someone passes hetero heads [mu, log_var], keep mu
            cols.append(a[:, 0])
        return np.column_stack(cols)
    A = np.asarray(pred)
    return A if A.ndim == 2 else A.reshape(len(A), -1)

def _inverse_targets(Y, scalers_y):
    """Inverse-transform each column using its scaler; returns (N,T)."""
    Y = np.asarray(Y)
    out = np.zeros_like(Y, dtype=float)
    for i, sc in enumerate(scalers_y):
        out[:, i] = sc.inverse_transform(Y[:, i].reshape(-1, 1)).ravel()
    return out

def bootstrap_mae(
    model,
    X_test: np.ndarray,
    y_test: np.ndarray,
    scalers_y: Optional[list] = None,
    n_bootstrap: int = 1000,
    seed: Optional[int] = 42,
) -> Dict[str, Tuple[float, float, Tuple[float, float]]]:
    """
    Bootstrap MAE from a single trained model (no retraining).
    Returns per-parameter (mean, sd, (ci_lo, ci_hi)) in the SAME UNITS for y_true & y_pred.
    If `scalers_y` is provided, both y_true and y_pred are inverse-transformed to physical units.
    """
    rng = np.random.default_rng(seed)

    # 1) Predict once on the full test set
    pred = model.predict(X_test, verbose=0)
    y_pred = _to_matrix(pred)                       # (N, T)
    y_true = np.asarray(y_test)                     # (N, T)

    # 2) Put both on the same scale (optionally inverse-transform to physical units)
    if scalers_y is not None:
        y_true_phys = _inverse_targets(y_true, scalers_y)
        y_pred_phys = _inverse_targets(y_pred, scalers_y)
    else:
        y_true_phys = y_true
        y_pred_phys = y_pred

    N, T = y_true_phys.shape
    boot = np.empty((n_bootstrap, T), dtype=float)

    # Precompute absolute errors once
    abs_err = np.abs(y_true_phys - y_pred_phys)     # (N, T)

    # 3) Bootstrap by resampling indices
    idx = rng.integers(0, N, size=(n_bootstrap, N))
    # mean over rows for each bootstrap sample
    for b in range(n_bootstrap):
        boot[b, :] = abs_err[idx[b]].mean(axis=0)

    # 4) Aggregate: mean, sd (ddof=1), and 95% percentile CI
    stats = {}
    for t in range(T):
        vals = boot[:, t]
        m = float(vals.mean())
        s = float(vals.std(ddof=1))
        lo, hi = np.percentile(vals, [2.5, 97.5])
        stats[f'param_{t}'] = (m, s, (float(lo), float(hi)))
    return stats


In [None]:
scalers_y = scalers['target_scalers']  # if you want physical units
mae_stats = bootstrap_mae(
    model,
    X_test_final,
    y_test_final,     # must match the scale used during training (scaled if you trained on scaled)
    scalers_y=scalers_y,    # set to None if you prefer scaled-units MAE
    n_bootstrap=1000,
    seed=42
)

for k, (m, s, (lo, hi)) in mae_stats.items():
    print(f"{k}: MAE = {m:} +/- {s:}  (95% CI [{lo:}, {hi:}])")


In [None]:
model = tf.keras.models.load_model('stellar_regression_model.keras')

In [None]:
#see how many parameters the model has
model.summary()

In [None]:
predictions_scaled_list = model.predict(X_test_final)

In [None]:
def wmape(y_true, y_pred):
    return np.sum(np.abs(y_true - y_pred)) / np.sum(np.abs(y_true))

In [None]:
# Final evaluation and analysis

# `predict` returns one array per output head
predictions_scaled = np.column_stack(predictions_scaled_list)  # Shape: (n_samples, 3)

param_names = ['Teff', 'Fe/H', 'log(g)']
print("\n--- Final model performance on the test set ---")

for i, name in enumerate(param_names):
    # Select the scaled predictions for the i-th head
    pred_scaled = predictions_scaled[:, i].reshape(-1, 1)

    # Invert the scaling applied to predictions
    pred_unscaled = scalers['target_scalers'][i].inverse_transform(pred_scaled)

    # Extract the true values
    true_scaled = y_test_final[:, i].reshape(-1, 1)
    true_unscaled = scalers['target_scalers'][i].inverse_transform(true_scaled)

    # Metrics
    mae = mean_absolute_error(true_unscaled, pred_unscaled)
    r2 = r2_score(true_unscaled, pred_unscaled)
    wmape_value = wmape(true_unscaled, pred_unscaled)

    print(f"Parameter: {name}")
    print(f"  -> Mean Absolute Error (MAE): {mae:.4f}")
    print(f"  -> Coefficient of Determination (R^2): {r2:.4f}")
    print(f"  -> Weighted Mean Absolute Percentage Error (WMAPE): {wmape_value:.4f}\n")


In [None]:
print("\n--- Generating regression plots (no outlier filtering, no fitted line) ---")

fig, axes = plt.subplots(1, 3, figsize=(22, 6))
fig.suptitle('Regression Plots', fontsize=24)

scatter_mappable = None

for i, name in enumerate(param_names):
    ax = axes[i]

    # Invert the scaling for predictions and ground truth (no filtering)
    pred_unscaled = scalers['target_scalers'][i].inverse_transform(
        predictions_scaled_list[i].reshape(-1, 1)
    ).flatten()

    true_scaled = y_test_final[:, i].reshape(-1, 1)
    true_values = scalers['target_scalers'][i].inverse_transform(true_scaled).flatten()

    # Metrics
    mae = mean_absolute_error(true_values, pred_unscaled)
    r2 = r2_score(true_values, pred_unscaled)

    # Density-colored scatter plot without any outlier filtering
    xy = np.vstack([true_values, pred_unscaled])
    z = gaussian_kde(xy)(xy)
    idx = z.argsort()
    x, y, z = true_values[idx], pred_unscaled[idx], z[idx]
    scatter_mappable = ax.scatter(x, y, c=z, s=15, cmap='viridis', alpha=0.7)

    # y = x reference line
    lim_min = np.min(np.concatenate([x, y]))
    lim_max = np.max(np.concatenate([x, y]))
    lims = [lim_min, lim_max]
    ax.plot(lims, lims, 'r--', alpha=0.8, zorder=3, label='Ideal')

    # >>> Removed: linear regression fit <<<
    # m, b = np.polyfit(true_values, pred_unscaled, 1)
    # ax.plot(np.array(lims), m * np.array(lims) + b, color='orange', linestyle='-',
    #         linewidth=2, zorder=2, label='Regression')

    # Text box with summary metrics
    stats_text = f'MAE = {mae:.4f}\n$R^2$ = {r2:.4f}'
    ax.text(0.05, 0.95, stats_text,
            transform=ax.transAxes,
            fontsize=16, verticalalignment='top',
            bbox=dict(boxstyle='square,pad=0.5', fc='wheat', alpha=0.5))

    ax.set_title(f'Parameter: {name}', fontsize=18)
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend(fontsize=14)

    if i == 0:
        ax.set_ylabel('Predicted', fontsize=16)
    ax.set_xlabel('True Value', fontsize=16)

# Layout and color bar
fig.tight_layout(rect=[0, 0, 0.9, 0.95])
if scatter_mappable is not None:
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    cbar = fig.colorbar(scatter_mappable, cax=cbar_ax)
    cbar.set_label('Point Density', rotation=270, labelpad=25, fontsize=16)

plt.show()


In [None]:
#save a and b from regression lines for later use
regression_params = {}
for i, name in enumerate(param_names):
    # Invert the scaling for predictions and ground truth (no filtering)
    pred_unscaled = scalers['target_scalers'][i].inverse_transform(
        predictions_scaled_list[i].reshape(-1, 1)
    ).flatten()

    true_scaled = y_test_final[:, i].reshape(-1, 1)
    true_values = scalers['target_scalers'][i].inverse_transform(true_scaled).flatten()

    # Linear regression fit
    m, b = np.polyfit(true_values, pred_unscaled, 1)
    regression_params[name] = (m, b)