# K-means avec Random Forest

Ce notebook regroupe tout le code nécessaire pour entraîner une Random Forest à repérer les pixels du fond puis appliquer un K-means amélioré sur les structures d’images en discriminant les pixels du fond. 


### Imports et fonctions utilitaires
La cellule qui suit rassemble toutes les bibliothèques nécessaires (NumPy, Pandas, scikit-learn, scikit-image…) normalement

In [1]:
"""Core improved K-means segmentation primitives.
This cell is a verbatim port of `improved_kmeans_segmentation.py`
so that the notebook can run standalone without external modules.
"""
from __future__ import annotations

import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence, Tuple

import numpy as np
import pandas as pd
from PIL import Image
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.metrics import adjusted_rand_score
from sklearn.utils.validation import check_random_state
from skimage import exposure, morphology, util
from skimage.feature import hessian_matrix, structure_tensor

try:  # scikit-image <=0.22
    from skimage.feature import hessian_matrix_eigvals
except ImportError:  # scikit-image >=0.23
    from skimage.feature import hessian_matrix_eigenvalues as hessian_matrix_eigvals

try:
    from skimage.feature import structure_tensor_eigvals
except ImportError:
    from skimage.feature import structure_tensor_eigenvalues as structure_tensor_eigvals

from skimage.filters import gaussian, laplace, sobel
from skimage.segmentation import relabel_sequential


def _structure_tensor_eigenvalues(
    Axx: np.ndarray, Axy: np.ndarray, Ayy: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    """Compat wrapper to support multiple scikit-image signatures."""

    try:
        return structure_tensor_eigvals(Axx, Axy, Ayy)
    except TypeError:
        return structure_tensor_eigvals((Axx, Axy, Ayy))


def _hessian_matrix_eigenvalues(
    Hxx: np.ndarray, Hxy: np.ndarray, Hyy: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    try:
        return hessian_matrix_eigvals(Hxx, Hxy, Hyy)
    except TypeError:
        return hessian_matrix_eigvals((Hxx, Hxy, Hyy))




## Data Loading Utilities
Cette section définit une fonction `load_png_stack` qui lit un dossier d’images PNG en niveaux de gris, les trie dans l’ordre numérique et les empile dans un tableau NumPy. On obtient d’un côté la liste des chemins (pratique pour tracer ou sauvegarder) et de l’autre un cube `nombre_de_slices × hauteur × largeur` qui servira de base au reste du pipeline.

In [2]:
def load_png_stack(
    img_dir: str | Path, limit: int | None = None
) -> Tuple[List[Path], np.ndarray]:
    """Load a stack of grayscale PNG slices sorted by numeric stem."""

    img_dir = Path(img_dir)
    pngs = sorted(
        img_dir.glob("*.png"),
        key=lambda f: int("".join(ch for ch in f.stem if ch.isdigit()) or 0),
    )
    if not pngs:
        raise FileNotFoundError(f"No PNG image found in {img_dir}")

    if limit is not None:
        pngs = pngs[:limit]

    shapes = {Image.open(p).size for p in pngs}
    if len(shapes) != 1:
        raise ValueError(f"Images with different shapes were detected: {shapes}")

    arrays: List[np.ndarray] = []
    for path in pngs:
        with Image.open(path) as img:
            arrays.append(np.array(img.convert("L"), dtype=np.uint8))

    stack = np.stack(arrays, axis=0)
    return pngs, stack

## Feature Engineering
Ici on transforme chaque image en un jeu de caractéristiques riche : intensité, flous gaussiens multi-échelles, gradients, Laplacien, tenseur de structure, Hessien, variance locale et coordonnées normalisées. Tout est standardisé pour que les futures étapes de clustering reçoivent une matrice `pixels × features` prête à l’emploi.

In [3]:
@dataclass
class FeatureBundle:
    matrix: np.ndarray
    maps: Dict[str, np.ndarray]


def _prepare_slice(raw_slice: np.ndarray, clip_limit: float = 0.015) -> np.ndarray:
    """Normalize slice to float32 within [0, 1] and enhance contrast."""

    img = raw_slice.astype(np.float32) / 255.0
    img = exposure.equalize_adapthist(img, clip_limit=clip_limit)
    return img.astype(np.float32)


def build_feature_bundle(
    raw_slice: np.ndarray, *, sigmas: Sequence[float] = (0.8, 1.6, 2.4)
) -> FeatureBundle:
    """Create a per-pixel feature stack for K-means segmentation."""

    img = _prepare_slice(raw_slice)
    h, w = img.shape

    features: List[np.ndarray] = []
    maps: Dict[str, np.ndarray] = {}

    maps["intensity"] = img
    features.append(img)

    for sigma in sigmas:
        smooth = gaussian(img, sigma=sigma, preserve_range=True)
        maps[f"gauss_{sigma:.1f}"] = smooth
        features.append(smooth)

    grad_mag = sobel(img)
    maps["gradient"] = grad_mag
    features.append(grad_mag)

    lap = np.abs(laplace(img, ksize=3))
    maps["laplacian"] = lap
    features.append(lap)

    Axx, Axy, Ayy = structure_tensor(img, sigma=1.0, mode="reflect")
    l1, l2 = _structure_tensor_eigenvalues(Axx, Axy, Ayy)
    maps["st_eig1"], maps["st_eig2"] = l1, l2
    features.extend([l1, l2])

    Hxx, Hxy, Hyy = hessian_matrix(img, sigma=1.8, order="rc", mode="reflect")
    h1, h2 = _hessian_matrix_eigenvalues(Hxx, Hxy, Hyy)
    maps["hess_eig1"], maps["hess_eig2"] = h1, h2
    features.extend([h1, h2])

    gauss1 = gaussian(img, sigma=1.2, preserve_range=True)
    gauss2 = gaussian(img ** 2, sigma=1.2, preserve_range=True)
    variance = np.clip(gauss2 - gauss1 ** 2, 0.0, None)
    maps["variance"] = variance
    features.append(variance)

    yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
    xx_norm = xx / max(w - 1, 1)
    yy_norm = yy / max(h - 1, 1)
    radius = np.sqrt((xx_norm - 0.5) ** 2 + (yy_norm - 0.5) ** 2)
    border = np.minimum.reduce([xx_norm, 1 - xx_norm, yy_norm, 1 - yy_norm])

    maps["x_norm"], maps["y_norm"], maps["radius"], maps["border"] = (
        xx_norm,
        yy_norm,
        radius,
        border,
    )
    features.extend([xx_norm, yy_norm, radius, border])

    stack = np.stack(features, axis=-1)
    flat = stack.reshape(-1, stack.shape[-1]).astype(np.float32)
    flat = (flat - flat.mean(axis=0)) / (flat.std(axis=0) + 1e-6)

    return FeatureBundle(matrix=flat, maps=maps)

## Clustering & Post-processing
Cette partie regroupe le cœur du K-means : on entraîne soit un K-means classique soit un MiniBatchKMeans (plus rapide) sur les caractéristiques des pixels (intensité etc), on identifie le cluster de fond via des heuristiques simples, puis on nettoie les étiquettes (suppression des petits objets, fermeture morphologique, remplissage des trous). Les fonctions `segment_slice` et `segment_stack` appliquent le tout à une image ou à une pile entière.

In [4]:
def _fit_clusterer(
    X: np.ndarray, n_clusters: int, subsample: int, random_state: int, method: str
) -> Tuple[np.ndarray, np.ndarray]:
    rng = check_random_state(random_state)
    n_samples = X.shape[0]

    if method == "kmeans":
        estimator = KMeans(
            n_clusters=n_clusters,
            n_init="auto",
            random_state=random_state,
            max_iter=400,
        )
        estimator.fit(X)
    else:
        choose = min(subsample, n_samples)
        idx = rng.choice(n_samples, choose, replace=False)
        estimator = MiniBatchKMeans(
            n_clusters=n_clusters,
            batch_size=4096,
            n_init="auto",
            reassignment_ratio=0.01,
            max_iter=200,
            random_state=random_state,
        )
        estimator.fit(X[idx])

    labels = estimator.predict(X)
    return labels, getattr(estimator, "cluster_centers_", None)


def _pick_background_label(
    labels: np.ndarray, maps: Dict[str, np.ndarray], margin: int = 6
) -> int:
    h, w = maps["intensity"].shape
    label_img = labels.reshape(h, w)
    unique_labels = np.unique(label_img)

    border_mask = np.zeros_like(label_img, dtype=bool)
    border_mask[:margin, :] = True
    border_mask[-margin:, :] = True
    border_mask[:, :margin] = True
    border_mask[:, -margin:] = True

    grad = maps["gradient"]
    intensity = maps["gauss_0.8"] if "gauss_0.8" in maps else maps["intensity"]

    best_label = int(unique_labels[0])
    best_score = -np.inf

    for lbl in unique_labels:
        mask = label_img == lbl
        area_ratio = mask.mean()
        if area_ratio == 0:
            continue

        border_ratio = mask[border_mask].mean()
        grad_mean = float(grad[mask].mean())
        intensity_mean = float(intensity[mask].mean())

        score = (
            1.8 * border_ratio
            + 0.7 * area_ratio
            - 1.1 * grad_mean
            - 0.3 * intensity_mean
        )

        if score > best_score:
            best_score = score
            best_label = int(lbl)

    return best_label


def _post_process(
    label_img: np.ndarray, maps: Dict[str, np.ndarray], *, min_size: int = 80, hole_size: int = 96
) -> np.ndarray:
    h, w = label_img.shape
    cleaned = np.zeros((h, w), dtype=np.uint8)
    next_label = 1

    for lbl in np.unique(label_img):
        if lbl == 0:
            continue

        mask = label_img == lbl
        mask = morphology.remove_small_objects(mask, min_size=min_size)
        if not mask.any():
            continue

        mask = morphology.binary_closing(mask, morphology.disk(2))
        mask = morphology.remove_small_holes(mask, area_threshold=hole_size)

        if mask.any():
            cleaned[mask] = next_label
            next_label += 1

    return cleaned


def segment_slice(
    raw_slice: np.ndarray,
    *,
    k_structures: int = 12,
    extra_background_clusters: int = 2,
    subsample: int = 120_000,
    random_state: int = 42,
    method: str = "minibatch",
) -> Tuple[np.ndarray, FeatureBundle]:
    bundle = build_feature_bundle(raw_slice)
    k_total = k_structures + extra_background_clusters

    labels, _ = _fit_clusterer(bundle.matrix, k_total, subsample, random_state, method)
    h, w = raw_slice.shape

    label_img = labels.reshape(h, w)
    background_label = _pick_background_label(labels, bundle.maps)
    label_img[label_img == background_label] = 0

    label_img, _, _ = relabel_sequential(label_img)

    processed = _post_process(label_img, bundle.maps)
    return processed, bundle


def segment_stack(
    stack: np.ndarray,
    *,
    k_structures: int = 12,
    method: str = "minibatch",
    subsample: int = 120_000,
    random_state: int = 42,
) -> np.ndarray:
    preds: List[np.ndarray] = []
    start = time.time()

    for idx, slc in enumerate(stack):
        pred, _ = segment_slice(
            slc,
            k_structures=k_structures,
            extra_background_clusters=2,
            subsample=subsample,
            random_state=random_state + idx,
            method=method,
        )
        preds.append(pred)

        if (idx + 1) % 25 == 0:
            elapsed = time.time() - start
            rate = (idx + 1) / max(elapsed, 1e-6)
            print(f"Processed {idx + 1} slices in {elapsed:.1f}s ({rate:.2f} slices/s)")

    return np.stack(preds, axis=0)

## Evaluation Helpers
Les fonctions de cette section servent à sauvegarder les résultats et à mesurer leur qualité. `save_predictions` écrit un CSV/NPY prêt à être soumis au challenge, `load_ground_truth` reconstitue les masques à partir du CSV d’annotations et `evaluate_ari` calcule l’Adjusted Rand Index en ignorant le fond. `run_demo` combine le tout pour un test rapide sur un sous-ensemble.

In [5]:
def save_predictions(
    preds: np.ndarray, out_csv: str | Path, save_npy: bool = True
) -> None:
    out_csv = Path(out_csv)
    flat = preds.transpose(1, 2, 0).reshape(-1, preds.shape[0])
    columns = [f"{idx}.png" for idx in range(preds.shape[0])]
    index = [f"Pixel {idx}" for idx in range(flat.shape[0])]

    df = pd.DataFrame(flat, index=index, columns=columns)
    df.index.name = ""
    df.to_csv(out_csv)
    print(f"Saved CSV predictions to {out_csv} with shape {df.shape}")

    if save_npy:
        np.save(out_csv.with_suffix(".npy"), preds.astype(np.uint8))
        print(f"Saved NumPy predictions to {out_csv.with_suffix('.npy')}")


def load_ground_truth(path: str | Path, *, expected_hw: Tuple[int, int]) -> np.ndarray:
    df = pd.read_csv(path)

    numeric_df = df.apply(pd.to_numeric, errors="coerce")
    numeric_df = numeric_df.dropna(axis=1, how="all")

    if numeric_df.isnull().values.any():
        raise ValueError(
            "Ground-truth CSV contains non-numeric values even after coercion. "
            "Please ensure labels are stored as integers."
        )

    flat = numeric_df.to_numpy(dtype=np.int16, copy=False).T
    h, w = expected_hw
    return flat.reshape(flat.shape[0], h, w)


def evaluate_ari(preds: np.ndarray, y_true: np.ndarray) -> Dict[str, float]:
    if preds.shape != y_true.shape:
        n_min = min(preds.shape[0], y_true.shape[0])
        preds = preds[:n_min]
        y_true = y_true[:n_min]
        print(f"Aligned prediction and ground-truth stacks to {n_min} slices")

    scores: List[float] = []
    for pred, gt in zip(preds, y_true):
        mask = (pred > 0) & (gt > 0)
        if mask.sum() < 20:
            scores.append(0.0)
            continue
        scores.append(adjusted_rand_score(gt[mask].ravel(), pred[mask].ravel()))

    stats = {
        "mean": float(np.mean(scores)) if scores else 0.0,
        "median": float(np.median(scores)) if scores else 0.0,
        "min": float(np.min(scores)) if scores else 0.0,
        "max": float(np.max(scores)) if scores else 0.0,
    }

    print(
        "ARI statistics -> "
        f"mean: {stats['mean']:.4f}, median: {stats['median']:.4f}, "
        f"min: {stats['min']:.4f}, max: {stats['max']:.4f}"
    )
    return stats


## Foreground-Boosted Pipeline (RandomForest + K-means)
Cette section ajoute la couche « supervisée » : une Random Forest apprend à distinguer le fond des structures à partir d’un échantillon de slices annotées. On utilise ensuite ce masque de confiance pour ne segmenter via K-means que les pixels vraiment intéressants. Les fonctions définies ici (configurations, entraînement RF, segmentation et visualisation) permettent de lancer un pipeline complet sur un jeu d’apprentissage, d’évaluer la qualité avec l’ARI du challenge et d’exporter des aperçus visuels.

> **Choix de l’algorithme de clustering** : c’est la valeur passée via `seg_cfg` (par exemple `seg_config = ImprovedKMeansHybridConfig(method="kmeans")`) qui est utilisée par `run_demo`. Le champ `method` du dataclass ne fournit qu’une valeur par défaut.


In [6]:
"""Foreground-boosted segmentation pipeline (RandomForest + improved K-means)."""
from dataclasses import dataclass, field
from typing import Sequence

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import adjusted_rand_score
from sklearn.utils.validation import check_random_state
from skimage.segmentation import relabel_sequential


@dataclass(slots=True)
class ImprovedKMeansHybridConfig:
    k_structures: int = 14
    extra_background_clusters: int = 2
    subsample: int = 250_000
    random_state: int = 7
    method: str = "minibatch" # si on ne modifie rien, la methode par defaut sera la moins couteuse donc minibatch
    sigmas: Sequence[float] = (0.8, 1.6, 2.4)
    bg_margin: int = 6
    post_min_size: int = 80
    post_hole_size: int = 96


@dataclass(slots=True)
class ForegroundClassifierConfig:
    n_training_slices: int = 40
    pixels_per_class: int = 8_000
    random_state: int = 7
    n_estimators: int = 400
    max_depth: int | None = 20
    min_samples_leaf: int = 10
    probability_threshold: float = 0.5


def _sample_training_pixels(
    stack: np.ndarray,
    masks: np.ndarray,
    seg_cfg: ImprovedKMeansHybridConfig,
    clf_cfg: ForegroundClassifierConfig,
    *,
    chosen_indices: Sequence[int] | None = None,
) -> Tuple[np.ndarray, np.ndarray]:
    rng = check_random_state(clf_cfg.random_state)
    if chosen_indices is None:
        n_slices = stack.shape[0]
        choose = min(clf_cfg.n_training_slices, n_slices)
        chosen_indices = rng.choice(n_slices, size=choose, replace=False)
    else:
        chosen_indices = np.array(chosen_indices, dtype=int)

    features_list: List[np.ndarray] = []
    labels_list: List[np.ndarray] = []

    for idx in chosen_indices:
        raw_slice = stack[idx]
        gt_slice = masks[idx]

        bundle = build_feature_bundle(raw_slice, sigmas=seg_cfg.sigmas)
        flat_features = bundle.matrix
        flat_labels = (gt_slice.ravel() > 0).astype(np.uint8)

        for cls_val in (0, 1):
            cls_idx = np.flatnonzero(flat_labels == cls_val)
            if cls_idx.size == 0:
                continue

            sample_size = min(cls_idx.size, clf_cfg.pixels_per_class)
            sampled_idx = rng.choice(cls_idx, size=sample_size, replace=False)

            features_list.append(flat_features[sampled_idx])
            labels_list.append(flat_labels[sampled_idx])

    if not features_list:
        raise RuntimeError("No training pixels were sampled. Check ground truth availability.")

    X_train = np.concatenate(features_list, axis=0)
    y_train = np.concatenate(labels_list, axis=0)

    return X_train.astype(np.float32), y_train.astype(np.uint8)


def train_foreground_classifier(
    stack: np.ndarray,
    masks: np.ndarray,
    seg_cfg: ImprovedKMeansHybridConfig,
    clf_cfg: ForegroundClassifierConfig,
    *,
    chosen_indices: Sequence[int] | None = None,
) -> RandomForestClassifier:
    start = time.time()
    X_train, y_train = _sample_training_pixels(
        stack,
        masks,
        seg_cfg,
        clf_cfg,
        chosen_indices=chosen_indices,
    )

    clf = RandomForestClassifier(
        n_estimators=clf_cfg.n_estimators,
        max_depth=clf_cfg.max_depth,
        min_samples_leaf=clf_cfg.min_samples_leaf,
        n_jobs=-1,
        random_state=clf_cfg.random_state,
        oob_score=False,
    )
    clf.fit(X_train, y_train)

    elapsed = time.time() - start
    print(
        f"Trained foreground classifier on {len(y_train):,} pixels "
        f"(class balance: {y_train.sum():,} foreground / {(y_train==0).sum():,} background) "
        f"in {elapsed:.1f}s"
    )

    return clf


@dataclass(slots=True)
class ForegroundBoostedImprovedKMeans:
    seg_cfg: ImprovedKMeansHybridConfig = field(default_factory=ImprovedKMeansHybridConfig)
    clf_cfg: ForegroundClassifierConfig = field(default_factory=ForegroundClassifierConfig)
    classifier: RandomForestClassifier | None = None

    def fit(
        self,
        stack: np.ndarray,
        masks: np.ndarray,
        *,
        chosen_indices: Sequence[int] | None = None,
    ) -> "ForegroundBoostedImprovedKMeans":
        self.classifier = train_foreground_classifier(
            stack,
            masks,
            self.seg_cfg,
            self.clf_cfg,
            chosen_indices=chosen_indices,
        )
        return self

    def _segment_single(self, raw_slice: np.ndarray, slice_idx: int) -> np.ndarray:
        if self.classifier is None:
            raise RuntimeError("Classifier must be trained before inference.")

        cfg = self.seg_cfg
        bundle = build_feature_bundle(raw_slice, sigmas=cfg.sigmas)
        h, w = raw_slice.shape

        proba = self.classifier.predict_proba(bundle.matrix)[:, 1]
        structure_mask = proba.reshape(h, w) >= self.clf_cfg.probability_threshold

        k_total = cfg.k_structures + cfg.extra_background_clusters
        labels, _ = _fit_clusterer(
            bundle.matrix,
            k_total,
            cfg.subsample,
            cfg.random_state + slice_idx,
            cfg.method,
        )

        label_img = labels.reshape(h, w)
        background_label = _pick_background_label(labels, bundle.maps, margin=cfg.bg_margin)
        label_img[label_img == background_label] = 0
        label_img[~structure_mask] = 0

        label_img, _, _ = relabel_sequential(label_img)
        processed = _post_process(
            label_img,
            bundle.maps,
            min_size=cfg.post_min_size,
            hole_size=cfg.post_hole_size,
        )
        processed[~structure_mask] = 0

        return processed.astype(np.uint8)

    def segment_stack(self, stack: np.ndarray, verbose: bool = True) -> np.ndarray:
        preds: List[np.ndarray] = []
        start = time.time()

        for idx, slc in enumerate(stack):
            pred = self._segment_single(slc, slice_idx=idx)
            preds.append(pred)

            if verbose and (idx + 1) % 25 == 0:
                elapsed = time.time() - start
                rate = (idx + 1) / max(elapsed, 1e-6)
                print(f"Processed {idx + 1} slices in {elapsed:.1f}s ({rate:.2f} slices/s)")

        return np.stack(preds, axis=0)


def evaluate_ari_challenge(preds: np.ndarray, y_true: np.ndarray) -> Dict[str, float]:
    if preds.shape != y_true.shape:
        n_min = min(preds.shape[0], y_true.shape[0])
        preds = preds[:n_min]
        y_true = y_true[:n_min]
        print(f"Aligned prediction and ground-truth stacks to {n_min} slices")

    scores: List[float] = []
    for pred, gt in zip(preds, y_true):
        scores.append(adjusted_rand_score(gt.ravel(), pred.ravel()))

    stats = {
        "mean": float(np.mean(scores)) if scores else 0.0,
        "median": float(np.median(scores)) if scores else 0.0,
        "min": float(np.min(scores)) if scores else 0.0,
        "max": float(np.max(scores)) if scores else 0.0,
    }

    print(
        "ARI statistics (challenge metric) -> "
        f"mean: {stats['mean']:.4f}, median: {stats['median']:.4f}, "
        f"min: {stats['min']:.4f}, max: {stats['max']:.4f}"
    )
    return stats


def save_visual_report(
    raw_stack: np.ndarray,
    gt_stack: np.ndarray,
    preds: np.ndarray,
    eval_files: Sequence[Path],
    *,
    output_dir: str | Path,
    max_examples: int = 6,
) -> List[Path]:
    if preds.size == 0:
        raise ValueError("No predictions available for visualization.")

    try:
        import matplotlib.pyplot as plt
    except ImportError as exc:
        raise RuntimeError("matplotlib is required for visualization") from exc

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    max_examples = max(1, min(int(max_examples), preds.shape[0]))
    per_slice_scores: List[Tuple[int, float]] = []
    for idx, (pred, gt) in enumerate(zip(preds, gt_stack)):
        score = adjusted_rand_score(gt.ravel(), pred.ravel())
        per_slice_scores.append((idx, score))

    per_slice_scores.sort(key=lambda item: item[1])
    score_lookup = {idx: score for idx, score in per_slice_scores}
    bottom = max(1, max_examples // 2)
    top = max_examples - bottom

    selected: List[int] = [idx for idx, _ in per_slice_scores[:bottom]]
    if top > 0:
        selected.extend(idx for idx, _ in per_slice_scores[-top:])

    ordered_unique: List[int] = []
    seen = set()
    for idx in selected:
        if idx not in seen:
            seen.add(idx)
            ordered_unique.append(idx)

    n_rows = len(ordered_unique)
    fig, axes = plt.subplots(n_rows, 3, figsize=(9, 3 * n_rows))
    axes = np.atleast_2d(axes)

    cmap_mask = "nipy_spectral"
    saved_paths: List[Path] = []
    for row, local_idx in enumerate(ordered_unique):
        raw = raw_stack[local_idx]
        gt = gt_stack[local_idx]
        pred = preds[local_idx]
        slice_name = Path(eval_files[local_idx]).name if eval_files else f"slice_{local_idx}"
        score = score_lookup.get(local_idx, float("nan"))

        axes[row, 0].imshow(raw, cmap="gray")
        axes[row, 0].set_title(f"{slice_name}\nRaw")
        axes[row, 1].imshow(gt, cmap=cmap_mask, interpolation="nearest")
        axes[row, 1].set_title("Vérité terrain")
        axes[row, 2].imshow(pred, cmap=cmap_mask, interpolation="nearest")
        axes[row, 2].set_title(f"Prédiction\nARI={score:.3f}")

        for col in range(3):
            axes[row, col].axis("off")

    fig.tight_layout()
    preview_path = output_dir / "foreground_boosted_improved_preview.png"
    fig.savefig(preview_path, dpi=200, bbox_inches="tight")
    plt.close(fig)

    saved_paths.append(preview_path)
    return saved_paths


# ---------------------------------------------------------------------------
# Convenience runner
# ---------------------------------------------------------------------------


def run_demo(
    train_img_dir: str | Path,
    y_train_path: str | Path,
    *,
    seg_cfg: ImprovedKMeansHybridConfig | None = None,
    clf_cfg: ForegroundClassifierConfig | None = None,
    n_slices: int | None = 200,
    output_csv: str | Path = "predictions_foreground_boosted_improved.csv",
    preview_dir: str | Path | None = None,
    preview_max_examples: int = 6,
) -> Dict[str, float]:
    seg_cfg = seg_cfg or ImprovedKMeansHybridConfig()
    clf_cfg = clf_cfg or ForegroundClassifierConfig()

    files, stack = load_png_stack(train_img_dir, limit=n_slices)
    print(f"Loaded {len(files)} slices of shape {stack.shape[1:]} from {train_img_dir}")

    y_true = load_ground_truth(y_train_path, expected_hw=stack.shape[1:])
    if y_true.shape[0] < stack.shape[0]:
        print(
            f"Warning: ground truth has only {y_true.shape[0]} slices; aligning annotations to data."
        )
    y_true = y_true[: stack.shape[0]]

    rng = np.random.RandomState(clf_cfg.random_state)
    total_slices = stack.shape[0]
    train_size = min(clf_cfg.n_training_slices, total_slices)
    if train_size >= total_slices:
        train_size = max(total_slices - 1, 1)
        print(
            "Adjusted training slice count to leave at least one slice for evaluation."
        )

    train_indices = rng.choice(total_slices, size=train_size, replace=False)
    eval_mask = np.ones(total_slices, dtype=bool)
    eval_mask[train_indices] = False
    eval_indices = np.flatnonzero(eval_mask)

    if eval_indices.size == 0:
        raise RuntimeError(
            "Evaluation set is empty. Reduce `n_training_slices` to leave holdout slices."
        )

    train_stack = stack[train_indices]
    train_masks = y_true[train_indices]
    eval_stack = stack[eval_indices]
    eval_masks = y_true[eval_indices]
    eval_files = [files[idx] for idx in eval_indices]

    print(
        f"Training RF on slices {train_indices.tolist()} (count={train_indices.size}), "
        f"evaluating on slices {eval_indices.tolist()} (count={eval_indices.size})."
    )

    segmenter = ForegroundBoostedImprovedKMeans(seg_cfg=seg_cfg, clf_cfg=clf_cfg)
    segmenter.fit(train_stack, train_masks, chosen_indices=range(train_stack.shape[0]))

    preds = segmenter.segment_stack(eval_stack)

    save_predictions(preds, output_csv)
    stats = evaluate_ari_challenge(preds, eval_masks[: preds.shape[0]])

    if preview_dir is not None:
        preview_paths = save_visual_report(
            eval_stack[: preds.shape[0]],
            eval_masks[: preds.shape[0]],
            preds,
            eval_files[: preds.shape[0]],
            output_dir=preview_dir,
            max_examples=preview_max_examples,
        )
        if preview_paths:
            joined = ", ".join(str(p) for p in preview_paths)
            print(f"Saved visual preview(s) to: {joined}")
            stats["preview_paths"] = [str(p) for p in preview_paths]

    return stats


## Utilisation

1. Installez les dépendances nécessaires :
   ```bash
   pip install numpy pandas pillow scikit-learn scikit-image matplotlib
   ```
2. Définissez les chemins d'accès aux dossiers d'images et au CSV des masques.
3. Exécutez la cellule ci-dessous pour lancer un run d'entraînement/évaluation ou
   adaptez-la pour générer des prédictions sur un autre jeu d'images (ex : `X_test`).


Dans la cellule ci-dessus, on peut modifier les parametres du Kmean (ou du minibatch) pour essayer d'optimiser les résultats
Par exemple on modifiant, le nombre de structures à identifier, ou le nombre de pixel à prendre en compte


In [7]:
# Exemple d'exécution (adapter les chemins à votre environnement)
TRAIN_IMG_DIR = Path("/path/vers/X_train/images")
Y_TRAIN_CSV = Path("/path/vers/Y_train.csv")

seg_config = ImprovedKMeansHybridConfig(
    k_structures=14, #initialement 12, on essai 10 ou 14
    extra_background_clusters=2,
    subsample=250_000, #initialement 150_000, on essai 200 000 si possible
    random_state=7,
    method="minibatch",  # mettre "kmeans" si on veut K-mean (attention à l'orthographe)
    sigmas=(0.8, 1.6, 2.4), #initialement 0.8, 1.6, 2.4, on essai peut essayer une echelle plus douce  0.6, 1.2, 2.4, 3.2)
    bg_margin=6, 
    post_min_size=80, 
    post_hole_size=96, 
)

clf_config = ForegroundClassifierConfig(
    n_training_slices=40,
    pixels_per_class=8_000,
    random_state=7,
    n_estimators=400,
    max_depth=20,
    min_samples_leaf=10,
    probability_threshold=0.5,
)

# Lancement d'un run démo : entraînement RF sur un sous-ensemble, évaluation hold-out.
stats = run_demo(
    train_img_dir=TRAIN_IMG_DIR,
    y_train_path=Y_TRAIN_CSV,
    seg_cfg=seg_config,
    clf_cfg=clf_config,
    n_slices=200,
    output_csv="predictions_foreground_boosted_improved.csv",
    preview_dir="reports/foreground_boosted_improved",
    preview_max_examples=6,
)

stats

FileNotFoundError: No PNG image found in /path/vers/X_train/images