In [None]:
# spectra_set_arrow.py
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
import numpy as np
from typing import Any, Dict, Sequence, Optional, Union, Literal, Tuple
from sklearn.base import TransformerMixin

AugPolicy   = Literal["all", "original", "augmented"]
FeatShape   = Literal["concatenate", "interlace", "2d", "transpose2d"]

class SpectraSet:
    """
    A minimal Arrow + Polars replacement for the previous xarray SpectraSet.
    All public methods keep (almost) the same signature.
    """
    # ------------------------------------------------------------------ #
    # constructor / helpers                                              #
    # ------------------------------------------------------------------ #
    def __init__(self, meta: pl.DataFrame, spectra_tbl: pa.Table):
        self.meta   = meta      # every row = one observation
        self.tbl    = spectra_tbl    # column "spec" list<float32>

    # >>> factory for toy examples ------------------------------------- #
    @classmethod
    def from_dict(cls,
                  spectra: Dict[str, Sequence[Sequence[Sequence[float]]]],
                  # spectra[source][sample][augmentation] = 1-D list/array
                  target: Optional[Sequence[Any]] = None):
        rows_meta = []
        spec_buf  = []
        obs_id    = 0
        for src, samples in spectra.items():
            for sid, aug_list in enumerate(samples):
                for aug_id, vec in enumerate(aug_list):
                    spec_buf.append(np.asarray(vec, dtype=np.float32))
                    rows_meta.append({
                        "obs_id":       obs_id,
                        "sample_id":    sid,
                        "source":       src,
                        "processing":   "raw",
                        "augmentation_id": aug_id,
                        "branch":       None,
                        "split":        None,
                        "fold_id":      None,
                        "row_idx":      obs_id
                    })
                    obs_id += 1
        meta = pl.DataFrame(rows_meta)
        tbl  = pa.Table.from_pydict({"spec": pa.array(spec_buf, type=pa.list_(pa.float32()))})
        if target is not None:
            meta = meta.with_columns(pl.Series("target", list(target))[meta["sample_id"]])
        return cls(meta, tbl)

    # internal: return a filtered meta dataframe ---------------------- #
    def _filter(self,
                *,
                split:  Optional[Union[str, Sequence[str]]] = None,
                fold_id:Optional[Union[int, Sequence[int]]] = None,
                groups: Optional[Dict[str, Any]]           = None,
                branch: Optional[Union[str, Sequence[str]]] = None,
                augment:AugPolicy = "all") -> pl.DataFrame:

        m = self.meta.lazy()

        def _isin(col, val):
            return pl.col(col).is_in(val if isinstance(val, (list, tuple, np.ndarray)) else [val])

        if split is not None:     m = m.filter(_isin("split", split))
        if fold_id is not None:   m = m.filter(_isin("fold_id", fold_id))
        if branch is not None:    m = m.filter(_isin("branch", branch))
        if groups:
            for gname, v in groups.items():
                m = m.filter(_isin(f"group_{gname}", v))
        if augment != "all":
            if augment == "original":
                m = m.filter(pl.col("augmentation_id") == 0)
            else:  # "augmented"
                m = m.filter(pl.col("augmentation_id") > 0)

        return m.collect()

    # internal: Arrow take -------------------------------------------- #
    def _take_specs(self, row_idx: np.ndarray) -> list[np.ndarray]:
        """
        Given Arrow row indices, return a Python list of NumPy 1-D float32
        vectors (one per observation, in the same order).
        """
        arr = pc.take(self.tbl.column("spec"), pa.array(row_idx))
        # arr is a ListArray; arr.to_pylist() -> list[list[float]]
        return [np.asarray(v, dtype=np.float32) for v in arr.to_pylist()]

    # ------------------------------------------------------------------ #
    # public getters                                                     #
    # ------------------------------------------------------------------ #
    def X(self, *, split=None, fold_id=None, groups=None, branch=None,
           sources: Optional[Union[bool, str, Sequence[str]]] = None,
           augment: AugPolicy = "all",
           feature_shape: FeatShape = "concatenate") -> np.ndarray:

        meta = self._filter(split=split, fold_id=fold_id, groups=groups,
                            branch=branch, augment=augment)

        if meta.is_empty():
            return np.empty((0, 0), dtype=np.float32)

        # --- choose sources ----------------------------------------- #
        if sources is True or sources is None:
            kept = meta
        elif sources is False:
            return np.empty((meta.height, 0), dtype=np.float32)
        elif isinstance(sources, str):
            kept = meta.filter(pl.col("source") == sources)
        else:
            kept = meta.filter(pl.col("source").is_in(list(sources)))

        # --- fetch spectra ------------------------------------------ #
        vecs = self._take_specs(kept["row_idx"].to_numpy())

        # concatenate/interlace/reshape per feature_shape ------------- #
        if feature_shape == "concatenate":
            # 1. quelles colonnes définissent une 'clé observation' ?
            key_cols = ["sample_id", "processing", "augmentation_id",
                        "branch", "split", "fold_id"]
            keys = kept.select(key_cols)

            # 2. ordre stable des sources  & taille max par source
            srcs_order = sorted(set(kept["source"]))
            max_feat = {s: max(len(v) for v, s2 in zip(vecs, kept["source"])
                               if s2 == s) for s in srcs_order}

            # 3. dictionnaire {key_tuple -> dict(source -> vector)}
            rows: Dict[Tuple, Dict[str, np.ndarray]] = {}
            for key_vals, src, vec in zip(keys.iter_rows(), kept["source"], vecs):
                k = tuple(key_vals)
                rows.setdefault(k,
                                {s: np.full(max_feat[s], np.nan, np.float32)
                                 for s in srcs_order})
                rows[k][src][:len(vec)] = vec

            # 4. concatène chaque source dans l’ordre & empile les lignes
            matrix = np.vstack([
                np.hstack([rows[k][s] for s in srcs_order])
                for k in sorted(rows)            # ordre déterministe
            ])
            return matrix
        else:
            # Build 3-D tensor [obs, source, feat] first
            srcs   = kept["source"].to_list()
            uniq   = {s: i for i, s in enumerate(sorted(set(srcs)))}
            max_f  = max(len(v) for v in vecs)
            tensor = np.full((kept.height, len(uniq), max_f), np.nan,
                             dtype=np.float32)
            for i, (s, v) in enumerate(zip(srcs, vecs)):
                tensor[i, uniq[s], :len(v)] = v
            if feature_shape == "2d":
                return tensor
            if feature_shape == "transpose2d":
                return np.transpose(tensor, (0, 2, 1))
            # "interlace"
            return tensor.reshape(tensor.shape[0], -1)

    def X_with_labels(self, **kw) -> Tuple[np.ndarray, np.ndarray]:
        meta   = self._filter(**{k: kw[k] for k in kw if k in
                                 {"split", "fold_id", "groups", "branch", "augment"}})
        labels = meta["obs_id"].to_numpy()
        return self.X(**kw), labels

    def y(self, *, encode_labels=False, **kw) -> np.ndarray:
        meta = self._filter(**kw)
        if "target" not in meta.columns:
            return np.array([])
        y = meta["target"].to_numpy()
        if encode_labels and y.dtype.kind in "UO":
            from sklearn.preprocessing import LabelEncoder
            enc = LabelEncoder()
            y   = enc.fit_transform(y)
        return y

    # ------------------------------------------------------------------ #
    # mutators                                                           #
    # ------------------------------------------------------------------ #
    def add_split(self, labels: Sequence[Any]):
        self.meta = self.meta.with_columns(split=pl.Series("split", labels)[self.meta["sample_id"]])

    def add_folds(self, labels: Sequence[Any]):
        self.meta = self.meta.with_columns(fold_id=pl.Series(labels)[self.meta["sample_id"]])

    def add_groups(self, name: str, labels: Sequence[Any]):
        self.meta = self.meta.with_columns(pl.Series(f"group_{name}", labels)[self.meta["sample_id"]])

    # --- add_processing = sample-level augmentation ------------------- #
    def augment_samples(self,
                        new_samples: Dict[str, Sequence[np.ndarray]],
                        parent_obs: Sequence[int],
                        name: str = "aug",
                        branch: Any = None):
        """
        Duplicate observations referenced by `parent_obs`, replace their spectra
        with `new_samples[src][i]`, push into Arrow table, extend meta.
        """
        if not new_samples:
            raise ValueError("new_samples cannot be empty.")
        n_new = len(parent_obs)
        # 1/ build Arrow column
        new_vecs = []
        rows     = []
        next_rid = int(self.tbl.num_rows)
        for i, obs in enumerate(parent_obs):
            parent = self.meta.row(obs)
            for src, vec_list in new_samples.items():
                vec = np.asarray(vec_list[i], dtype=np.float32)
                new_vecs.append(vec)
                rows.append({
                    **parent,
                    "obs_id":        int(self.meta["obs_id"].max()) + 1 + i,
                    "processing":    name,
                    "augmentation_id": parent["augmentation_id"] + 1,
                    "branch":        branch,
                    "row_idx":       next_rid
                })
                next_rid += 1
        self.tbl  = pa.concat_tables([self.tbl,
                                      pa.Table.from_pydict({"spec":
                                      pa.array(new_vecs,
                                               type=pa.list_(pa.float32()))})])
        self.meta = pl.concat([self.meta, pl.DataFrame(rows)])

    # --- add_features -------------------------------------------------- #
    def add_features(self, source: str, new_feat: np.ndarray,
                     feature_names: Optional[Sequence[str]] = None):
        """
        Append features horizontally for *all* observations of a source.
        new_feat shape = (n_obs_source, n_new_feat)
        """
        mask = self.meta["source"] == source
        row_idx = self.meta.filter(mask)["row_idx"].to_numpy()
        old_vecs = self._take_specs(row_idx)
        if new_feat.shape[0] != len(old_vecs):
            raise ValueError("new_feat rows mismatch.")
        merged = [np.hstack([o, n.astype(np.float32)]) for o, n in
                  zip(old_vecs, new_feat)]
        col = pa.array(merged, type=pa.list_(pa.float32()))
        self.tbl = self.tbl.set_column(0, "spec",
                   pc.if_else(pc.is_in(pa.array(row_idx),
                                       value_set=pa.array(row_idx)),
                              col, self.tbl.column("spec")))

    # ------------------------------------------------------------------ #
    # predictions                                                        #
    # ------------------------------------------------------------------ #
    def add_prediction(self, model_id: str, y_pred: Sequence[float], **kw):
        meta = self._filter(**kw)
        if len(y_pred) != meta.height:
            raise ValueError("y_pred length mismatch after filtering.")
        pred_col = np.full(self.meta.height, np.nan, dtype=float)
        pred_col[meta["obs_id"].to_numpy()] = y_pred
        self.meta = self.meta.with_columns(
            pl.Series(f"pred_{model_id}", pred_col))

# ---------------------------------------------------------------------- #
# Fix: Each sample should be a list of augmentations (even if only one)
toy = {
    "nirs":  [ [[1,2,3]], [[1.1,2.1,3.1]], [[4,5,6]] ],
    "raman": [ [[10,11]], [[10.1,11.1]], [[40,41]] ]
}
ss = SpectraSet.from_dict(toy, target=[0, 1, 2])
print(ss.meta)
print(ss.X(sources="raman"))
print(ss.y(sources="raman"))



shape: (6, 10)
┌────────┬───────────┬────────┬────────────┬───┬───────┬─────────┬─────────┬────────┐
│ obs_id ┆ sample_id ┆ source ┆ processing ┆ … ┆ split ┆ fold_id ┆ row_idx ┆ target │
│ ---    ┆ ---       ┆ ---    ┆ ---        ┆   ┆ ---   ┆ ---     ┆ ---     ┆ ---    │
│ i64    ┆ i64       ┆ str    ┆ str        ┆   ┆ null  ┆ null    ┆ i64     ┆ i64    │
╞════════╪═══════════╪════════╪════════════╪═══╪═══════╪═════════╪═════════╪════════╡
│ 0      ┆ 0         ┆ nirs   ┆ raw        ┆ … ┆ null  ┆ null    ┆ 0       ┆ 0      │
│ 1      ┆ 1         ┆ nirs   ┆ raw        ┆ … ┆ null  ┆ null    ┆ 1       ┆ 1      │
│ 2      ┆ 2         ┆ nirs   ┆ raw        ┆ … ┆ null  ┆ null    ┆ 2       ┆ 2      │
│ 3      ┆ 0         ┆ raman  ┆ raw        ┆ … ┆ null  ┆ null    ┆ 3       ┆ 0      │
│ 4      ┆ 1         ┆ raman  ┆ raw        ┆ … ┆ null  ┆ null    ┆ 4       ┆ 1      │
│ 5      ┆ 2         ┆ raman  ┆ raw        ┆ … ┆ null  ┆ null    ┆ 5       ┆ 2      │
└────────┴───────────┴────────┴────────

TypeError: SpectraSet._filter() got an unexpected keyword argument 'sources'

In [72]:
from __future__ import annotations

import numpy as np
import polars as pl
from typing import Any, Sequence, Union

class SpectraDataset:
    """
    Dataset ultra-rapide pour ML basé sur polars/Arrow.
    
    Chaque spectre est stocké dans une colonne 'spectrum' de type List[Float64],
    et un identifiant unique 'row_id' est assigné à chaque entrée.
    Les index par défaut sont gérés dynamiquement, avec broadcast ou vecteur dédié.
    """

    DEFAULT_INDEX = (
        "origin", "sample", "type", "set",
        "processing", "augmentation", "branch"
    )
    _DEFAULT_VALUES: dict[str, Any] = {
        "set":        "train",
        "processing": "raw",
        "augmentation": "raw",
        "branch":       0,
    }

    def __init__(self) -> None:
        self.df: pl.DataFrame | None = None
        self._next_id: int = 0

    def _mask(self, **filters: Any) -> pl.Expr:
        exprs = []
        for key, val in filters.items():
            if isinstance(val, (list, tuple, set, np.ndarray)):
                exprs.append(pl.col(key).is_in(val))
            else:
                exprs.append(pl.col(key) == val)
        return exprs[0] if len(exprs) == 1 else pl.all_horizontal(exprs)

    def _select(self, **filters: Any) -> pl.DataFrame:
        if self.df is None:
            return pl.DataFrame()
        return self.df.filter(self._mask(**filters)) if filters else self.df

    def add_spectra(
        self,
        spectra: Sequence[Sequence[float]],
        target: Sequence[Any] | None = None,
        **index_values: Union[Any, Sequence[Any]],
    ) -> None:
        n = len(spectra)
        tgt = target if target is not None else [None] * n
        if len(tgt) != n:
            raise ValueError("La longueur de 'target' ne correspond pas au nombre de spectres")

        data: dict[str, list[Any]] = {}
        for k in self.DEFAULT_INDEX:
            if k in index_values:
                v = index_values[k]
                if isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
                    if len(v) != n:
                        raise ValueError(f"Index '{k}' longueur {len(v)} != {n}")
                    data[k] = list(v)
                else:
                    data[k] = [v] * n
            else:
                default = self._DEFAULT_VALUES.get(k)
                data[k] = [default] * n

        data["spectrum"] = [list(vec) for vec in spectra]
        data["target"]   = list(tgt)
        data["row_id"]   = list(range(self._next_id, self._next_id + n))
        self._next_id += n

        new_df = pl.DataFrame(data)
        if self.df is None:
            self.df = new_df
        else:
            self.df = pl.concat([self.df, new_df], how="vertical")

    def change_spectra(
        self,
        new_spectra: Sequence[Sequence[float]],
        **filters: Any,
    ) -> None:
        if self.df is None:
            return
        mask = self._mask(**filters)
        n = self.df.filter(mask).height
        if n != len(new_spectra):
            raise ValueError("new_spectra count mismatch")
        self.df = self.df.with_columns(
            pl.when(mask)
              .then(pl.Series(new_spectra, dtype=pl.List(pl.Float64)))
              .otherwise(pl.col("spectrum"))
              .alias("spectrum")
        )

    def add_tag(
        self,
        tag_name: str,
        tag_values: Union[Any, Sequence[Any]],
        **filters: Any,
    ) -> None:
        """
        Ajoute une nouvelle colonne `tag_name`.
        - Si `tag_values` scalaire → broadcast sur toutes les lignes filtrées.
        - Si `tag_values` séquence de longueur m → assigné séquentiellement aux m lignes filtrées.
        """
        if self.df is None:
            raise ValueError("Dataset vide.")
        if tag_name in self.df.columns:
            raise ValueError(f"Le tag '{tag_name}' existe déjà ; utilisez set_tag.")

        df = self.df
        mask_series = df.select(self._mask(**filters).alias("_mask"))
        mask = mask_series.to_series()
        total = len(mask)
        # Construire liste complète
        full: list[Any] = [None] * total
        if isinstance(tag_values, Sequence) and not isinstance(tag_values, (str, bytes)):
            values = list(tag_values)
            if sum(mask) != len(values):
                raise ValueError("Le nombre de 'tag_values' ne correspond pas aux lignes filtrées")
            it = iter(values)
            for i, m in enumerate(mask):
                if m:
                    full[i] = next(it)
        else:
            for i, m in enumerate(mask):
                if m:
                    full[i] = tag_values
        self.df = df.with_columns(pl.Series(tag_name, full))

    def set_tag(
        self,
        tag_name: str,
        new_value: Union[Any, Sequence[Any]],
        **filters: Any,
    ) -> None:
        """
        Change les valeurs d'un tag existant.
        - Si `new_value` scalaire → broadcast.
        - Si `new_value` séquence de longueur m → assigné séquentiellement aux m lignes filtrées.
        """
        if self.df is None:
            raise ValueError("Dataset vide.")
        if tag_name not in self.df.columns:
            raise KeyError(f"Tag '{tag_name}' inexistant.")

        df = self.df
        mask_series = df.select(self._mask(**filters).alias("_mask")).to_series()
        total = len(mask_series)
        full = df.get_column(tag_name).to_list()
        if isinstance(new_value, Sequence) and not isinstance(new_value, (str, bytes)):
            vals = list(new_value)
            if sum(mask_series) != len(vals):
                raise ValueError("Le nombre de 'new_value' ne correspond pas aux lignes filtrées")
            it = iter(vals)
            for i, m in enumerate(mask_series):
                if m:
                    full[i] = next(it)
        else:
            for i, m in enumerate(mask_series):
                if m:
                    full[i] = new_value
        self.df = df.with_columns(pl.Series(tag_name, full))

    def X(
        self,
        pad: bool = False,
        pad_value: float = np.nan,
        as_arrow: bool = False,
        return_ids: bool = True,
        **filters: Any,
    ) -> Union[np.ndarray, tuple[Union[np.ndarray, "pyarrow.ListArray"], np.ndarray]]:
        sub = self._select(**filters)
        spectra = sub["spectrum"].to_list()

        if as_arrow and not pad:
            import pyarrow as pa
            arr = pa.array(spectra, type=pa.list_(pa.float64()))
            ids = sub["row_id"].to_numpy() if return_ids else None
            return (arr, ids) if return_ids else arr

        if pad:
            max_len = max((len(v) for v in spectra), default=0)
            out = np.full((len(spectra), max_len), pad_value, dtype=np.float64)
            for i, vec in enumerate(spectra):
                out[i, : len(vec)] = vec
        else:
            out = np.array(spectra, dtype=object)

        if return_ids:
            ids = sub["row_id"].to_numpy()
            return out, ids
        return out

    def y(self, **filters: Any) -> np.ndarray:
        sub = self._select(**filters)
        return sub["target"].to_numpy()

    def to_arrow_table(self) -> "pyarrow.Table":
        return self.df.to_arrow()  # type: ignore

    def __len__(self) -> int:
        return self.df.height if self.df is not None else 0

    def __repr__(self) -> str:
        cols = self.df.columns if self.df is not None else []
        return f"SpectraDataset(n={len(self)}, cols={cols})"


In [74]:
ds = SpectraDataset()

# ajout de 2 spectres
ds.add_spectra(
    spectra=[[0.1, 0.2, 0.3], [1.0, 1.1]],
    target=[0.25, 1.05],
    origin=1,
    sample=[1, 2],
    type="nirs"
)

# récupérer X/y côté train
X, ids = ds.X(pad=True)  # pad à la longueur max
y = ds.y()

# taguer un groupe de spectres
ds.add_tag("groupe", ["A", "B"], type="nirs")

c = ds._select().to_numpy()
a, b = ds.X(pad=False)
t = ds.y()
for i in range(len(a)):
    print(a[i], b[i], t[i])
    print(c[i])




[0.1, 0.2, 0.3] 0 0.25
[1 1 'nirs' 'train' 'raw' 'raw' 0 array([0.1, 0.2, 0.3]) 0.25 0 'A']
[1.0, 1.1] 1 1.05
[1 2 'nirs' 'train' 'raw' 'raw' 0 array([1. , 1.1]) 1.05 1 'B']
