Skip to content
1 change: 1 addition & 0 deletions configs/data/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dataset:
test_split_filename: "split_mapping/test_split.csv"
tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86"
tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f"
tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a"

exclusions:
bad_slides:
Expand Down
12 changes: 12 additions & 0 deletions configs/experiment/preprocessing/filter_tiles.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# @package _global_

defaults:
- /data: dataset
- _self_

metadata:
run_name: Filter tiles ${dataset.name}
description: "Drop tiles with no annotation coverage and no tile tissue coverage over tiling run ${dataset.mlflow_artifacts.tiling_run_id} and tissue stats run ${dataset.mlflow_artifacts.tissue_stats_run_id}."
hyperparams:
tiling_run_id: ${dataset.mlflow_artifacts.tiling_run_id}
tissue_stats_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id}
13 changes: 13 additions & 0 deletions configs/preprocessing/filter_tiles.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# @package _global_

tissue_stats_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id}
tissue_stats_artifact_path: tissue_stats
tissue_coverage_column: tile_tissue_coverage

mlflow_artifact_path: filter_tiles

metadata:
run_name: "Filter tiles ${dataset.name}"
description: "Drop tiles with no annotation coverage and no tile tissue coverage."
hyperparams:
tissue_coverage_column: ${tissue_coverage_column}
113 changes: 113 additions & 0 deletions preprocessing/filter_tiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import tempfile
from pathlib import Path

import hydra
import mlflow
import mlflow.artifacts
import pyarrow as pa
import pyarrow.dataset as pads
import pyarrow.parquet as pq
from omegaconf import DictConfig
from rationai.mlkit import autolog, with_cli_args
from rationai.mlkit.lightning.loggers import MLFlowLogger


def filter_split(
tiling_run_id: str,
tissue_stats_run_id: str,
tissue_stats_artifact_path: str,
tissue_column: str,
split_name: str,
output_path: Path,
) -> dict[str, int]:
"""Drop tiles with no annotation coverage and no tissue coverage.

Uses PyArrow predicate pushdown so the full tiles parquet is never loaded into
memory — only rows passing the annotation filter are materialised. The tissue
coverage table is then joined in-memory to drop tiles outside tissue and to
carry through the per-tile coverage values into the output.
"""
tiles_local = mlflow.artifacts.download_artifacts(
run_id=tiling_run_id, artifact_path=f"{split_name}_split/tiles.parquet"
)
tiles_ds = pads.dataset(tiles_local, format="parquet")
original_count = tiles_ds.count_rows()

ann_cols = [f.name for f in tiles_ds.schema if f.name.startswith("tile_coverage_")]
if not ann_cols:
raise RuntimeError(
"No tile_coverage_* columns found in tiles parquet. "
"Check that tiling used a class mapping with annotations."
)
ann_filter = pads.field(ann_cols[0]) > 0
for c in ann_cols[1:]:
ann_filter = ann_filter | (pads.field(c) > 0)

tiles_table = tiles_ds.to_table(filter=ann_filter)
ann_count = len(tiles_table)
if ann_count == 0:
raise RuntimeError(
f"All {original_count} tiles dropped by annotation filter for {split_name}. "
"Check the tiling run's class mapping and annotation sources."
)

tissue_local = mlflow.artifacts.download_artifacts(
run_id=tissue_stats_run_id,
artifact_path=f"{tissue_stats_artifact_path}/{split_name}_tiles.parquet",
)
tissue_ds = pads.dataset(tissue_local, format="parquet")
tissue_schema_names = {f.name for f in tissue_ds.schema}
if tissue_column not in tissue_schema_names:
raise RuntimeError(
f"tissue_column '{tissue_column}' not found in tissue stats parquet. "
f"Available columns: {sorted(tissue_schema_names)}"
)
tissue_table = tissue_ds.to_table(filter=pads.field(tissue_column) > 0)

tiles_df = tiles_table.to_pandas()
tissue_df = tissue_table.to_pandas()
del tiles_table, tissue_table
filtered_df = tiles_df.merge(tissue_df, on=["slide_id", "x", "y"], how="inner")
del tiles_df, tissue_df
filtered = pa.Table.from_pandas(filtered_df, preserve_index=False)
final_count = len(filtered)
del filtered_df
Comment thread
vojtech-cifka marked this conversation as resolved.
if final_count == 0:
raise RuntimeError(
f"All {ann_count} annotation-passing tiles dropped by tissue filter for {split_name}. "
f"Check that tissue_column '{tissue_column}' is non-zero for at least some tiles."
)

pq.write_table(filtered, str(output_path))
return {
"original_count": original_count,
"after_annotation": ann_count,
"after_tissue": final_count,
}


@with_cli_args(["+preprocessing=filter_tiles"])
@hydra.main(config_path="../configs", config_name="preprocessing", version_base=None)
@autolog
def main(config: DictConfig, logger: MLFlowLogger) -> None:
tiling_run_id = config.dataset.mlflow_artifacts.tiling_run_id

with tempfile.TemporaryDirectory() as tmp_dir:
for split_name in ("train", "test"):
output_path = Path(tmp_dir) / f"{split_name}_tiles.parquet"
stats = filter_split(
tiling_run_id=tiling_run_id,
tissue_stats_run_id=config.tissue_stats_run_id,
tissue_stats_artifact_path=config.tissue_stats_artifact_path,
tissue_column=config.tissue_coverage_column,
split_name=split_name,
output_path=output_path,
)
for key, value in stats.items():
mlflow.log_metric(f"{split_name}_{key}", value)

mlflow.log_artifacts(tmp_dir, config.mlflow_artifact_path)


if __name__ == "__main__":
main()
17 changes: 17 additions & 0 deletions scripts/submit_filter_tiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from kube_jobs import storage, submit_job


submit_job(
job_name="tissue-classification-filter-tiles",
username=...,
cpu=8,
memory="16Gi",
public=False,
script=[
"git clone https://github.com/RationAI/tissue-classification.git workdir",
"cd workdir",
"uv sync",
"uv run python -m preprocessing.filter_tiles +experiment=...",
],
Comment thread
vojtech-cifka marked this conversation as resolved.
storage=[storage.secure.PROJECTS],
)