Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
bef70df
feat: add embedding dataset build pipeline
vojtech-cifka May 8, 2026
911bec2
feat: add class tresholds and run ids
vojtech-cifka May 8, 2026
1a02395
fix: wrong run id
vojtech-cifka May 8, 2026
08d7ba5
Merge remote-tracking branch 'origin/master' into feature/embedding-d…
vojtech-cifka May 9, 2026
b38465e
feat: add timing
vojtech-cifka May 9, 2026
bfc9578
refactor: use pyarrow to avoid to pandas conversion
vojtech-cifka May 9, 2026
eb213c6
fix: join on keys only
vojtech-cifka May 9, 2026
c92d9a1
fix: typing
vojtech-cifka May 9, 2026
01cc394
fix: add prints
vojtech-cifka May 9, 2026
cad0d37
refactor: use combine chunks
vojtech-cifka May 9, 2026
ae04552
fix: lazy-cast embeddings to large_list and stay in Arrow during join
vojtech-cifka May 9, 2026
82320db
fix: validate label/tissue_prop columns when derive=False
vojtech-cifka May 9, 2026
3b0137f
chore: remove time
vojtech-cifka May 9, 2026
8df47aa
feat: add timing
vojtech-cifka May 10, 2026
926753d
chore: revert to the previous state
vojtech-cifka May 10, 2026
b0e9ba4
feat: add prints
vojtech-cifka May 10, 2026
6a915de
refactor: use discusssed thresholds
vojtech-cifka May 11, 2026
0f50307
refactor: use different labeling strategy
vojtech-cifka May 11, 2026
c421c74
refactor: drop tiles that are covered by two or more distinct labels
vojtech-cifka May 11, 2026
718ec08
fix: format
vojtech-cifka May 11, 2026
389a0a5
chore: update embeddings run id
vojtech-cifka May 11, 2026
11ed4e3
chore: remove timing prints
vojtech-cifka May 11, 2026
d59425a
refactor: use int64
vojtech-cifka May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/data/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dataset:
test_split_filename: "split_mapping/test_split.csv"
tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86"
filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba"
kfold_run_id: "850c81506684450b9af92296acfd045a"
embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6"
tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f"
tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a"

Expand Down
23 changes: 23 additions & 0 deletions configs/experiment/preprocessing/embedding_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_

defaults:
- /data: dataset
- _self_

tissue_prop_min: 0.2
thresholds:
Nerve: 0.0
Blood: 0.0
Connective-Tissue: 0.4
Fat: 0.5
Epithelium: 0.2
Muscle: 0.4
Other: 0.5

metadata:
run_name: Embedding dataset ${dataset.name}
description: "Join k-fold (${dataset.mlflow_artifacts.kfold_run_id}) and filter_tiles (${dataset.mlflow_artifacts.filter_tiles_run_id}) tile metadata with embeddings (${dataset.mlflow_artifacts.embedding_run_id})."
hyperparams:
kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id}
filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id}
embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id}
13 changes: 13 additions & 0 deletions configs/preprocessing/embedding_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# @package _global_

mlflow_artifact_path: embedding_dataset

tissue_prop_min: ???
thresholds: ???

metadata:
run_name: "Embedding dataset ${dataset.name}"
description: "Build embedding training dataset by joining k-fold/filter_tiles tile metadata with precomputed embeddings."
hyperparams:
tissue_prop_min: ${tissue_prop_min}
thresholds: ${thresholds}
24 changes: 24 additions & 0 deletions preprocessing/_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Shared helpers for deriving tile labels from roi_coverage_* columns."""

from collections.abc import Mapping
from typing import Any

import numpy as np
import pandas as pd


def compute_label_and_tissue_prop(
roi_data: Mapping[str, Any],
roi_cols: list[str],
) -> tuple[np.ndarray, np.ndarray]:
"""Compute (label, tissue_prop) from roi_coverage_* columns.

label = argmax across roi_cols (with ``roi_coverage_`` prefix stripped),
falling back to ``"background"`` whenever all coverages are zero.
tissue_prop = sum across roi_cols.
"""
roi_df = pd.DataFrame({col: roi_data[col] for col in roi_cols})
tp = roi_df.sum(axis=1).to_numpy()
lbl = roi_df.idxmax(axis=1).str.removeprefix("roi_coverage_").to_numpy()
lbl[tp == 0] = "background"
return lbl, tp
Comment thread
vojtech-cifka marked this conversation as resolved.
283 changes: 283 additions & 0 deletions preprocessing/embedding_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
"""Build an embedding training dataset by joining tile metadata with embeddings.

Joins precomputed tile embeddings with k-fold metadata (train) / filter_tiles
metadata (test), applies tissue + per-class ROI thresholds before the join, and
emits a training-ready Parquet dataset (per-split ``slides.parquet`` +
``tiles.parquet``) ready for ``rationai.mlkit.data.datasets.SlidesTilesLoader``.
"""

import shutil
import tempfile
from pathlib import Path

import hydra
import mlflow
import mlflow.artifacts
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as pads
import pyarrow.parquet as pq
from omegaconf import DictConfig, OmegaConf
from rationai.mlkit import autolog, with_cli_args
from rationai.mlkit.lightning.loggers import MLFlowLogger

from preprocessing._labels import compute_label_and_tissue_prop


def apply_thresholds(
df: pd.DataFrame,
tissue_prop_min: float,
thresholds: dict[str, float],
roi_cols: list[str],
) -> tuple[pd.DataFrame, int, int]:
"""Filter tiles by tissue, drop multi-annotation tiles, then apply argmax-then-threshold.

Returns ``(filtered_df, after_tissue_count, after_single_class_count)``.
"""
df = df.loc[df["tissue_prop"] >= tissue_prop_min]
after_tissue = len(df)
if df.empty:
return df, after_tissue, after_tissue

nonzero_classes = (df[roi_cols].to_numpy() > 0).sum(axis=1)
df = df.loc[pd.Series(nonzero_classes <= 1, index=df.index)]
after_single_class = len(df)
if df.empty:
return df, after_tissue, after_single_class

roi_only = df[roi_cols]
dominant = roi_only.idxmax(axis=1).str.removeprefix("roi_coverage_")
dominant_value = roi_only.max(axis=1).to_numpy()
threshold_per_row = dominant.map(thresholds).to_numpy()
keep = dominant_value >= threshold_per_row

out = df.loc[pd.Series(keep, index=df.index)].copy()
out["label"] = dominant.to_numpy()[keep]
return out, after_tissue, after_single_class


def join_embeddings(
tiles_table: pa.Table,
embedding_run_id: str,
embedding_split: str,
) -> tuple[pa.Table, int]:
"""Join filtered tile metadata with embeddings on (slide_id, x, y).

Stays in Arrow throughout to avoid the very slow list<double> -> pandas
conversion. Acero's join engine doesn't accept list columns in non-key
fields, so we join on keys plus a synthetic row index, then pull embeddings
via take(). The embedding column is cast per chunk to large_list to avoid
int32 offset overflow that bites take() when chunks are concatenated.
"""
emb_dir = mlflow.artifacts.download_artifacts(
run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles"
)
emb_table = pads.dataset(emb_dir, format="parquet").to_table(
columns=["slide_id", "x", "y", "embedding"]
)

emb_col = emb_table.column("embedding")
if pa.types.is_list(emb_col.type):
target_type = pa.large_list(emb_col.type.value_type)
emb_col = pa.chunked_array(
[c.cast(target_type) for c in emb_col.chunks], type=target_type
)

emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64())
emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx)
del emb_table, emb_idx

joined_keys = tiles_table.join(
emb_keys, keys=["slide_id", "x", "y"], join_type="inner"
)
del emb_keys

indices = joined_keys.column("_emb_idx")
if isinstance(indices, pa.ChunkedArray):
indices = indices.combine_chunks()

emb_contig = emb_col.combine_chunks()
del emb_col

embeddings = emb_contig.take(indices)
del emb_contig
Comment thread
vojtech-cifka marked this conversation as resolved.

joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings)
dropped_no_embedding = tiles_table.num_rows - joined.num_rows
return joined, dropped_no_embedding


def process_split(
split_name: str,
src_run_id: str,
src_artifact_path: str,
embedding_run_id: str,
tissue_prop_min: float,
thresholds: dict[str, float],
output_split_dir: Path,
derive: bool,
) -> dict[str, int]:
print(f"[{split_name}] downloading source tiles", flush=True)
src_local = mlflow.artifacts.download_artifacts(
run_id=src_run_id, artifact_path=src_artifact_path
)
df = pads.dataset(src_local, format="parquet").to_table().to_pandas()
input_count = len(df)

roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")]
if not roi_cols:
raise RuntimeError(
f"No roi_coverage_* columns in {src_artifact_path}. "
"Cannot apply class thresholds."
)

classes_in_data = {c.removeprefix("roi_coverage_") for c in roi_cols}
missing = classes_in_data - set(thresholds.keys())
if missing:
raise ValueError(
f"thresholds is missing entries for roi_coverage_* classes present "
f"in data: {sorted(missing)}"
)

if derive:
lbl, tp = compute_label_and_tissue_prop(df, roi_cols)
df["label"] = lbl
df["tissue_prop"] = tp
else:
required = {"label", "tissue_prop"}
missing_required = required - set(df.columns)
if missing_required:
raise RuntimeError(
f"Source split '{split_name}' (derive=False) is missing required "
f"columns {sorted(missing_required)} in {src_artifact_path}. "
"Expected the kfold_split artifact, which writes label/tissue_prop/fold."
)

df, after_tissue_filter, after_single_class = apply_thresholds(
df, tissue_prop_min, thresholds, roi_cols
)
after_class_threshold = len(df)
if after_class_threshold == 0:
raise RuntimeError(
f"All {input_count} tiles dropped by thresholds for split '{split_name}'."
)

drop_cols = [
c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_"))
]
df = df.drop(columns=drop_cols)
print(
f"[{split_name}] {input_count} -> {after_tissue_filter} (tissue) "
f"-> {after_single_class} (single-class) "
f"-> {after_class_threshold} (class threshold), joining embeddings",
flush=True,
)

tiles_table = pa.Table.from_pandas(df, preserve_index=False)
del df

merged_table, dropped_no_embedding = join_embeddings(
tiles_table, embedding_run_id, split_name
)
del tiles_table
if dropped_no_embedding != 0:
print(
f"WARNING: {dropped_no_embedding} tiles in split '{split_name}' have "
"no matching embedding and were dropped on join.",
flush=True,
)

sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")])
merged_table = merged_table.take(sort_indices)

output_split_dir.mkdir(parents=True, exist_ok=True)
pq.write_table(merged_table, str(output_split_dir / "tiles.parquet"))

slides_local = mlflow.artifacts.download_artifacts(
run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet"
)
shutil.copy(slides_local, output_split_dir / "slides.parquet")

log_label_distributions(split_name, merged_table)
print(f"[{split_name}] wrote {merged_table.num_rows} rows", flush=True)

return {
"input_count": input_count,
"after_tissue_filter": after_tissue_filter,
"after_single_class": after_single_class,
"after_class_threshold": after_class_threshold,
"after_join": merged_table.num_rows,
"dropped_no_embedding": dropped_no_embedding,
}


def log_label_distributions(split_name: str, table: pa.Table) -> None:
has_fold = "fold" in table.schema.names
cols = ["label", "fold"] if has_fold else ["label"]
df = table.select(cols).to_pandas()

label_dist = (
df["label"].value_counts().rename_axis("label").reset_index(name="count")
)
mlflow.log_table(
data=label_dist,
artifact_file=f"fold_statistics/{split_name}_label_distribution.json",
)

if has_fold:
fold_dist = (
df.groupby(["fold", "label"]).size().unstack(fill_value=0).reset_index()
)
mlflow.log_table(
data=fold_dist,
artifact_file=f"fold_statistics/{split_name}_fold_label_distribution.json",
)


@with_cli_args(["+preprocessing=embedding_dataset"])
@hydra.main(config_path="../configs", config_name="preprocessing", version_base=None)
@autolog
def main(config: DictConfig, logger: MLFlowLogger) -> None:
artifacts = config.dataset.mlflow_artifacts
kfold_run_id = artifacts.kfold_run_id
filter_tiles_run_id = artifacts.filter_tiles_run_id
embedding_run_id = artifacts.embedding_run_id

tissue_prop_min = float(config.tissue_prop_min)
if tissue_prop_min <= 0:
raise ValueError(
f"tissue_prop_min must be > 0 (got {tissue_prop_min}); "
"otherwise background tiles are not filtered out."
)
raw_thresholds = OmegaConf.to_container(config.thresholds, resolve=True)
if not isinstance(raw_thresholds, dict):
raise TypeError("config.thresholds must be a mapping of class -> threshold")
thresholds = {str(k): float(v) for k, v in raw_thresholds.items()}

splits = [
("train", kfold_run_id, "kfold_split/kfold_tiles.parquet", False),
("test", filter_tiles_run_id, "filter_tiles/test_tiles.parquet", True),
]

with tempfile.TemporaryDirectory() as tmp_root:
tmp_root_path = Path(tmp_root)
for split_name, src_run_id, src_artifact_path, derive in splits:
stats = process_split(
split_name=split_name,
src_run_id=src_run_id,
src_artifact_path=src_artifact_path,
embedding_run_id=embedding_run_id,
tissue_prop_min=tissue_prop_min,
thresholds=thresholds,
output_split_dir=tmp_root_path / split_name,
derive=derive,
)
for key, value in stats.items():
mlflow.log_metric(f"{split_name}_{key}", value)

mlflow.log_artifacts(str(tmp_root_path), config.mlflow_artifact_path)


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions scripts/submit_embedding_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from kube_jobs import storage, submit_job


submit_job(
job_name="tissue-classification-embedding-dataset",
username=...,
cpu=8,
memory="64Gi",
gpu=None,
public=False,
script=[
"git clone https://github.com/RationAI/tissue-classification.git workdir",
"cd workdir",
"uv sync",
"uv run python -m preprocessing.embedding_dataset +experiment=...",
],
storage=[storage.secure.PROJECTS],
)
7 changes: 3 additions & 4 deletions split/kfold_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from rationai.mlkit.lightning.loggers import MLFlowLogger
from sklearn.model_selection import StratifiedKFold

from preprocessing._labels import compute_label_and_tissue_prop


def derive_labels(
dataset: Dataset,
Expand All @@ -20,10 +22,7 @@ def derive_labels(
"""Derive label, tissue_prop, and slide_id arrays from the dataset."""

def compute(batch: dict[str, Any]) -> dict[str, Any]:
roi_df = pd.DataFrame({col: batch[col] for col in roi_cols})
tp = roi_df.sum(axis=1).values
lbl = roi_df.idxmax(axis=1).str.removeprefix("roi_coverage_").values
lbl[tp == 0] = "background"
lbl, tp = compute_label_and_tissue_prop(batch, roi_cols)
return {"label": lbl.tolist(), "tissue_prop": tp.tolist()}

label_ds = dataset.select_columns(["slide_id", *roi_cols]).map(
Expand Down