In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline

In [3]:
from astromodal.config import load_config
from tqdm import tqdm
import polars as pl
import random
from pathlib import Path
from astromodal.datasets.datacubes import load_datacube_files

In [4]:
config = load_config("/home/schwarz/projetoFM/config.yaml")

hdd_folder = config['hdd_folder']

hddfolder = Path(config["hdd_folder"]) / "image_latents" 

In [5]:
file = config['datacubes_paths'].replace('*', 'STRIPE82-0002')

import polars as pl

header = pl.read_parquet(file, n_rows=0)
columns = [col for col in header.columns if 'gaiaxp' in col] + ["id", "mag_psf_r"]

In [6]:
train_files, val_files = load_datacube_files(config['datacubes_paths'], train_val_split=0.9, nfiles_subsample=500)

[info] - Found 2444 datacube files
[info] - Subsampled to 500 files
[info] - Training files: 450
[info] - Validation files: 50


In [7]:
train_df = None

for f in tqdm(train_files, desc="Loading train files"):
    try:
        df = pl.read_parquet(f, columns=columns, use_pyarrow=True)
        df = df.filter(pl.col(columns[0]).is_not_null())
        
        df = df.filter(pl.col("mag_psf_r") < 21)
        
        
        if df.height == 0:
            continue

        train_df = df if train_df is None else pl.concat([train_df, df], how="vertical", rechunk=False)
    except Exception as e:
        pass

val_df = None

for f in tqdm(val_files, desc="Loading val files"):
    try:
        df = pl.read_parquet(f, columns=columns, use_pyarrow=True)
        df = df.filter(pl.col(columns[0]).is_not_null())
        
        df = df.filter(pl.col("mag_psf_r") < 21)
        
        
        if df.height == 0:
            continue

        val_df = df if val_df is None else pl.concat([val_df, df], how="vertical", rechunk=False)
    except Exception as e:
        pass

Loading train files:   0%|          | 0/450 [00:00<?, ?it/s]

Loading train files: 100%|██████████| 450/450 [01:26<00:00,  5.21it/s]
Loading val files: 100%|██████████| 50/50 [00:10<00:00,  4.81it/s]


In [8]:
train_df = train_df.filter(pl.col("gaiaxp_solution_id").is_not_null())
val_df = val_df.filter(pl.col("gaiaxp_solution_id").is_not_null())

In [9]:
len(train_df), len(val_df)

(184987, 25358)

In [12]:
# rename columns, remove prefix gaiaxp_
train_df = train_df.rename(lambda x: x.replace('gaiaxp_', ''))
val_df = val_df.rename(lambda x: x.replace('gaiaxp_', ''))

In [12]:
import ast

train_df = train_df.with_columns(
    pl.col("bp_coefficients")
      .map_elements(ast.literal_eval, return_dtype=pl.List(pl.Float64))
)

train_df = train_df.with_columns(
    pl.col("rp_coefficients")
      .map_elements(ast.literal_eval, return_dtype=pl.List(pl.Float64))
)


In [36]:
from astromodal.specifics.gaiaxp_scaler import fit_standard_scaler_vec_from_gaiaxp

scaler_bp = fit_standard_scaler_vec_from_gaiaxp(
    train_df,
    col="bp_coefficients",
    max_rows=100000,
    seed=0,
)

scaler_rp = fit_standard_scaler_vec_from_gaiaxp(
    train_df,
    col="rp_coefficients",
    max_rows=100000,
    seed=0,
)


In [37]:
scaler_bp.save(Path(config["models_folder"]) / "scalers" / "gaiaxp_scaler_bp.pkl")
scaler_rp.save(Path(config["models_folder"]) / "scalers" / "gaiaxp_scaler_rp.pkl")

In [39]:
import numpy as np
import ast

def _parse_list_cell(x):
    # Handles: list/np.ndarray OR string like "[1,2,3]"
    if x is None:
        return None
    if isinstance(x, (list, tuple, np.ndarray)):
        return np.asarray(x, dtype=np.float32)
    if isinstance(x, str):
        # safe parse
        return np.asarray(ast.literal_eval(x), dtype=np.float32)
    # fallback
    return np.asarray(x, dtype=np.float32)

def df_col_to_matrix(df, col="bp_coefficients"):
    arr = []
    for x in df[col].to_list():
        v = _parse_list_cell(x)
        if v is None:
            continue
        arr.append(v)
    if len(arr) == 0:
        raise ValueError(f"No valid rows found in {col}")
    # all should have same length
    D = len(arr[0])
    bad = [i for i,a in enumerate(arr) if len(a) != D]
    if bad:
        raise ValueError(f"Inconsistent coefficient lengths. Example bad idx: {bad[:5]}")
    return np.stack(arr, axis=0).astype(np.float32)  # [N, D]

In [49]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path

from astromodal.tokenizers.rvq import ResidualVQ
from astromodal.tokenizers.spectralrvq import SpectralPatchRVQ


# -------------------------
# helpers: weights
# -------------------------

def compute_coeff_weights_from_train(
    train_df,
    *,
    col="bp_coefficients",
    method="inv_p90",   # "inv_p90" | "inv_std" | "none"
    eps=1e-3,
) -> np.ndarray:
    """
    Returns weight vector w: [55] in original units space, normalized to mean=1.
    Used to weight per-position MSE (L dimension).
    """
    X = df_col_to_matrix(train_df, col=col).astype(np.float64)  # [N,55]

    if method == "none":
        w = np.ones(X.shape[1], dtype=np.float64)

    elif method == "inv_p90":
        p90 = np.quantile(np.abs(X), 0.90, axis=0)  # [55]
        w = 1.0 / np.clip(p90, eps, None)

    elif method == "inv_std":
        std = np.std(X, axis=0)  # [55]
        w = 1.0 / np.clip(std, eps, None)

    else:
        raise ValueError(f"Unknown method={method}")

    w = w / np.mean(w)
    return w.astype(np.float32)


# -------------------------
# DataLoader that also provides w: [B,55]
# -------------------------

def make_coeff_loader(
    df,
    scaler,
    *,
    col="bp_coefficients",
    batch_size=1024,
    shuffle=True,
    weight_vec: np.ndarray | None = None,   # [55] or None
):
    # raw: [N,55]
    X = df_col_to_matrix(df, col=col).astype("float32")

    # normalize in same shape [N,55]
    Xn = scaler.transform_x(X).astype("float32")

    # enforce sequence layout [N,55,1]
    Xn = Xn.reshape(Xn.shape[0], Xn.shape[1], 1)
    assert Xn.ndim == 3 and Xn.shape[1] == 55 and Xn.shape[2] == 1, f"bad Xn shape {Xn.shape}"

    ds = TensorDataset(torch.from_numpy(Xn))

    # custom collate to attach per-position weights
    w_t = None
    if weight_vec is not None:
        w_t = torch.as_tensor(weight_vec, dtype=torch.float32)  # [55]

    def collate(batch):
        # batch is list of tuples: [(x,), (x,), ...]
        x = torch.stack([b[0] for b in batch], dim=0)  # [B,55,1]
        if w_t is None:
            return (x,)  # old behavior
        B = x.shape[0]
        w = w_t.unsqueeze(0).expand(B, -1).contiguous()  # [B,55]
        # SpectralPatchRVQ.train_epoch expects (x, mask?, w?)
        return (x, None, w)

    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate,
    )


# -------------------------
# eval in ORIGINAL units + per-index metrics
# -------------------------

@torch.no_grad()
def eval_tok_unscaled_stats(
    tok: SpectralPatchRVQ,
    dl,
    scaler,
    device="cuda",
    p90_ref: float | None = None,  # optional: global p90(|coeff|) for NRMSE
):
    tok.eval()

    sum_mse = 0.0
    n = 0

    sum_diff2_i = None  # [55]
    count_i = 0

    for batch in dl:
        # batch can be (x,) or (x, None, w)
        x_norm = batch[0].to(device=device, dtype=torch.float32, non_blocking=True)  # [B,55,1]

        out = tok.encode(x_norm, update_ema=False)
        xq_norm = out["x_q"]  # [B,55,1]

        # inverse transform expects [B,55] (safer)
        x_raw  = scaler.inverse_transform_x(x_norm.detach().cpu().numpy().squeeze(-1))   # [B,55]
        xq_raw = scaler.inverse_transform_x(xq_norm.detach().cpu().numpy().squeeze(-1))  # [B,55]

        diff2 = (xq_raw - x_raw) ** 2  # [B,55]

        mse_global = float(diff2.mean())
        sum_mse += mse_global * x_norm.shape[0]
        n += x_norm.shape[0]

        s = diff2.sum(axis=0)  # [55]
        sum_diff2_i = s if sum_diff2_i is None else (sum_diff2_i + s)
        count_i += diff2.shape[0]

    mse = sum_mse / max(n, 1)
    rmse = float(np.sqrt(mse))

    mse_i = sum_diff2_i / max(count_i, 1)  # [55]
    rmse_i = np.sqrt(mse_i)

    # NRMSE using provided p90 reference (recommended)
    nrmse_p90 = None
    if p90_ref is not None:
        nrmse_p90 = float(rmse / max(float(p90_ref), 1e-12))

    return {
        "mse": float(mse),
        "rmse": float(rmse),
        "mse_x_1e6": float(mse * 1e6),
        "-log10(mse)": float(-np.log10(mse + 1e-300)),
        "rmse_idx_mean": float(np.mean(rmse_i)),
        "rmse_idx_median": float(np.median(rmse_i)),
        "rmse_idx_max": float(np.max(rmse_i)),
        "rmse_idx_min": float(np.min(rmse_i)),
        "nrmse_p90": nrmse_p90,
        # if you want to plot later:
        "rmse_i": rmse_i,
    }


# -------------------------
# training
# -------------------------

def train_coeff_spectralrvq(
    train_df,
    val_df,
    scaler,
    *,
    col="bp_coefficients",
    codebook_size=1024,
    num_stages=4,
    decay=0.99,
    epochs=60,
    batch_size=1024,
    device="cuda",
    save_path="best_spectralrvq.pt",
    weight_method="inv_p90",   # "inv_p90" | "inv_std" | "none"
    patience=8,                # <--- NEW: early stopping patience
    min_delta=0.0,             # <--- NEW: require improvement > min_delta
):
    # weights in ORIGINAL space (per index)
    w_vec = compute_coeff_weights_from_train(
        train_df, col=col, method=weight_method
    ) if weight_method != "none" else None

    # p90 reference for NRMSE (global, original units)
    Xv = df_col_to_matrix(val_df, col=col).astype(np.float64)
    p90_ref = float(np.quantile(np.abs(Xv), 0.90))

    dl_train = make_coeff_loader(
        train_df, scaler,
        col=col, batch_size=batch_size, shuffle=True,
        weight_vec=w_vec
    )
    dl_val = make_coeff_loader(
        val_df, scaler,
        col=col, batch_size=batch_size, shuffle=False,
        weight_vec=None
    )

    # 55 tokens -> patch_size=1, channels=1 -> RVQ dim=1
    P = 1
    C = 1
    rvq = ResidualVQ(
        dim=P * C,
        num_stages=num_stages,
        codebook_size=codebook_size,
        decay=decay,
    ).to(device)

    tok = SpectralPatchRVQ(
        rvq=rvq,
        patch_size=P,
        channels=C,
    ).to(device)

    history = []
    best_mse = float("inf")
    best_epoch = -1
    bad_epochs = 0

    tok.train()
    for epoch in range(epochs):
        # IMPORTANT: EMA must stay ON; otherwise RVQ stops learning.
        train_stats = tok.train_epoch(
            dl_train, device=device, update_ema=True
        )

        val_stats = eval_tok_unscaled_stats(
            tok, dl_val, scaler, device=device, p90_ref=p90_ref
        )

        history.append({
            "epoch": int(epoch),
            "train_loss_norm": float(train_stats["loss"]),
            "weight_method": str(weight_method),
            **{k: v for k, v in val_stats.items() if k != "rmse_i"},
        })

        # -------- best checkpoint + early stopping --------
        improved = (val_stats["mse"] < (best_mse - float(min_delta)))

        if improved:
            best_mse = float(val_stats["mse"])
            best_epoch = int(epoch)
            bad_epochs = 0

            tok.save(
                str(save_path),
                additional_info={
                    "best_epoch": best_epoch,
                    "best_val_mse": best_mse,
                    "codebook_size": int(codebook_size),
                    "num_stages": int(num_stages),
                    "decay": float(decay),
                    "col": str(col),
                    "weight_method": str(weight_method),
                    "p90_ref": float(p90_ref),
                },
            )
        else:
            bad_epochs += 1

        print(
            f"Epoch {epoch:03d} | "
            f"train_loss(norm)={train_stats['loss']:.6g} | "
            f"val_mse={val_stats['mse']:.6g} rmse={val_stats['rmse']:.6g} "
            f"nrmse_p90={val_stats['nrmse_p90'] if val_stats['nrmse_p90'] is not None else None} | "
            f"rmse_idx(mean/med/max)={val_stats['rmse_idx_mean']:.4g}/"
            f"{val_stats['rmse_idx_median']:.4g}/"
            f"{val_stats['rmse_idx_max']:.4g} | "
            f"best_mse={best_mse:.6g} @ {best_epoch} | "
            f"bad_epochs={bad_epochs}/{patience}"
        )

        if bad_epochs >= int(patience):
            print(f"\n✔ Early stopping at epoch {epoch} (best @ {best_epoch}, mse={best_mse:.6g})")
            break

    print(f"\n✔ Best model saved at epoch {best_epoch} with val_mse={best_mse:.6g}")
    return tok, history

# -------------------------
# example
# -------------------------

filt = "bp"
scaler = scaler_bp

tok, hist = train_coeff_spectralrvq(
    train_df,
    val_df,
    scaler,
    col=f"{filt}_coefficients",
    codebook_size=1024,
    num_stages=3,
    epochs=40,
    save_path=Path(config["models_folder"]) / "tokenizers" / f"gaiaxp_spectral_rvq_{filt}.pt",
    weight_method="inv_p90",
    patience=8,
    min_delta=0.0,
)

[info] - Saved SpectralPatchRVQ to /home/schwarz/projetoFM/outputs/tokenizers/gaiaxp_spectral_rvq_bp.pt
Epoch 000 | train_loss(norm)=0.303548 | val_mse=714.362 rmse=26.7275 nrmse_p90=7.768789349716187 | rmse_idx(mean/med/max)=6.966/0.5249/143.6 | best_mse=714.362 @ 0 | bad_epochs=0/8
[info] - Saved SpectralPatchRVQ to /home/schwarz/projetoFM/outputs/tokenizers/gaiaxp_spectral_rvq_bp.pt
Epoch 001 | train_loss(norm)=0.118965 | val_mse=94.7076 rmse=9.73178 nrmse_p90=2.8286981819200445 | rmse_idx(mean/med/max)=2.473/0.2408/64.73 | best_mse=94.7076 @ 1 | bad_epochs=0/8
[info] - Saved SpectralPatchRVQ to /home/schwarz/projetoFM/outputs/tokenizers/gaiaxp_spectral_rvq_bp.pt
Epoch 002 | train_loss(norm)=0.060663 | val_mse=72.607 rmse=8.52097 nrmse_p90=2.4767570575025246 | rmse_idx(mean/med/max)=2.358/0.1936/48.49 | best_mse=72.607 @ 2 | bad_epochs=0/8
Epoch 003 | train_loss(norm)=0.0272439 | val_mse=86.8943 rmse=9.32171 nrmse_p90=2.7095049526134956 | rmse_idx(mean/med/max)=2.033/0.1435/66.79 | 

In [42]:
X = df_col_to_matrix(val_df, col="bp_coefficients").astype(np.float64)
print("median |coeff|:", np.median(np.abs(X)))
print("p90 |coeff|:", np.quantile(np.abs(X), 0.90))

median |coeff|: 0.2966558188199997
p90 |coeff|: 3.440374636650086
