In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline

In [3]:
from astromodal.config import load_config
from tqdm import tqdm
import polars as pl
import random
from pathlib import Path
from astromodal.datasets.datacubes import load_datacube_files

In [4]:
config = load_config("/home/schwarz/projetoFM/config.yaml")

hdd_folder = config['hdd_folder']

hddfolder = Path(config["hdd_folder"]) / "image_latents" 

In [5]:
splus_bands = [
    "u", "i", "r", "g", "z",
    "j0378", "j0395", "j0410", "j0430",
    "j0515", "j0660", "j0861",
]

In [6]:
# mandatory columns (must exist in all files)
CORE_COLUMNS = ["id", "ra", "dec"]

# your scalar features
SCALAR_COLUMNS = [
    "ellipticity_det",
    "elongation_det",
    "a_pixel_det",
    "b_pixel_det",
    "theta_det",
    "fwhm_n_det",
    *[f"mag_pstotal_{b}" for b in splus_bands],
    *[f"err_mag_pstotal_{b}" for b in splus_bands],
    "gaia_parallax",
    "gaia_parallax_error",
    "gaia_pmra",
    "gaia_pmdec",
    "gaia_pmra_error",
    "gaia_pmdec_error",
    "gaia_phot_bp_mean_flux",
    "gaia_phot_rp_mean_flux",
    "gaia_phot_g_mean_flux",
    "gaia_phot_bp_mean_flux_error",
    "gaia_phot_rp_mean_flux_error",
    "gaia_phot_g_mean_flux_error",
    "gaia_teff_gspphot",
    "gaia_logg_gspphot",
    "gaia_mh_gspphot",
    "specz_z",
    "specz_e_z",
    "vista_yapermag6",
    "vista_yapermag6err",
    "vista_japermag6",
    "vista_japermag6err",
    "vista_hapermag6",
    "vista_hapermag6err",
    "vista_ksapermag6",
    "vista_ksapermag6err",
]

# final schema (order matters for ML)
EXPECTED_COLUMNS = CORE_COLUMNS + SCALAR_COLUMNS

In [7]:
import polars as pl
from pathlib import Path
from typing import Sequence

def read_parquet_with_schema(
    path: str | Path,
    *,
    expected_columns: Sequence[str],
) -> pl.DataFrame:
    """
    Reads a parquet file and guarantees that all expected_columns exist.
    Missing columns are added as nulls.
    """
    path = Path(path)

    # inspect schema without loading data
    schema = pl.read_parquet_schema(path)
    available = set(schema.keys())

    # only read columns that exist
    cols_to_read = [c for c in expected_columns if c in available]
    df = pl.read_parquet(path, columns=cols_to_read, use_pyarrow=True)

    # add missing columns as null
    missing = [c for c in expected_columns if c not in df.columns]
    if missing:
        df = df.with_columns([pl.lit(None).alias(c) for c in missing])

    # reorder to canonical schema
    return df.select(expected_columns)

from pathlib import Path
import polars as pl
from typing import Sequence
from tqdm import tqdm


def load_datacubes_from_filelist(
    files: Sequence[str | Path],
    *,
    expected_columns: Sequence[str],
    desc: str = "Loading datacubes",
) -> pl.DataFrame:
    """
    Load multiple parquet files given explicitly as a list.
    Guarantees that ALL expected_columns exist in the final DataFrame.

    Parameters
    ----------
    files : list[str | Path]
        List of parquet files to read.
    expected_columns : list[str]
        Canonical schema (order is preserved).
    desc : str
        tqdm description.

    Returns
    -------
    pl.DataFrame
        Concatenated DataFrame with guaranteed schema.
    """
    dfs = []

    for f in tqdm(files, desc=desc):
        f = Path(f)
        try:
            # inspect schema first (cheap)
            schema = pl.read_parquet_schema(f)
            available = set(schema.keys())

            cols_to_read = [c for c in expected_columns if c in available]
            df = pl.read_parquet(f, columns=cols_to_read, use_pyarrow=True)

            # add missing columns as null
            missing = [c for c in expected_columns if c not in df.columns]
            if missing:
                df = df.with_columns([pl.lit(None).alias(c) for c in missing])

            # enforce canonical order
            df = df.select(expected_columns)

            dfs.append(df)

        except Exception as e:
            print(f"[skip] {f}: {e}")

    if not dfs:
        raise RuntimeError("No valid parquet files were loaded.")

    return pl.concat(dfs, how="vertical", rechunk=False)

In [8]:
file = config['datacubes_paths'].replace('*', 'STRIPE82-0002')

import polars as pl

header = pl.read_parquet(file, n_rows=0)
columns = [col for col in header.columns if 'gaiaxp' in col] + ["id", "mag_psf_r"]

In [9]:
train_files, val_files = load_datacube_files(config['datacubes_paths'], train_val_split=0.9, nfiles_subsample=500)

[info] - Found 2444 datacube files
[info] - Subsampled to 500 files
[info] - Training files: 450
[info] - Validation files: 50


In [10]:
train_df = load_datacubes_from_filelist(
    train_files,
    expected_columns=EXPECTED_COLUMNS,
)

train_df = train_df.filter(pl.col("mag_pstotal_r") < 21)

val_df = load_datacubes_from_filelist(
    val_files,
    expected_columns=EXPECTED_COLUMNS,
)

val_df = val_df.filter(pl.col("mag_pstotal_r") < 21)

Loading datacubes:   0%|          | 0/450 [00:00<?, ?it/s]

Loading datacubes: 100%|██████████| 450/450 [01:20<00:00,  5.58it/s]
Loading datacubes: 100%|██████████| 50/50 [00:12<00:00,  3.88it/s]


In [12]:
def count_non_null_per_column(df: pl.DataFrame) -> pl.DataFrame:
    """
    Returns a table with:
    - column name
    - number of non-null values
    """
    return df.select([
        pl.count().alias("n_rows"),
        *[
            pl.col(c).count().alias(c)
            for c in df.columns
        ]
    ]).transpose(
        include_header=True,
        header_name="column",
        column_names=["non_null_count"],
    ).filter(pl.col("column") != "n_rows")

In [13]:
nn = count_non_null_per_column(train_df)
nn.sort("non_null_count")

(Deprecated in version 0.20.5)
  pl.count().alias("n_rows"),


column,non_null_count
str,u32
"""specz_e_z""",299527
"""specz_z""",360786
"""vista_hapermag6""",1351706
"""vista_hapermag6err""",1351706
"""gaia_teff_gspphot""",2161068
…,…
"""b_pixel_det""",7259195
"""theta_det""",7259195
"""fwhm_n_det""",7259195
"""mag_pstotal_r""",7259195


In [29]:
def column_coverage(df: pl.DataFrame) -> pl.DataFrame:
    n = df.height
    return pl.DataFrame({
        "column": df.columns,
        "non_null": [df.select(pl.col(c).count()).item() for c in df.columns],
    }).with_columns(
        (pl.col("non_null") / n).alias("fraction")
    ).sort("fraction")

In [33]:
coverage = column_coverage(df)
coverage.write_csv("column_coverage.csv")

In [14]:
from astromodal.scalers.scaler1d import StandardScaler1D

In [None]:
from astromodal.scalers.scaler1d import StandardScaler1D
from astromodal.tokenizers.rvq import ResidualVQ
from astromodal.tokenizers.spectralrvq import SpectralResidualVQ

for col in SCALAR_COLUMNS[3:]:
    print(col)
    values = train_df.select(pl.col(col).drop_nulls()).to_series().to_numpy()

    mean = values.mean()
    std = values.std()

    scaler = StandardScaler1D(
        mean=mean,
        std=std,
    )

    val_values = val_df.select(pl.col(col).drop_nulls()).to_series().to_numpy()
    val_mean = val_values.mean()
    val_std = val_values.std()

    print(f"  train: mean={mean:.4f}, std={std:.4f}")
    print(f"  val:   mean={val_mean:.4f}, std={val_std:.4f}")




b_pixel_det
  train: mean=1.6234, std=0.8241
  val:   mean=1.6207, std=0.7945
theta_det
  train: mean=-3.9161, std=52.2860
  val:   mean=-0.4612, std=52.6731
fwhm_n_det
  train: mean=1.6279, std=1.8988
  val:   mean=1.5703, std=1.8145
mag_pstotal_u
  train: mean=21.2817, std=2.0948
  val:   mean=21.2497, std=2.0544
mag_pstotal_i
  train: mean=18.7422, std=1.7258
  val:   mean=18.7187, std=1.7152
mag_pstotal_r
  train: mean=19.2012, std=1.7691
  val:   mean=19.1764, std=1.7516
mag_pstotal_g
  train: mean=20.1524, std=1.9731
  val:   mean=20.1448, std=1.9521
mag_pstotal_z
  train: mean=18.4973, std=1.7116
  val:   mean=18.4764, std=1.7028
mag_pstotal_j0378
  train: mean=21.1070, std=2.0691
  val:   mean=21.1009, std=2.0381
mag_pstotal_j0395
  train: mean=20.8347, std=2.0272
  val:   mean=20.8651, std=2.0013
mag_pstotal_j0410
  train: mean=20.5400, std=2.1167
  val:   mean=20.5864, std=2.0987
mag_pstotal_j0430
  train: mean=20.4869, std=2.1148
  val:   mean=20.5491, std=2.0948
mag_pstotal

  val_mean = val_values.mean()
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


  train: mean=17.1637, std=1.6277
  val:   mean=17.2050, std=1.6329
vista_japermag6err
  train: mean=0.0943, std=7.1069
  val:   mean=0.0827, std=0.8039
vista_hapermag6
  train: mean=16.5750, std=1.5731
  val:   mean=16.5888, std=1.6005
vista_hapermag6err
  train: mean=0.1701, std=49.8035
  val:   mean=0.1007, std=0.6168
vista_ksapermag6
  train: mean=16.1005, std=1.4546
  val:   mean=16.2140, std=1.4916
vista_ksapermag6err
  train: mean=0.1702, std=4.3447
  val:   mean=0.1676, std=10.1178


In [30]:
"""
Fit + save (1) a StandardScaler1D and (2) a 1D SpectralPatchRVQ tokenizer
for EACH scalar column (including error columns), handling missing columns
gracefully.

Notes:
- For scalar columns we treat each object value as a length-1 sequence: [B, L=1, C=1].
- The tokenizer is therefore RVQ(dim=1) + SpectralPatchRVQ(patch_size=1, channels=1).
- If a column is missing (or has too few finite values), we still save a "default"
  scaler (mean=0,std=1) and *skip* tokenizer training (or you can train anyway on empties).
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Dict, Any, Tuple, List

import numpy as np
import polars as pl
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from astromodal.scalers.scaler1d import StandardScaler1D
from astromodal.tokenizers.rvq import ResidualVQ
from astromodal.tokenizers.spectralrvq import SpectralPatchRVQ


# -------------------------
# config / paths
# -------------------------

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE_T = torch.float32

OUT_ROOT = Path("outputs/scalars_tokenizers")  # <-- change to your config["models_folder"]/...
SCALERS_DIR = OUT_ROOT / "scalers"
TOK_DIR = OUT_ROOT / "tokenizers"
SCALERS_DIR.mkdir(parents=True, exist_ok=True)
TOK_DIR.mkdir(parents=True, exist_ok=True)

# RVQ/tokenizer hyperparams per scalar column
CODEBOOK_SIZE = 1024
NUM_STAGES = 3
DECAY = 0.99

# training
EPOCHS = 30
BATCH_SIZE = 4096
EMA_EPOCHS = 10  # freeze EMA after this (stability)
MIN_FINITE = 200  # skip training if too few values


# -------------------------
# columns
# -------------------------

CORE_COLUMNS = ["id", "ra", "dec"]

# you already define splus_bands somewhere above
# splus_bands = [...]

SCALAR_COLUMNS = [
    "ellipticity_det",
    "elongation_det",
    "a_pixel_det",
    "b_pixel_det",
    "theta_det",
    "fwhm_n_det",
    *[f"mag_pstotal_{b}" for b in splus_bands],
    *[f"err_mag_pstotal_{b}" for b in splus_bands],
    "gaia_parallax",
    "gaia_parallax_error",
    "gaia_pmra",
    "gaia_pmdec",
    "gaia_pmra_error",
    "gaia_pmdec_error",
    "gaia_phot_bp_mean_flux",
    "gaia_phot_rp_mean_flux",
    "gaia_phot_g_mean_flux",
    "gaia_phot_bp_mean_flux_error",
    "gaia_phot_rp_mean_flux_error",
    "gaia_phot_g_mean_flux_error",
    "gaia_teff_gspphot",
    "gaia_logg_gspphot",
    "gaia_mh_gspphot",
    "specz_z",
    "specz_e_z",
    "vista_yapermag6",
    "vista_yapermag6err",
    "vista_japermag6",
    "vista_japermag6err",
    "vista_hapermag6",
    "vista_hapermag6err",
    "vista_ksapermag6",
    "vista_ksapermag6err",
]

EXPECTED_COLUMNS = CORE_COLUMNS + SCALAR_COLUMNS


# -------------------------
# helpers
# -------------------------

def _finite_values_from_df(df: pl.DataFrame, col: str) -> np.ndarray:
    """
    Returns finite float64 values from df[col], empty if col missing.
    """
    if col not in df.columns:
        return np.array([], dtype=np.float64)

    s = df.select(pl.col(col)).to_series()
    # Convert to numpy + filter finite
    v = s.to_numpy()
    v = v.astype(np.float64, copy=False)
    v = v[np.isfinite(v)]
    return v


def fit_scaler_1d(
    train_df: pl.DataFrame,
    col: str,
    *,
    transform: str = "none",      # optionally "asinh"
    asinh_scale: float = 1.0,
    clip_quantile: Optional[float] = None,  # e.g. 0.999
) -> StandardScaler1D:
    """
    Fit StandardScaler1D pooling all finite values from the column.
    """
    v = _finite_values_from_df(train_df, col)
    if v.size == 0:
        return StandardScaler1D(mean=0.0, std=1.0, transform=transform, asinh_scale=asinh_scale)

    if transform == "asinh":
        s0 = asinh_scale if asinh_scale > 0 else 1.0
        v = np.arcsinh(v / s0)

    if clip_quantile is not None and 0.0 < clip_quantile < 1.0 and v.size > 10:
        lo = np.quantile(v, 1.0 - clip_quantile)
        hi = np.quantile(v, clip_quantile)
        v = np.clip(v, lo, hi)

    mean = float(np.mean(v))
    std = float(np.std(v))
    if not np.isfinite(std) or std < 1e-12:
        std = 1.0

    return StandardScaler1D(mean=mean, std=std, transform=transform, asinh_scale=asinh_scale)


def make_scalar_loader(
    df: pl.DataFrame,
    col: str,
    scaler: StandardScaler1D,
    *,
    batch_size: int,
    shuffle: bool,
) -> DataLoader:
    """
    Returns DataLoader yielding x_norm: [B, L=1, C=1].
    Missing/NaN values are dropped.
    """
    if col not in df.columns:
        # empty loader
        Xn = np.zeros((0, 1, 1), dtype=np.float32)
        return DataLoader(TensorDataset(torch.from_numpy(Xn)), batch_size=batch_size, shuffle=False, num_workers=0)

    # take finite values only
    v = _finite_values_from_df(df, col).astype(np.float32)  # [N]
    if v.size == 0:
        Xn = np.zeros((0, 1, 1), dtype=np.float32)
        return DataLoader(TensorDataset(torch.from_numpy(Xn)), batch_size=batch_size, shuffle=False, num_workers=0)

    # normalize
    vn = scaler.transform_x(v).astype(np.float32)  # StandardScaler1D works elementwise
    Xn = vn.reshape(-1, 1, 1)  # [N,1,1]

    ds = TensorDataset(torch.from_numpy(Xn))
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=0, pin_memory=True)


@torch.no_grad()
def eval_scalar_tok_unscaled_mse(
    tok: SpectralPatchRVQ,
    dl: DataLoader,
    scaler: StandardScaler1D,
    *,
    device: str,
) -> Dict[str, float]:
    """
    Evaluate unscaled MSE/RMSE in original units.
    Each batch x_norm is [B,1,1].
    """
    tok.eval().to(device)

    sum_mse = 0.0
    n = 0

    for (x_norm,) in dl:
        if x_norm.numel() == 0:
            continue

        x_norm = x_norm.to(device=device, dtype=DTYPE_T, non_blocking=True)  # [B,1,1]
        out = tok.encode(x_norm, update_ema=False)
        xq_norm = out["x_q"]  # [B,1,1]

        x_raw = scaler.inverse_transform_x(x_norm.detach().cpu().numpy().reshape(-1))   # [B]
        xq_raw = scaler.inverse_transform_x(xq_norm.detach().cpu().numpy().reshape(-1)) # [B]

        mse = float(np.mean((xq_raw - x_raw) ** 2))
        sum_mse += mse * x_norm.shape[0]
        n += x_norm.shape[0]

    mse = sum_mse / max(n, 1)
    rmse = float(np.sqrt(mse))
    return {
        "mse": float(mse),
        "rmse": float(rmse),
        "mse_x_1e6": float(mse * 1e6),
        "-log10(mse)": float(-np.log10(mse + 1e-300)),
    }

@torch.no_grad()
def compute_raw_metrics(x_norm, xq_norm, scaler):
    x_raw  = scaler.inverse_transform_x(x_norm)
    xq_raw = scaler.inverse_transform_x(xq_norm)
    diff2 = (xq_raw - x_raw) ** 2
    mse = float(diff2.mean())
    rmse = float(np.sqrt(mse))
    return mse, rmse

def train_scalar_tokenizer(
    train_dl: DataLoader,
    val_dl: DataLoader,
    *,
    codebook_size: int,
    num_stages: int,
    decay: float,
    epochs: int,
    ema_epochs: int,
    device: str,
    save_path: Path,
    scaler: StandardScaler1D,
) -> Tuple[SpectralPatchRVQ, List[Dict[str, Any]]]:
    """
    Train a scalar tokenizer (RVQ(dim=1) + SpectralPatchRVQ(patch_size=1, channels=1))
    and save best checkpoint by val(unscaled) MSE.
    """
    P = 1
    C = 1
    rvq = ResidualVQ(dim=P * C, num_stages=num_stages, codebook_size=codebook_size, decay=decay).to(device)
    tok = SpectralPatchRVQ(rvq=rvq, patch_size=P, channels=C).to(device)

    best_mse = float("inf")
    best_epoch = -1
    history: List[Dict[str, Any]] = []

    tok.train()
    for epoch in range(epochs):
        update_ema = (epoch < int(ema_epochs))
        tr = tok.train_epoch(train_dl, device=device, update_ema=update_ema)

        va = eval_scalar_tok_unscaled_mse(tok, val_dl, scaler, device=device)
        row = {"epoch": epoch, "train_loss_norm": float(tr["loss"]), "update_ema": bool(update_ema), **va}
        history.append(row)

        if va["mse"] < best_mse:
            best_mse = va["mse"]
            best_epoch = epoch
            tok.save(
                str(save_path),
                additional_info={
                    "best_epoch": int(best_epoch),
                    "best_val_mse": float(best_mse),
                    "codebook_size": int(codebook_size),
                    "num_stages": int(num_stages),
                    "decay": float(decay),
                    "patch_size": int(P),
                    "channels": int(C),
                },
            )

        print(
            f"Epoch {epoch:03d} | ema={int(update_ema)} | "
            f"train_loss(norm)={tr['loss']:.6g} | "
            f"val_mse={va['mse']:.6g} rmse={va['rmse']:.6g} | "
            f"best_mse={best_mse:.6g} @ {best_epoch}"
        )

    return tok, history


# -------------------------
# main: per-column fit + save scaler + tokenizer
# -------------------------

def fit_and_save_all_scalars(
    train_df: pl.DataFrame,
    val_df: pl.DataFrame,
    *,
    scalar_columns: List[str],
    scalers_dir: Path,
    tok_dir: Path,
    transform_for_errors: str = "asinh",  # often good for error-like heavy tails
    transform_for_values: str = "none",
    asinh_scale_default: float = 1.0,
    clip_quantile: Optional[float] = 0.999,
):
    results: List[Dict[str, Any]] = []

    for col in tqdm(scalar_columns, desc="Fitting scalers + tokenizers"):
        is_error_col = ("err_" in col) or col.endswith("error") or col.endswith("err")

        # choose scaler options
        transform = transform_for_errors if is_error_col else transform_for_values
        asinh_scale = asinh_scale_default

        # ---- fit scaler (train) ----
        scaler = fit_scaler_1d(
            train_df,
            col,
            transform=transform,
            asinh_scale=asinh_scale,
            clip_quantile=clip_quantile,
        )

        # ---- report train/val stats (raw, without transform) ----
        tr_vals = _finite_values_from_df(train_df, col)
        va_vals = _finite_values_from_df(val_df, col)

        tr_mean = float(np.mean(tr_vals)) if tr_vals.size else float("nan")
        tr_std = float(np.std(tr_vals)) if tr_vals.size else float("nan")
        va_mean = float(np.mean(va_vals)) if va_vals.size else float("nan")
        va_std = float(np.std(va_vals)) if va_vals.size else float("nan")

        print(f"\n[{col}]")
        print(f"  train(raw): mean={tr_mean:.6g}, std={tr_std:.6g}, n={tr_vals.size}")
        print(f"  val(raw):   mean={va_mean:.6g}, std={va_std:.6g}, n={va_vals.size}")
        print(f"  scaler: mean={scaler.mean:.6g}, std={scaler.std:.6g}, transform={scaler.transform}")

        # ---- save scaler ----
        scaler_path = scalers_dir / f"{col}.npz"
        scaler.save(scaler_path)

        # ---- if missing or too few points, skip tokenizer training but still record ----
        if tr_vals.size < MIN_FINITE or col not in train_df.columns:
            print(f"  -> skip tokenizer (too few finite values or missing). saved scaler only.")
            results.append({
                "col": col,
                "saved_scaler": str(scaler_path),
                "saved_tokenizer": None,
                "train_n_finite": int(tr_vals.size),
                "val_n_finite": int(va_vals.size),
                "skipped_tokenizer": True,
            })
            continue

        # ---- train tokenizer for this scalar ----
        train_dl = make_scalar_loader(train_df, col, scaler, batch_size=BATCH_SIZE, shuffle=True)
        val_dl = make_scalar_loader(val_df, col, scaler, batch_size=BATCH_SIZE, shuffle=False)

        tok_path = tok_dir / f"{col}.pt"
        tok, hist = train_scalar_tokenizer(
            train_dl,
            val_dl,
            codebook_size=CODEBOOK_SIZE,
            num_stages=NUM_STAGES,
            decay=DECAY,
            epochs=EPOCHS,
            ema_epochs=EMA_EPOCHS,
            device=DEVICE,
            save_path=tok_path,
            scaler=scaler,
        )

        results.append({
            "col": col,
            "saved_scaler": str(scaler_path),
            "saved_tokenizer": str(tok_path),
            "train_n_finite": int(tr_vals.size),
            "val_n_finite": int(va_vals.size),
            "skipped_tokenizer": False,
            "best_val_mse": float(min(h["mse"] for h in hist)) if hist else None,
        })

    return pl.DataFrame(results)


# -------------------------
# usage
# -------------------------
# train_df, val_df are your polars dataframes
# Make sure they already have the scalar columns as numeric (Float32/Float64).
#
# out_df = fit_and_save_all_scalars(train_df, val_df, scalar_columns=SCALAR_COLUMNS,
#                                  scalers_dir=SCALERS_DIR, tok_dir=TOK_DIR)
# out_df.write_parquet(OUT_ROOT / "scalar_fit_summary.parquet")
# print(out_df)

In [None]:
out_df = fit_and_save_all_scalars(train_df, val_df, scalar_columns=SCALAR_COLUMNS,
                                 scalers_dir=SCALERS_DIR, tok_dir=TOK_DIR)


Fitting scalers + tokenizers:   0%|          | 0/55 [00:00<?, ?it/s]


[ellipticity_det]
  train(raw): mean=0.159833, std=0.125802, n=7259195
  val(raw):   mean=0.156818, std=0.125913, n=905803
  scaler: mean=0.159777, std=0.125484, transform=none
[info] - Saved SpectralPatchRVQ to outputs/scalars_tokenizers/tokenizers/ellipticity_det.pt
Epoch 000 | ema=1 | train_loss(norm)=4.13459e-07 | val_mse=1.09266e-09 rmse=3.30555e-05 | best_mse=1.09266e-09 @ 0
[info] - Saved SpectralPatchRVQ to outputs/scalars_tokenizers/tokenizers/ellipticity_det.pt
Epoch 001 | ema=1 | train_loss(norm)=3.64264e-08 | val_mse=3.89674e-10 rmse=1.97402e-05 | best_mse=3.89674e-10 @ 1
[info] - Saved SpectralPatchRVQ to outputs/scalars_tokenizers/tokenizers/ellipticity_det.pt
Epoch 002 | ema=1 | train_loss(norm)=1.52262e-08 | val_mse=1.77457e-10 rmse=1.33213e-05 | best_mse=1.77457e-10 @ 2
[info] - Saved SpectralPatchRVQ to outputs/scalars_tokenizers/tokenizers/ellipticity_det.pt
Epoch 003 | ema=1 | train_loss(norm)=8.32845e-09 | val_mse=9.7587e-11 rmse=9.87862e-06 | best_mse=9.7587e-11 

Fitting scalers + tokenizers:   2%|▏         | 1/55 [08:50<7:57:35, 530.65s/it]

Epoch 029 | ema=0 | train_loss(norm)=1.27468e-09 | val_mse=2.1642e-11 rmse=4.6521e-06 | best_mse=2.1642e-11 @ 9

[elongation_det]
  train(raw): mean=1.24178, std=9.76633, n=7259195
  val(raw):   mean=1.23485, std=1.53701, n=905803
  scaler: mean=1.23145, std=0.305562, transform=none
[info] - Saved SpectralPatchRVQ to outputs/scalars_tokenizers/tokenizers/elongation_det.pt
Epoch 000 | ema=1 | train_loss(norm)=742.255 | val_mse=0.1359 rmse=0.368646 | best_mse=0.1359 @ 0
Epoch 001 | ema=1 | train_loss(norm)=835.897 | val_mse=0.151534 rmse=0.389273 | best_mse=0.1359 @ 0
