Skip to content

Virchow2 model#2

Open
Jurgee wants to merge 52 commits intomainfrom
feature/virchow2-model
Open

Virchow2 model#2
Jurgee wants to merge 52 commits intomainfrom
feature/virchow2-model

Conversation

@Jurgee
Copy link
Copy Markdown
Collaborator

@Jurgee Jurgee commented Mar 13, 2026

This PR introduces support for the Virchow2 foundation model (paige-ai/Virchow2) within the Ray Serve infrastructure.

New Model Deployment: Added virchow2.py implementing the Virchow2 class as a Ray Serve deployment

Summary by CodeRabbit

  • New Features

    • Virchow2 foundation-model deployment with FastAPI ingress for image embeddings.
    • Automated HuggingFace downloader job and local model-download utility for offline caching.
    • PersistentVolumeClaim and shared cache mounts for HuggingFace assets.
    • New provider helper to fetch models/files from HuggingFace.
  • Refactor

    • Removed thread-level inference configuration from model settings.
  • Chores

    • Updated service deployment, worker scaling, resource allocations, GPU runtime packaging and environment (HF token, cache) for HuggingFace access.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a Virchow2 Ray Serve deployment with FastAPI ingress, HuggingFace caching/support (PVC, provider helper, downloader job), Dockerfile and Ray service updates for HF/GPU integration, and removes intra_op_num_threads usage from model configs.

Changes

Cohort / File(s) Summary
Docker Configuration
docker/Dockerfile.cpu, docker/Dockerfile.gpu
Dockerfile.cpu: reformatted pip install into a multiline RUN. Dockerfile.gpu: split pip installs into multiple RUN steps, added GPU-specific packages and external indices (onnxruntime-gpu, tensorrt-cu12, torch/cu121, torchvision/cu121, timm, huggingface-hub).
Virchow2 Downloader Utilities
misc/virchow2_downloader/download_virchow2.py, misc/virchow2_downloader/virchow2_downloader_job.yaml
Added script to cache HuggingFace model snapshots (supports HF_TOKEN) and a Kubernetes Job manifest to run it with PVC mount, resource limits, secure non-root context, and required env vars.
Model Config Cleanup
models/binary_classifier.py, models/semantic_segmentation.py
Removed intra_op_num_threads from Config TypedDicts and stopped setting sess_options.intra_op_num_threads; small formatting tweaks.
Virchow2 Model Implementation
models/virchow2.py
New Ray Serve deployment Virchow2 with embedded FastAPI ingress, reconfigure to load timm model from HF/cache, preprocessing transforms, batched async predict, LZ4-compressed request handling, and exported app = Virchow2.bind().
Model Provider Enhancement
providers/model_provider.py
Added `huggingface(repo_id, filename: str
Kubernetes / Ray Serve Manifests
pvc/huggingface-pvc.yaml, ray-service.yaml
Added 15Gi NFS-backed PVC huggingface-cache-pvc. Updated ray-service.yaml: new virchow2 application, HF_HOME envs, HF_TOKEN secret in GPU workers, huggingface-cache volume mounts in cpu/gpu workers, head/worker image and resource updates, MIG GPU selector, actor sizing, and virchow2 user_config. Removed intra_op_num_threads from BinaryClassifier user_config.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant RayServe as Ray Serve (Virchow2)
    participant Provider as providers.huggingface
    participant PVC as HuggingFace PVC
    participant HFHub as HuggingFace Hub
    participant Downloader as Downloader Job

    Client->>RayServe: POST LZ4-compressed image(s)
    RayServe->>Provider: request local path for repo_id (HF_HOME)
    Provider-->>RayServe: return local cache path (exists?) 
    alt cache miss
        RayServe->>Downloader: schedule/run downloader job
        Downloader->>HFHub: snapshot_download / hf_hub_download (uses HF_TOKEN)
        HFHub-->>PVC: write model files to cache
        Downloader-->>RayServe: notify completion
    end
    RayServe->>RayServe: preprocess, batch, run timm model (GPU)
    RayServe-->>Client: return embeddings (LZ4-compressed)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐇 I hopped to the cache and found a bright seed,
A model unrolled with embeddings to feed,
From HuggingFace burrow to PVC nest,
Ray serves each tile — the rabbits impressed!
🥕

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title "Virchow2 model" is vague and generic. While it mentions the model name, it does not describe the actual change or what was accomplished—such as 'adds support for', 'integrates', 'implements deployment for', or similar action verbs that clarify the scope of work. Consider a more descriptive title like 'Add Virchow2 foundation model support to Ray Serve' or 'Implement Virchow2 model deployment with Ray Serve integration'.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/virchow2-model

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the model serving infrastructure by integrating the Virchow2 foundation model and optimizing existing models for GPU performance. It establishes a robust framework for deploying advanced deep learning models, leveraging TensorRT for efficient inference and Hugging Face for model management. The changes also include necessary infrastructure updates for dependency management and persistent caching, ensuring a more efficient and scalable system.

Highlights

  • Virchow2 Model Integration: Introduced the Virchow2 foundation model (paige-ai/Virchow2) as a new Ray Serve deployment, enabling its use within the existing infrastructure.
  • GPU Acceleration with TensorRT: Enabled GPU acceleration for existing BinaryClassifier and SemanticSegmentation models by integrating TensorRT, significantly improving inference performance.
  • Infrastructure Updates for Model Caching: Added support for persistent caching of Hugging Face models and TensorRT engines through new Persistent Volume Claims (PVCs) and a dedicated model downloader script.
  • Dockerfile and Dependency Management: Updated Dockerfiles to include necessary Python packages like PyTorch, timm, and huggingface_hub, and created a new GPU-specific Dockerfile with source builds for image processing libraries and TensorRT.
  • Ray Serve Configuration Enhancements: Adjusted resource allocations, autoscaling parameters, and added volume mounts in ray-service.yaml to optimize performance and resource utilization across all deployments, including the new Virchow2 model.
Changelog
  • builders/heatmap_builder.py
    • Added a call to mask_builder.flush() to ensure data is written before saving.
  • docker/Dockerfile.cpu
    • Updated pip install commands to include torch, torchvision, timm, and huggingface-hub.
  • docker/Dockerfile.gpu
    • Added a new Dockerfile for GPU environments, including building libvips and openslide from source, and installing onnxruntime-gpu, tensorrt, torch (CUDA-enabled), torchvision (CUDA-enabled), timm, and huggingface-hub.
  • misc/tile_heatmap_builder.py
    • Introduced a flush method to explicitly flush image and count data.
  • misc/virchow2_downloader/download_virchow2.py
    • Added a Python script to download the paige-ai/Virchow2 model from Hugging Face Hub and verify its loading.
  • misc/virchow2_downloader/virchow2_downloader_job.yaml
    • Added a Kubernetes Job definition to run the download_virchow2.py script, configuring security contexts, resources, environment variables, and volume mounts for caching.
  • models/binary_classifier.py
    • Imported asyncio and os.
    • Removed mean and std from Config as normalization is now handled by the model.
    • Modified reconfigure to be synchronous and include TensorRT execution provider options, cache path creation, graph optimizations, and a model warmup step.
    • Updated predict to handle uint8 input directly and flatten outputs.
    • Adjusted root method to use asyncio.to_thread for decompression and transpose for image reshaping.
  • models/semantic_segmentation.py
    • Imported os.
    • Modified reconfigure to be synchronous and include TensorRT execution provider options, cache path creation, graph optimizations, and a model warmup step.
    • Updated providers list to prioritize TensorrtExecutionProvider.
  • models/virchow2.py
    • Added a new Ray Serve deployment for the Virchow2 model, including __init__, reconfigure (for model loading, transformations, and warmup), predict (for inference and embedding extraction), and root (for handling requests).
  • providers/model_provider.py
    • Added a new provider function for fetching models from Hugging Face Hub, supporting both single files and snapshots, with local caching.
  • pvc/huggingface-pvc.yaml
    • Added a Persistent Volume Claim for the Hugging Face cache with 15Gi storage.
  • pvc/tensorrt-cache-pvc.yaml
    • Added a Persistent Volume Claim for the TensorRT cache with 20Gi storage.
  • pyproject.toml
    • Added 'Jiří Štípek' as an author.
  • ray-service.yaml
    • Updated BinaryClassifier deployment: changed num_cpus to 4, added num_gpus to 1, reduced memory to 4Gi, decreased max_batch_size to 16, reduced batch_wait_timeout_s to 0.01, removed mean and std config, and updated artifact_uri.
    • Updated SemanticSegmentation deployment: reduced max_queued_requests to 32, max_replicas to 2, target_ongoing_requests to 8, num_cpus to 4, added num_gpus to 1, increased max_batch_size to 8, and reduced batch_wait_timeout_s to 0.1.
    • Updated HeatmapBuilder deployment: increased max_replicas to 4, num_cpus to 8, num_threads to 8, and max_concurrent_tasks to 24.
    • Added a new virchow2 service with a Virchow2 deployment, configuring runtime environment, resource limits (8 CPUs, 1 GPU, 8Gi memory), autoscaling, and user configuration for tile size, batching, and model provider.
    • Updated ray-head image to cerit.io/rationai/model-service:2.53.0 and increased memory limits/requests.
    • Added HTTPS_PROXY environment variable to ray-head and ray-worker containers.
    • Added securityContext and lifecycle hooks to ray-worker containers.
    • Added trt-cache-volume and huggingface-cache volume mounts to ray-worker containers.
    • Added gpu-workers group with specific GPU resource requests (nvidia.com/mig-2g.20gb: 1).
    • Added trt-cache-volume and huggingface-cache PVCs to the volumes section.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Virchow2 foundation model within the Ray Serve infrastructure, which is a significant addition. The changes include a new Ray Serve deployment for Virchow2, updates to Dockerfiles for GPU support with necessary dependencies, and Kubernetes configurations for model downloading and caching. The refactoring of existing models to leverage GPU and TensorRT is a great performance enhancement. However, the pull request introduces critical security vulnerabilities by including hardcoded secrets in configuration files. These must be addressed by using a secure secret management solution like Kubernetes Secrets. Additionally, there are minor areas for improvement regarding Docker image consistency and file permissions.

@Jurgee Jurgee self-assigned this Mar 13, 2026
@matejpekar matejpekar requested review from Adames4 and matejpekar and removed request for JakubPekar and ejdam87 March 17, 2026 10:08
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
models/virchow2.py (1)

50-55: Clarify provider contract: currently its return value is discarded.

At Line 51, provider(**...) is called for side effects, but timm.create_model at Line 53 always uses repo_id. This makes _target_ abstraction misleading (compare with models/semantic_segmentation.py:55-69, where provider output is consumed directly).

Consider either:

  1. explicitly documenting provider as cache warm-up only, or
  2. consuming provider output as the model source of truth.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@models/virchow2.py` around lines 50 - 55, The call to
provider(**config["model"]) in the constructor is currently invoked only for
side effects and its return value is ignored while timm.create_model always uses
repo_id; either consume the provider return as the authoritative model source or
explicitly document it as a cache-warmup-only call. Update the code so that
provider(...)'s return (e.g., a model path, HF repo override, or config dict) is
checked and passed into timm.create_model (replace f"hf-hub:{repo_id}" with the
provider-provided identifier) and assign to self.model, or alternatively add a
clear comment and/or refactor the _target_ contract and call site to match
models/semantic_segmentation.py behavior where the provider output is consumed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@models/virchow2.py`:
- Line 66: The log message always says "moved to GPU" even when the model is
loaded to CPU; update the logger call that currently reads logger.info("Virchow2
model loaded and moved to GPU.") to report the actual device variable used
(e.g., include the `device` or `device_str` value computed when selecting
CPU/GPU) so the message becomes device-accurate; locate the model load/transfer
code (where `model.to(device)` or device selection is performed) and change the
info log to include that device identifier.
- Around line 46-48: The reconfigure function is mutating config["model"] by
calling .pop("_target_"), which will remove the key and break subsequent calls;
instead, read the target without mutating (e.g., target =
config["model"].get("_target_") or work on a shallow copy of config["model"])
and then split that target into module_path and attr_name, and fall back to a
clear error if the key is missing; update the code references in reconfigure
that currently use module_path, attr_name, provider to use the non-mutating
value so the original user_config remains unchanged.
- Around line 99-107: The code currently lets client-caused errors from
lz4.frame.decompress, np.frombuffer/reshape and np.dtype bubble up as 500s; wrap
the decompression/reshape and dtype parsing in a try/except that catches
RuntimeError, ValueError and TypeError (and optionally OverflowError) and
convert them to a client error (raise an HTTP 4xx, e.g.,
HTTPException(status_code=400) or 422) with a clear message; specifically
protect the lz4.frame.decompress call, the
np.frombuffer(...).reshape(self.tile_size, self.tile_size, 3) sequence, and the
np.dtype(requested_dtype) call (and validate
request.headers.get("x-output-dtype") before parsing) so malformed compressed
payloads, wrong buffer sizes, or invalid dtypes return 4xx instead of 500.

---

Nitpick comments:
In `@models/virchow2.py`:
- Around line 50-55: The call to provider(**config["model"]) in the constructor
is currently invoked only for side effects and its return value is ignored while
timm.create_model always uses repo_id; either consume the provider return as the
authoritative model source or explicitly document it as a cache-warmup-only
call. Update the code so that provider(...)'s return (e.g., a model path, HF
repo override, or config dict) is checked and passed into timm.create_model
(replace f"hf-hub:{repo_id}" with the provider-provided identifier) and assign
to self.model, or alternatively add a clear comment and/or refactor the _target_
contract and call site to match models/semantic_segmentation.py behavior where
the provider output is consumed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 85e59ece-bf28-466e-a429-d0cc10ffa6ae

📥 Commits

Reviewing files that changed from the base of the PR and between fb646c4 and 75ed923.

📒 Files selected for processing (2)
  • models/virchow2.py
  • providers/model_provider.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • providers/model_provider.py

@Jurgee Jurgee requested a review from matejpekar March 18, 2026 19:43
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
ray-service.yaml (2)

257-257: ⚠️ Potential issue | 🟠 Major

Replace mutable latest-gpu image tag with a pinned version (and digest).

Line 257 can drift independently of rayVersion: 2.53.0, causing rollout non-determinism.

Suggested change
-                image: cerit.io/rationai/model-service:latest-gpu
+                image: cerit.io/rationai/model-service:2.53.0-gpu@sha256:<digest>
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ray-service.yaml` at line 257, The image reference "image:
cerit.io/rationai/model-service:latest-gpu" is using a mutable tag; replace it
with a specific, immutable image: use the release-aligned semantic version
(matching rayVersion: 2.53.0 if applicable) and append the image digest (sha256)
so the line becomes a pinned image (e.g.,
cerit.io/rationai/model-service:<version>@sha256:<digest>); update CI/build
manifests that produce the digest or fetch the digest from your registry and
ensure the new pinned string replaces the "latest-gpu" token to guarantee
deterministic rollouts.

92-95: ⚠️ Potential issue | 🟠 Major

Pin working_dir to an immutable commit archive.

Line 95 points to refs/heads/master.zip, which is mutable and breaks reproducibility/rollback guarantees on restart.

Suggested change
-          working_dir: https://github.com/RationAI/model-service/archive/refs/heads/master.zip
+          working_dir: https://github.com/RationAI/model-service/archive/<commit-sha>.zip
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ray-service.yaml` around lines 92 - 95, The working_dir currently points to a
mutable branch archive (https://.../refs/heads/master.zip) which breaks
reproducibility; change the runtime_env.working_dir to an immutable commit
archive by replacing the refs/heads/master.zip URL with the repository archive
URL for a specific commit SHA (e.g.
https://github.com/RationAI/model-service/archive/<commit-sha>.zip). Locate the
runtime_env block and update working_dir accordingly, ensuring you use a pinned
commit SHA (not a branch name) and update any deployment docs to record the
chosen SHA for rollbacks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docker/Dockerfile.gpu`:
- Around line 56-64: The Dockerfile's GPU pip installs (the pip install lines
installing onnxruntime-gpu, tensorrt-cu12, torch, torchvision, timm and
huggingface-hub) must use explicit, validated version pins to avoid upstream
drift; update the two pip install invocations to pin each critical package
(e.g., torch==2.4.1+cu121 and torchvision==0.19.1+cu121 or your tested
equivalents, timm==0.9.11+, and fixed onnxruntime-gpu and tensorrt-cu12 versions
that you validated for Python 3.12 + CUDA 12.1) and document the chosen
combinations in the Dockerfile comment so future rebuilds use the same, tested
package matrix.

In `@ray-service.yaml`:
- Around line 270-274: The HF_TOKEN secret is being injected into all GPU worker
pods (env name HF_TOKEN from secret huggingface-secret) even though
providers/model_provider.py already uses local_files_only=True and does not need
the token; remove HF_TOKEN from the generic GPU worker container spec and
instead inject the secret only into the specific pod/container that performs
authenticated Hugging Face calls (the component that imports
providers/model_provider.py or any service that sets local_files_only=False).
Update ray-service.yaml to delete the HF_TOKEN env entry from the generic worker
template and add it to the targeted deployment/container spec (or create a
separate pod template/service account) so only that component receives
huggingface-secret.

---

Duplicate comments:
In `@ray-service.yaml`:
- Line 257: The image reference "image:
cerit.io/rationai/model-service:latest-gpu" is using a mutable tag; replace it
with a specific, immutable image: use the release-aligned semantic version
(matching rayVersion: 2.53.0 if applicable) and append the image digest (sha256)
so the line becomes a pinned image (e.g.,
cerit.io/rationai/model-service:<version>@sha256:<digest>); update CI/build
manifests that produce the digest or fetch the digest from your registry and
ensure the new pinned string replaces the "latest-gpu" token to guarantee
deterministic rollouts.
- Around line 92-95: The working_dir currently points to a mutable branch
archive (https://.../refs/heads/master.zip) which breaks reproducibility; change
the runtime_env.working_dir to an immutable commit archive by replacing the
refs/heads/master.zip URL with the repository archive URL for a specific commit
SHA (e.g. https://github.com/RationAI/model-service/archive/<commit-sha>.zip).
Locate the runtime_env block and update working_dir accordingly, ensuring you
use a pinned commit SHA (not a branch name) and update any deployment docs to
record the chosen SHA for rollbacks.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1e5753ec-93b8-4bb1-86fb-69a3fdeac1c2

📥 Commits

Reviewing files that changed from the base of the PR and between 75ed923 and bfcbe5b.

📒 Files selected for processing (2)
  • docker/Dockerfile.gpu
  • ray-service.yaml

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
models/virchow2.py (2)

46-51: ⚠️ Potential issue | 🟡 Minor

Avoid mutating deployment config in reconfigure.

Line 46 uses .pop("_target_"), which mutates config["model"]. Reconfigure can be invoked multiple times, so this can make subsequent calls fail when _target_ is missing.

♻️ Proposed fix
-        module_path, attr_name = config["model"].pop("_target_").split(":")
+        model_cfg = dict(config["model"])
+        module_path, attr_name = model_cfg["_target_"].split(":")
         provider = getattr(importlib.import_module(module_path), attr_name)
-        repo_id = config["model"]["repo_id"]
+        repo_id = model_cfg["repo_id"]
@@
-        provider(**config["model"])
+        provider(**{k: v for k, v in model_cfg.items() if k != "_target_"})
#!/bin/bash
# Verify mutation pattern is present and compare with similar deployment code.
rg -nP 'config\["model"\]\.pop\("_target_"\)' models/virchow2.py models/binary_classifier.py -C2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@models/virchow2.py` around lines 46 - 51, The code currently mutates
config["model"] by calling .pop("_target_"), which breaks repeated calls (e.g.,
reconfigure); instead, read the target string without removing it and pass a
non-mutated kwargs dict to the provider: extract target_str =
config["model"]["_target_"] then do module_path, attr_name =
target_str.split(":"), load provider =
getattr(importlib.import_module(module_path), attr_name), build model_kwargs =
{k:v for k,v in config["model"].items() if k != "_target_"} (or use
config["model"].copy() and del the key on the copy) and call
provider(**model_kwargs); keep using repo_id and logger.info as before.

97-104: ⚠️ Potential issue | 🟠 Major

Map malformed input and dtype parsing errors to 4xx responses.

Lines 97-104 can raise client-caused exceptions (lz4.frame.decompress, reshape, np.dtype) that currently bubble as 500s. This should return a 4xx with a clear message.

🛠️ Proposed fix
-        data = await asyncio.to_thread(lz4.frame.decompress, await request.body())
-        image = np.frombuffer(data, dtype=np.uint8).reshape(
-            self.tile_size, self.tile_size, 3
-        )
-
         requested_dtype = request.headers.get("x-output-dtype", "float32").lower()
-
-        output_dtype = np.dtype(requested_dtype)
+        if requested_dtype not in {"float16", "float32"}:
+            return Response("Unsupported x-output-dtype", status_code=400)
+
+        try:
+            data = await asyncio.to_thread(lz4.frame.decompress, await request.body())
+            image = np.frombuffer(data, dtype=np.uint8).reshape(
+                self.tile_size, self.tile_size, 3
+            )
+            output_dtype = np.dtype(requested_dtype)
+        except (RuntimeError, ValueError, TypeError, OverflowError):
+            return Response("Malformed request payload", status_code=400)
#!/bin/bash
# Verify risky parsing operations and whether 4xx handling is present around them.
rg -nP 'lz4\.frame\.decompress|np\.frombuffer|reshape\(|np\.dtype\(' models/virchow2.py -C2
rg -nP 'HTTPException|status_code\s*=\s*4[0-9]{2}' models/virchow2.py -C2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@models/virchow2.py` around lines 97 - 104, The code that calls
lz4.frame.decompress, np.frombuffer(...).reshape(...), and np.dtype(...) can
raise client-caused errors and should be mapped to 4xx responses: wrap the risky
block (the async to_thread call that decompresses request.body(), the
np.frombuffer(...).reshape(self.tile_size, self.tile_size, 3) call, and the
np.dtype(requested_dtype) parsing) in a try/except that catches
lz4.frame.LZ4FrameError (or a broad lz4 error), ValueError and TypeError and
then raise fastapi.HTTPException(status_code=400, detail="...") with a short,
clear message (include the original exception text) instead of allowing a 500;
reference the handler method containing these calls and use
HTTPException/status_code=400 for the response.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@models/virchow2.py`:
- Around line 46-51: The code currently mutates config["model"] by calling
.pop("_target_"), which breaks repeated calls (e.g., reconfigure); instead, read
the target string without removing it and pass a non-mutated kwargs dict to the
provider: extract target_str = config["model"]["_target_"] then do module_path,
attr_name = target_str.split(":"), load provider =
getattr(importlib.import_module(module_path), attr_name), build model_kwargs =
{k:v for k,v in config["model"].items() if k != "_target_"} (or use
config["model"].copy() and del the key on the copy) and call
provider(**model_kwargs); keep using repo_id and logger.info as before.
- Around line 97-104: The code that calls lz4.frame.decompress,
np.frombuffer(...).reshape(...), and np.dtype(...) can raise client-caused
errors and should be mapped to 4xx responses: wrap the risky block (the async
to_thread call that decompresses request.body(), the
np.frombuffer(...).reshape(self.tile_size, self.tile_size, 3) call, and the
np.dtype(requested_dtype) parsing) in a try/except that catches
lz4.frame.LZ4FrameError (or a broad lz4 error), ValueError and TypeError and
then raise fastapi.HTTPException(status_code=400, detail="...") with a short,
clear message (include the original exception text) instead of allowing a 500;
reference the handler method containing these calls and use
HTTPException/status_code=400 for the response.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4e298101-689a-49a4-81ba-f4c51aaa0a3f

📥 Commits

Reviewing files that changed from the base of the PR and between bfcbe5b and 1247e72.

📒 Files selected for processing (1)
  • models/virchow2.py

@Jurgee Jurgee requested a review from matejpekar March 23, 2026 19:55
@Jurgee Jurgee requested a review from matejpekar March 28, 2026 16:03
Copy link
Copy Markdown
Member

@matejpekar matejpekar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is not inline with #5.


RUN pip install --no-cache-dir onnxruntime-gpu tensorrt lz4 ratiopath "mlflow<3.0"
RUN pip install --no-cache-dir \
onnxruntime-gpu tensorrt-cu12 lz4 ratiopath "mlflow<3.0" torch torchvision \
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really need TensorRT for cuda 12? By default it uses cuda 13, same as torch. This way you are using two different cuda version

self.tile_size, self.tile_size, 3
)

requested_dtype = request.headers.get("x-output-dtype", "float32").lower()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output dtype should be applied in the predict method to avoid serialization of large arrays


@serve.batch
async def predict(
self, inputs: list[torch.Tensor | NDArray[np.float16] | NDArray[np.float32]]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept only one type. Also check if ray can serialize tensors

model = cast("torch.nn.Module", self.model)

device_type = self.device.type
autocast_dtype = torch.float16 if device_type == "cuda" else torch.bfloat16
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you switching to bfp16 for CPUs? Is fp16 not supported by the cluster CPUs?

Comment on lines +90 to +92
class_token = output[:, 0]
patch_tokens = output[:, 5:]
embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be optional. Some users might want to access the individual patch tokens

Comment on lines +52 to +60
provider(**model_config)

self.model = timm.create_model(
f"hf-hub:{repo_id}",
pretrained=True,
num_classes=0,
mlp_layer=SwiGLUPacked,
act_layer=torch.nn.SiLU,
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be fragile. I guess the provider downloads the model to cache and then timm loads it from the cache based on the environmental variables set by the provider. Why do you even need to call the provider? I think timm can handle the downloading and storing to cache?

Co-authored-by: Matěj Pekár <matej.pekar120@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants