-
Notifications
You must be signed in to change notification settings - Fork 0
feat: add embedding dataset build pipeline #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
bef70df
feat: add embedding dataset build pipeline
vojtech-cifka 911bec2
feat: add class tresholds and run ids
vojtech-cifka 1a02395
fix: wrong run id
vojtech-cifka 08d7ba5
Merge remote-tracking branch 'origin/master' into feature/embedding-d…
vojtech-cifka b38465e
feat: add timing
vojtech-cifka bfc9578
refactor: use pyarrow to avoid to pandas conversion
vojtech-cifka eb213c6
fix: join on keys only
vojtech-cifka c92d9a1
fix: typing
vojtech-cifka 01cc394
fix: add prints
vojtech-cifka cad0d37
refactor: use combine chunks
vojtech-cifka ae04552
fix: lazy-cast embeddings to large_list and stay in Arrow during join
vojtech-cifka 82320db
fix: validate label/tissue_prop columns when derive=False
vojtech-cifka 3b0137f
chore: remove time
vojtech-cifka 8df47aa
feat: add timing
vojtech-cifka 926753d
chore: revert to the previous state
vojtech-cifka b0e9ba4
feat: add prints
vojtech-cifka 6a915de
refactor: use discusssed thresholds
vojtech-cifka 0f50307
refactor: use different labeling strategy
vojtech-cifka c421c74
refactor: drop tiles that are covered by two or more distinct labels
vojtech-cifka 718ec08
fix: format
vojtech-cifka 389a0a5
chore: update embeddings run id
vojtech-cifka 11ed4e3
chore: remove timing prints
vojtech-cifka d59425a
refactor: use int64
vojtech-cifka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # @package _global_ | ||
|
|
||
| defaults: | ||
| - /data: dataset | ||
| - _self_ | ||
|
|
||
| tissue_prop_min: 0.2 | ||
| thresholds: | ||
| Nerve: 0.0 | ||
| Blood: 0.0 | ||
| Connective-Tissue: 0.4 | ||
| Fat: 0.5 | ||
| Epithelium: 0.2 | ||
| Muscle: 0.4 | ||
| Other: 0.5 | ||
|
|
||
| metadata: | ||
| run_name: Embedding dataset ${dataset.name} | ||
| description: "Join k-fold (${dataset.mlflow_artifacts.kfold_run_id}) and filter_tiles (${dataset.mlflow_artifacts.filter_tiles_run_id}) tile metadata with embeddings (${dataset.mlflow_artifacts.embedding_run_id})." | ||
| hyperparams: | ||
| kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} | ||
| filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} | ||
| embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # @package _global_ | ||
|
|
||
| mlflow_artifact_path: embedding_dataset | ||
|
|
||
| tissue_prop_min: ??? | ||
| thresholds: ??? | ||
|
|
||
| metadata: | ||
| run_name: "Embedding dataset ${dataset.name}" | ||
| description: "Build embedding training dataset by joining k-fold/filter_tiles tile metadata with precomputed embeddings." | ||
| hyperparams: | ||
| tissue_prop_min: ${tissue_prop_min} | ||
| thresholds: ${thresholds} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| """Shared helpers for deriving tile labels from roi_coverage_* columns.""" | ||
|
|
||
| from collections.abc import Mapping | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
|
|
||
| def compute_label_and_tissue_prop( | ||
| roi_data: Mapping[str, Any], | ||
| roi_cols: list[str], | ||
| ) -> tuple[np.ndarray, np.ndarray]: | ||
| """Compute (label, tissue_prop) from roi_coverage_* columns. | ||
|
|
||
| label = argmax across roi_cols (with ``roi_coverage_`` prefix stripped), | ||
| falling back to ``"background"`` whenever all coverages are zero. | ||
| tissue_prop = sum across roi_cols. | ||
| """ | ||
| roi_df = pd.DataFrame({col: roi_data[col] for col in roi_cols}) | ||
| tp = roi_df.sum(axis=1).to_numpy() | ||
| lbl = roi_df.idxmax(axis=1).str.removeprefix("roi_coverage_").to_numpy() | ||
| lbl[tp == 0] = "background" | ||
| return lbl, tp | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,283 @@ | ||
| """Build an embedding training dataset by joining tile metadata with embeddings. | ||
|
|
||
| Joins precomputed tile embeddings with k-fold metadata (train) / filter_tiles | ||
| metadata (test), applies tissue + per-class ROI thresholds before the join, and | ||
| emits a training-ready Parquet dataset (per-split ``slides.parquet`` + | ||
| ``tiles.parquet``) ready for ``rationai.mlkit.data.datasets.SlidesTilesLoader``. | ||
| """ | ||
|
|
||
| import shutil | ||
| import tempfile | ||
| from pathlib import Path | ||
|
|
||
| import hydra | ||
| import mlflow | ||
| import mlflow.artifacts | ||
| import pandas as pd | ||
| import pyarrow as pa | ||
| import pyarrow.compute as pc | ||
| import pyarrow.dataset as pads | ||
| import pyarrow.parquet as pq | ||
| from omegaconf import DictConfig, OmegaConf | ||
| from rationai.mlkit import autolog, with_cli_args | ||
| from rationai.mlkit.lightning.loggers import MLFlowLogger | ||
|
|
||
| from preprocessing._labels import compute_label_and_tissue_prop | ||
|
|
||
|
|
||
| def apply_thresholds( | ||
| df: pd.DataFrame, | ||
| tissue_prop_min: float, | ||
| thresholds: dict[str, float], | ||
| roi_cols: list[str], | ||
| ) -> tuple[pd.DataFrame, int, int]: | ||
| """Filter tiles by tissue, drop multi-annotation tiles, then apply argmax-then-threshold. | ||
|
|
||
| Returns ``(filtered_df, after_tissue_count, after_single_class_count)``. | ||
| """ | ||
| df = df.loc[df["tissue_prop"] >= tissue_prop_min] | ||
| after_tissue = len(df) | ||
| if df.empty: | ||
| return df, after_tissue, after_tissue | ||
|
|
||
| nonzero_classes = (df[roi_cols].to_numpy() > 0).sum(axis=1) | ||
| df = df.loc[pd.Series(nonzero_classes <= 1, index=df.index)] | ||
| after_single_class = len(df) | ||
| if df.empty: | ||
| return df, after_tissue, after_single_class | ||
|
|
||
| roi_only = df[roi_cols] | ||
| dominant = roi_only.idxmax(axis=1).str.removeprefix("roi_coverage_") | ||
| dominant_value = roi_only.max(axis=1).to_numpy() | ||
| threshold_per_row = dominant.map(thresholds).to_numpy() | ||
| keep = dominant_value >= threshold_per_row | ||
|
|
||
| out = df.loc[pd.Series(keep, index=df.index)].copy() | ||
| out["label"] = dominant.to_numpy()[keep] | ||
| return out, after_tissue, after_single_class | ||
|
|
||
|
|
||
| def join_embeddings( | ||
| tiles_table: pa.Table, | ||
| embedding_run_id: str, | ||
| embedding_split: str, | ||
| ) -> tuple[pa.Table, int]: | ||
| """Join filtered tile metadata with embeddings on (slide_id, x, y). | ||
|
|
||
| Stays in Arrow throughout to avoid the very slow list<double> -> pandas | ||
| conversion. Acero's join engine doesn't accept list columns in non-key | ||
| fields, so we join on keys plus a synthetic row index, then pull embeddings | ||
| via take(). The embedding column is cast per chunk to large_list to avoid | ||
| int32 offset overflow that bites take() when chunks are concatenated. | ||
| """ | ||
| emb_dir = mlflow.artifacts.download_artifacts( | ||
| run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" | ||
| ) | ||
| emb_table = pads.dataset(emb_dir, format="parquet").to_table( | ||
| columns=["slide_id", "x", "y", "embedding"] | ||
| ) | ||
|
|
||
| emb_col = emb_table.column("embedding") | ||
| if pa.types.is_list(emb_col.type): | ||
| target_type = pa.large_list(emb_col.type.value_type) | ||
| emb_col = pa.chunked_array( | ||
| [c.cast(target_type) for c in emb_col.chunks], type=target_type | ||
| ) | ||
|
|
||
| emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) | ||
| emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) | ||
| del emb_table, emb_idx | ||
|
|
||
| joined_keys = tiles_table.join( | ||
| emb_keys, keys=["slide_id", "x", "y"], join_type="inner" | ||
| ) | ||
| del emb_keys | ||
|
|
||
| indices = joined_keys.column("_emb_idx") | ||
| if isinstance(indices, pa.ChunkedArray): | ||
| indices = indices.combine_chunks() | ||
|
|
||
| emb_contig = emb_col.combine_chunks() | ||
| del emb_col | ||
|
|
||
| embeddings = emb_contig.take(indices) | ||
| del emb_contig | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
|
|
||
| joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) | ||
| dropped_no_embedding = tiles_table.num_rows - joined.num_rows | ||
| return joined, dropped_no_embedding | ||
|
|
||
|
|
||
| def process_split( | ||
| split_name: str, | ||
| src_run_id: str, | ||
| src_artifact_path: str, | ||
| embedding_run_id: str, | ||
| tissue_prop_min: float, | ||
| thresholds: dict[str, float], | ||
| output_split_dir: Path, | ||
| derive: bool, | ||
| ) -> dict[str, int]: | ||
| print(f"[{split_name}] downloading source tiles", flush=True) | ||
| src_local = mlflow.artifacts.download_artifacts( | ||
| run_id=src_run_id, artifact_path=src_artifact_path | ||
| ) | ||
| df = pads.dataset(src_local, format="parquet").to_table().to_pandas() | ||
| input_count = len(df) | ||
|
|
||
| roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] | ||
| if not roi_cols: | ||
| raise RuntimeError( | ||
| f"No roi_coverage_* columns in {src_artifact_path}. " | ||
| "Cannot apply class thresholds." | ||
| ) | ||
|
|
||
| classes_in_data = {c.removeprefix("roi_coverage_") for c in roi_cols} | ||
| missing = classes_in_data - set(thresholds.keys()) | ||
| if missing: | ||
| raise ValueError( | ||
| f"thresholds is missing entries for roi_coverage_* classes present " | ||
| f"in data: {sorted(missing)}" | ||
| ) | ||
|
|
||
| if derive: | ||
| lbl, tp = compute_label_and_tissue_prop(df, roi_cols) | ||
| df["label"] = lbl | ||
| df["tissue_prop"] = tp | ||
| else: | ||
| required = {"label", "tissue_prop"} | ||
| missing_required = required - set(df.columns) | ||
| if missing_required: | ||
| raise RuntimeError( | ||
| f"Source split '{split_name}' (derive=False) is missing required " | ||
| f"columns {sorted(missing_required)} in {src_artifact_path}. " | ||
| "Expected the kfold_split artifact, which writes label/tissue_prop/fold." | ||
| ) | ||
|
|
||
| df, after_tissue_filter, after_single_class = apply_thresholds( | ||
| df, tissue_prop_min, thresholds, roi_cols | ||
| ) | ||
| after_class_threshold = len(df) | ||
| if after_class_threshold == 0: | ||
| raise RuntimeError( | ||
| f"All {input_count} tiles dropped by thresholds for split '{split_name}'." | ||
| ) | ||
|
|
||
| drop_cols = [ | ||
| c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) | ||
| ] | ||
| df = df.drop(columns=drop_cols) | ||
| print( | ||
| f"[{split_name}] {input_count} -> {after_tissue_filter} (tissue) " | ||
| f"-> {after_single_class} (single-class) " | ||
| f"-> {after_class_threshold} (class threshold), joining embeddings", | ||
| flush=True, | ||
| ) | ||
|
|
||
| tiles_table = pa.Table.from_pandas(df, preserve_index=False) | ||
| del df | ||
|
|
||
| merged_table, dropped_no_embedding = join_embeddings( | ||
| tiles_table, embedding_run_id, split_name | ||
| ) | ||
| del tiles_table | ||
| if dropped_no_embedding != 0: | ||
| print( | ||
| f"WARNING: {dropped_no_embedding} tiles in split '{split_name}' have " | ||
| "no matching embedding and were dropped on join.", | ||
| flush=True, | ||
| ) | ||
|
|
||
| sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) | ||
| merged_table = merged_table.take(sort_indices) | ||
|
|
||
| output_split_dir.mkdir(parents=True, exist_ok=True) | ||
| pq.write_table(merged_table, str(output_split_dir / "tiles.parquet")) | ||
|
|
||
| slides_local = mlflow.artifacts.download_artifacts( | ||
| run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet" | ||
| ) | ||
| shutil.copy(slides_local, output_split_dir / "slides.parquet") | ||
|
|
||
| log_label_distributions(split_name, merged_table) | ||
| print(f"[{split_name}] wrote {merged_table.num_rows} rows", flush=True) | ||
|
|
||
| return { | ||
| "input_count": input_count, | ||
| "after_tissue_filter": after_tissue_filter, | ||
| "after_single_class": after_single_class, | ||
| "after_class_threshold": after_class_threshold, | ||
| "after_join": merged_table.num_rows, | ||
| "dropped_no_embedding": dropped_no_embedding, | ||
| } | ||
|
|
||
|
|
||
| def log_label_distributions(split_name: str, table: pa.Table) -> None: | ||
| has_fold = "fold" in table.schema.names | ||
| cols = ["label", "fold"] if has_fold else ["label"] | ||
| df = table.select(cols).to_pandas() | ||
|
|
||
| label_dist = ( | ||
| df["label"].value_counts().rename_axis("label").reset_index(name="count") | ||
| ) | ||
| mlflow.log_table( | ||
| data=label_dist, | ||
| artifact_file=f"fold_statistics/{split_name}_label_distribution.json", | ||
| ) | ||
|
|
||
| if has_fold: | ||
| fold_dist = ( | ||
| df.groupby(["fold", "label"]).size().unstack(fill_value=0).reset_index() | ||
| ) | ||
| mlflow.log_table( | ||
| data=fold_dist, | ||
| artifact_file=f"fold_statistics/{split_name}_fold_label_distribution.json", | ||
| ) | ||
|
|
||
|
|
||
| @with_cli_args(["+preprocessing=embedding_dataset"]) | ||
| @hydra.main(config_path="../configs", config_name="preprocessing", version_base=None) | ||
| @autolog | ||
| def main(config: DictConfig, logger: MLFlowLogger) -> None: | ||
| artifacts = config.dataset.mlflow_artifacts | ||
| kfold_run_id = artifacts.kfold_run_id | ||
| filter_tiles_run_id = artifacts.filter_tiles_run_id | ||
| embedding_run_id = artifacts.embedding_run_id | ||
|
|
||
| tissue_prop_min = float(config.tissue_prop_min) | ||
| if tissue_prop_min <= 0: | ||
| raise ValueError( | ||
| f"tissue_prop_min must be > 0 (got {tissue_prop_min}); " | ||
| "otherwise background tiles are not filtered out." | ||
| ) | ||
| raw_thresholds = OmegaConf.to_container(config.thresholds, resolve=True) | ||
| if not isinstance(raw_thresholds, dict): | ||
| raise TypeError("config.thresholds must be a mapping of class -> threshold") | ||
| thresholds = {str(k): float(v) for k, v in raw_thresholds.items()} | ||
|
|
||
| splits = [ | ||
| ("train", kfold_run_id, "kfold_split/kfold_tiles.parquet", False), | ||
| ("test", filter_tiles_run_id, "filter_tiles/test_tiles.parquet", True), | ||
| ] | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmp_root: | ||
| tmp_root_path = Path(tmp_root) | ||
| for split_name, src_run_id, src_artifact_path, derive in splits: | ||
| stats = process_split( | ||
| split_name=split_name, | ||
| src_run_id=src_run_id, | ||
| src_artifact_path=src_artifact_path, | ||
| embedding_run_id=embedding_run_id, | ||
| tissue_prop_min=tissue_prop_min, | ||
| thresholds=thresholds, | ||
| output_split_dir=tmp_root_path / split_name, | ||
| derive=derive, | ||
| ) | ||
| for key, value in stats.items(): | ||
| mlflow.log_metric(f"{split_name}_{key}", value) | ||
|
|
||
| mlflow.log_artifacts(str(tmp_root_path), config.mlflow_artifact_path) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| from kube_jobs import storage, submit_job | ||
|
|
||
|
|
||
| submit_job( | ||
| job_name="tissue-classification-embedding-dataset", | ||
| username=..., | ||
| cpu=8, | ||
| memory="64Gi", | ||
| gpu=None, | ||
| public=False, | ||
| script=[ | ||
| "git clone https://github.com/RationAI/tissue-classification.git workdir", | ||
| "cd workdir", | ||
| "uv sync", | ||
| "uv run python -m preprocessing.embedding_dataset +experiment=...", | ||
| ], | ||
| storage=[storage.secure.PROJECTS], | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.