Skip to content

# feat: linear classifier training pipeline on precomputed embeddings#11

Merged
vojtech-cifka merged 55 commits into
masterfrom
feature/ml-linear-classifier
May 16, 2026
Merged

# feat: linear classifier training pipeline on precomputed embeddings#11
vojtech-cifka merged 55 commits into
masterfrom
feature/ml-linear-classifier

Conversation

@vojtech-cifka
Copy link
Copy Markdown
Collaborator

@vojtech-cifka vojtech-cifka commented May 11, 2026

Summary

Adds an end-to-end ML training pipeline for linear probing on precomputed tile
embeddings. Introduces the embedding dataset preprocessing step, a PyTorch
Lightning training module, and all supporting configs and submission scripts.

Changes

Preprocessing

  • preprocessing/_labels.py — shared label/tissue-prop derivation logic.

ML training

  • ml/meta_arch.pyMetaArch Lightning module: backbone + decode head +
    CrossEntropyLoss with balanced class weights computed from the train fold.
    Logs per-class metrics, confusion matrices, and per-slide accuracy.
  • ml/data/datasets/embedding_tiles.pyEmbeddingTilesDataset: loads the
    embedding parquet, inner-joins with metadata, and serves (embedding, label, slide_id) triples. Stays in Arrow for the join to avoid large-list → pandas
    conversion overhead.
  • ml/data/data_module.py — Lightning DataModule wrapping train/val/test splits.
  • ml/callbacks/parquet_prediction_writer.py — writes model predictions to Parquet.
  • configs/experiment/ml/linear_classifier.yaml — full experiment config.
  • configs/ml/ — model, data, and trainer sub-configs.
  • scripts/submit_train_linear.py — MLflow submission script.

Summary by CodeRabbit

Release Notes

  • New Features
    • Linear classifier model training with configurable learning rate and regularization parameters
    • Stratified and stratified group k-fold cross-validation strategies for comprehensive model evaluation
    • Automatic prediction export to parquet format with per-class probability distributions
    • Tile embedding dataset support featuring adjustable tissue coverage filtering and cross-validation fold subsetting

Review Change Stack

vojtech-cifka and others added 26 commits May 8, 2026 21:27
Extract derive_labels logic to shared preprocessing/_labels.py, then use it in
both split/kfold_split.py and the new embedding_dataset pipeline. The new
pipeline joins k-fold (train) / filter_tiles (test) tile metadata with
precomputed embeddings after applying tissue + per-dominant-class ROI thresholds,
and emits a SlidesTilesLoader-compatible Parquet dataset as an MLflow artifact.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Joining 1M+ rows of list<double> embeddings was either OOMing on
to_pandas() or hitting int32 list-offset overflow inside take(). The fix:
- read embeddings into Arrow only and cast each chunk to large_list so
  take() concatenation uses int64 offsets;
- run the join on keys plus a synthetic row index because Acero refuses
  list columns in non-key fields, then pull embeddings via take();
- combine_chunks() before take() for an O(N) single-pass copy;
- write the parquet straight from Arrow, never materialising the
  embedding column in pandas.

Also bumps the kube job memory to 64Gi to give the combined-chunks +
take() peak some headroom, and trims the verbose [timing] prints down
to one progress line per split.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Without this guard a malformed train artifact would crash deep inside
apply_thresholds with a confusing KeyError. Surface a clear error that
points at the expected upstream artifact instead.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@vojtech-cifka vojtech-cifka requested review from a team, JakubPekar and ejdam87 May 11, 2026 20:14
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 11, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR introduces an end-to-end linear classifier training pipeline for tissue classification. It adds type aliases and a shared label-computation helper, implements a PyArrow-backed embedding dataset with metadata filtering and k-fold support, provides a Lightning-based DataModule and MetaArch model architecture with class-weighted training and comprehensive MLflow logging, creates a parquet prediction callback, defines Hydra configurations for model/data/trainer/experiments, and wires everything through a Hydra entrypoint that orchestrates training via a Kubernetes job submission script.

Changes

Linear Classifier Training Pipeline

Layer / File(s) Summary
Type aliases & label helper
ml/typing.py, preprocessing/_labels.py
Defines Sample, Input, Outputs type aliases and adds compute_label_and_tissue_prop helper that converts ROI coverage columns to integer labels and tissue proportions with a "background" fallback for zero-coverage tiles.
DataModule and EmbeddingTilesDataset
ml/data/__init__.py, ml/data/data_module.py, ml/data/datasets/__init__.py, ml/data/datasets/embedding_tiles.py
Implements stage-aware Lightning DataModule that lazily instantiates datasets from Hydra configs. EmbeddingTilesDataset resolves MLflow artifact URIs, loads embeddings and metadata from parquet, filters by tissue proportion and per-class thresholds, supports optional k-fold subsetting, inner-joins embeddings to metadata, and returns (embedding_tensor, class_index, slide_id) samples.
MetaArch Lightning Module
ml/meta_arch.py
Multiclass classification module composing backbone, decode head, and criterion. Computes inverse-frequency class weights from training labels, applies class-weighted CrossEntropyLoss, tracks per-class/macro metrics and confusion matrices, logs per-class metrics and artifacts to MLflow for validation and testing, accumulates per-slide tile accuracy, and returns slide IDs and per-class probabilities for prediction.
ParquetPredictionWriter callback
ml/callbacks/__init__.py, ml/callbacks/parquet_prediction_writer.py
Lightning callback that aggregates predict_step outputs into a parquet file with columns slide_id, target, pred, and per-class probabilities (derived from class_names or prob_<i> naming). Writes to trainer.default_root_dir/predictions.parquet and logs as MLflow artifact.
Hydra configurations
configs/ml/trainer/default.yaml, configs/ml/model/linear_classifier.yaml, configs/ml/data/embedding.yaml, configs/data/dataset.yaml, configs/ml/linear_classifier.yaml, configs/experiment/ml/linear_classifier_stratified_kfold.yaml, configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml
Trainer defaults with callbacks (early stopping, checkpoint, learning-rate monitor, prediction writer). Model config defines identity backbone and fixed-in-features linear head. Embedding data config splits train/val/test by fold using include/exclude_folds. Dataset YAML adds artifact run IDs. Linear classifier base config wires embedding/k-fold/filter artifacts, thresholds, metadata. Experiment configs specify stratified k-fold strategies.
Hydra entrypoint & MLflow wiring
ml/__main__.py
Main entrypoint decorated with @hydra.main and @autolog. Registers OmegaConf resolvers, seeds RNG, instantiates DataModule/MetaArch/Trainer with MLFlowLogger, validates config.mode against {fit, test, validate, predict}, dispatches to trainer methods with checkpoint path, and ends MLflow run.
Job submission, split updates, dependencies
scripts/submit_train_linear.py, split/kfold_split.py, pyproject.toml
Kubernetes job submission script that clones repo, installs via uv, and runs training with multi-experiment config and 5-fold cross-validation. K-fold split code refactored to use shared compute_label_and_tissue_prop helper. Dependencies: pyarrow bumped to >=19.0.1, datasets to >=4.0.0.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

The PR introduces substantial new logic across multiple domains: embedding dataset with complex filtering/joining logic, Lightning MetaArch with multi-metric tracking and MLflow integration, extensive Hydra configuration wiring, and a new training entrypoint. While individual files are coherent, the cross-file dependencies, data flow complexity, and MLflow/Lightning integration patterns demand careful review of assumptions around tensor shapes, metric computation, and artifact logging.

Possibly related PRs

  • RationAI/tissue-classification#7: Extends dataset.mlflow_artifacts with embedding-related run IDs used by this PR's embedding dataset and config wiring.
  • RationAI/tissue-classification#8: Introduces the tile-filtering preprocessing step whose artifact run ID (filter_tiles_run_id) is referenced in this PR's linear-classifier config.

Suggested reviewers

  • vejtek
  • JakubPekar
  • matejpekar

Poem

🐰 A linear head upon embeddings rides,
Stratified folds guide the training tides,
Confusion matrices logged with care,
Parquet predictions float through the air,
Config-driven pipelines, oh what a sight!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main objective of the changeset—implementing a linear classifier training pipeline on precomputed embeddings, which is reflected across all new ML modules, configs, and scripts.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/ml-linear-classifier

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

vejtek
vejtek previously approved these changes May 13, 2026
Train loss ~0.02 vs val loss ~0.32 indicated severe overfit on the
linear probe. AdamW weight_decay was 0; bump to 1e-3 to regularize
the head.
@vojtech-cifka vojtech-cifka requested a review from vejtek May 14, 2026 17:10
vejtek
vejtek previously approved these changes May 15, 2026
@vojtech-cifka vojtech-cifka merged commit 4c90041 into master May 16, 2026
3 checks passed
@vojtech-cifka vojtech-cifka deleted the feature/ml-linear-classifier branch May 16, 2026 12:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants