Skip to content

fix: Add the compatibility processing of MPS devices for Float8_e4m3fn data types#12406

Open
Kiruno-lz wants to merge 8 commits intoComfy-Org:masterfrom
Kiruno-lz:fix/mps-Float8_e4m3fn
Open

fix: Add the compatibility processing of MPS devices for Float8_e4m3fn data types#12406
Kiruno-lz wants to merge 8 commits intoComfy-Org:masterfrom
Kiruno-lz:fix/mps-Float8_e4m3fn

Conversation

@Kiruno-lz
Copy link

@Kiruno-lz Kiruno-lz commented Feb 11, 2026

On MPS devices, PyTorch lacks the capability to directly create or convert float8 data types. Consequently, attempting to move a float8 tensor to an MPS device will result in a runtime error.

  • When the MPS device is detected and the target type is Float8_e4m3fn, incorporate the logic for recursive regression into the stochastic_rounding function on the CPU.

  • Modify the cast_to function to include special processing for MPS devices. This ensures that the float8 tensor is converted to float16 before being moved to the MPS device.

  • Conducted two workflows: templates-6-key-frames and image_z_image_turbo.

This solution is more maintainable compared to the approach outlined in #12378.

在MPS设备上,PyTorch不支持直接创建或转换float8数据类型。这会导致在尝试将float8张量移至MPS设备时出现运行时错误。

- 在 `stochastic_rounding` 函数中添加递归回退到CPU的逻辑,当检测到MPS设备且目标类型为float8时。
- 在 `cast_to` 函数中添加针对MPS设备的特殊处理,确保float8张量在移动到MPS前先转换为float16。
- 测试了两个workflow:templates-6-key-frames和image_z_image_turbo

相比Comfy-Org#12378 更具有维护性
@Kiruno-lz Kiruno-lz changed the title fix: 添加MPS设备对float8数据类型的兼容性处理 fix: Add the compatibility processing of MPS devices for Float8_e4m3fn data types Feb 11, 2026
@kalias
Copy link

kalias commented Feb 25, 2026

On MPS devices, PyTorch lacks the capability to directly create or convert float8 data types. Consequently, attempting to move a float8 tensor to an MPS device will result in a runtime error.

  • When the MPS device is detected and the target type is Float8_e4m3fn, incorporate the logic for recursive regression into the stochastic_rounding function on the CPU.
  • Modify the cast_to function to include special processing for MPS devices. This ensures that the float8 tensor is converted to float16 before being moved to the MPS device.
  • Conducted two workflows: templates-6-key-frames and image_z_image_turbo.

This solution is more maintainable compared to the approach outlined in #12378.

which version will release this fix #12406 ?

@altasol
Copy link

altasol commented Feb 25, 2026

+1

@comfy-pr-bot
Copy link
Member

Test Evidence Check

⚠️ Warning: Visual Documentation Missing

If this PR changes user-facing behavior, visual proof (screen recording or screenshot) is required. PRs without applicable visual documentation may not be reviewed until provided.

You can add it by:

  • GitHub: Drag & drop media directly into the PR description
  • YouTube: Include a link to a short demo

@Kiruno-lz Kiruno-lz marked this pull request as draft March 6, 2026 05:50
@Kiruno-lz Kiruno-lz marked this pull request as ready for review March 6, 2026 05:50
@coderabbitai
Copy link

coderabbitai bot commented Mar 6, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 3a21b0b8-0e26-4307-86e4-cefc9ca8ed8a

📥 Commits

Reviewing files that changed from the base of the PR and between 891cd07 and d729086.

📒 Files selected for processing (2)
  • comfy/float.py
  • comfy/model_management.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • comfy/float.py
  • comfy/model_management.py

📝 Walkthrough

Walkthrough

Adds MPS-specific handling for float8 types in two files. comfy/float.py: stochastic_rounding detects float8 dtypes on MPS devices, moves the tensor to CPU, and recursively calls stochastic_rounding on the CPU before continuing with CPU-based generation and casting. comfy/model_management.py: cast_to adds a branch for FLOAT8 weights targeting MPS that ensures the weight is moved to CPU, dequantizes if quantized, converts FLOAT8 to float, and then moves/casts the tensor to the MPS device, defaulting the dtype to float16 when appropriate. Function signatures are unchanged.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding MPS device compatibility for Float8_e4m3fn data types by implementing workarounds in stochastic_rounding and cast_to functions.
Description check ✅ Passed The description is directly related to the changeset, explaining the rationale for the MPS compatibility fixes and detailing the modifications made to both stochastic_rounding and cast_to functions.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can enforce grammar and style rules using `languagetool`.

Configure the reviews.tools.languagetool setting to enable/disable rules and categories. Refer to the LanguageTool Community to learn more.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
comfy/float.py (1)

58-62: Returned tensor changes device from MPS to CPU — consider documenting this contract.

When the input is on MPS and the target dtype is float8, the returned tensor will be on CPU rather than the original MPS device. This is a necessary workaround since MPS doesn't support float8, but it changes the function's implicit contract.

Callers that subsequently need the tensor on MPS will need to handle this (via cast_to which converts to float16). Consider adding a brief docstring or inline note clarifying this behavior for future maintainers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/float.py` around lines 58 - 62, Document the device-change behavior of
stochastic_rounding: when stochastic_rounding(value, dtype, seed) sees
value.device.type == "mps" and forces a CPU conversion (cpu_value =
value.to("cpu")) it returns a CPU tensor (not MPS) because MPS doesn't support
float8; update the stochastic_rounding function docstring or add an inline
comment near the MPS workaround to state this explicit contract, note that
callers who need the tensor back on MPS should re-cast (e.g., via cast_to which
converts to float16) and include the dtype=float8 limitation and recommended
workaround in the comment.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@comfy/model_management.py`:
- Line 1250: In cast_to, when handling an explicit dtype that's in FLOAT8_TYPES
and the target device is MPS, update the dtype variable after converting/moving
the weight to float16 so subsequent calls don't attempt weight.to(dtype=float8)
on MPS; specifically, in the cast_to function adjust the branch that converts
weights to float16 for MPS (the lines around dtype = torch.float16 if dtype is
None else dtype and the FLOAT8 handling) to set dtype = torch.float16 whenever
dtype is a FLOAT8_TYPES value and the device is MPS.

---

Nitpick comments:
In `@comfy/float.py`:
- Around line 58-62: Document the device-change behavior of stochastic_rounding:
when stochastic_rounding(value, dtype, seed) sees value.device.type == "mps" and
forces a CPU conversion (cpu_value = value.to("cpu")) it returns a CPU tensor
(not MPS) because MPS doesn't support float8; update the stochastic_rounding
function docstring or add an inline comment near the MPS workaround to state
this explicit contract, note that callers who need the tensor back on MPS should
re-cast (e.g., via cast_to which converts to float16) and include the
dtype=float8 limitation and recommended workaround in the comment.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 08aaaed2-2dd1-4a7d-9f5a-6a2a70e1fb14

📥 Commits

Reviewing files that changed from the base of the PR and between dc9822b and f47fb54.

📒 Files selected for processing (2)
  • comfy/float.py
  • comfy/model_management.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
comfy/model_management.py (1)

1209-1224: ⚠️ Potential issue | 🔴 Critical

Bug: Explicit FLOAT8 dtype parameter will still cause RuntimeError on MPS.

When cast_to is called with an explicit dtype in FLOAT8_TYPES targeting MPS, the weight is correctly converted to float16 and moved to MPS at line 1223. However, line 1224 only updates dtype when it's None, leaving an explicit FLOAT8 dtype unchanged. This causes the subsequent code at line 1235 to attempt weight.to(dtype=float8) on MPS, which will fail.

Proposed fix
-            dtype = torch.float16 if dtype is None else dtype
+            if dtype is None or dtype in FLOAT8_TYPES:
+                dtype = torch.float16
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/model_management.py` around lines 1209 - 1224, The code fails when an
explicit FLOAT8 dtype is passed while targeting MPS because dtype is only
changed when None; update the branch that handles MPS/FLOAT8 (using symbols
target_device, is_quantized, real_dtype, FLOAT8_TYPES, weight) to force dtype to
torch.float16 after converting/dequantizing and moving weight to MPS;
specifically, if real_dtype or the passed dtype is in FLOAT8_TYPES, set dtype =
torch.float16 (overriding an explicit FLOAT8) so later weight.to(dtype=...) does
not attempt to move a FLOAT8 dtype to MPS.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@comfy/model_management.py`:
- Around line 1209-1224: The code fails when an explicit FLOAT8 dtype is passed
while targeting MPS because dtype is only changed when None; update the branch
that handles MPS/FLOAT8 (using symbols target_device, is_quantized, real_dtype,
FLOAT8_TYPES, weight) to force dtype to torch.float16 after
converting/dequantizing and moving weight to MPS; specifically, if real_dtype or
the passed dtype is in FLOAT8_TYPES, set dtype = torch.float16 (overriding an
explicit FLOAT8) so later weight.to(dtype=...) does not attempt to move a FLOAT8
dtype to MPS.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 1fd4c955-38b6-41b6-a299-3d10e4ea5e0f

📥 Commits

Reviewing files that changed from the base of the PR and between f47fb54 and 5f01662.

📒 Files selected for processing (1)
  • comfy/model_management.py

…o` function.

When the dtype is `None` or `float8`, it should be uniformly converted to `float16` to prevent type mismatches during the conversion of `float8` types.
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@comfy/model_management.py`:
- Line 1213: stochastic_rounding in comfy/float.py currently only handles
float8_e4m3fn and float8_e5m2 but may be called with any member of FLOAT8_TYPES
(e.g., float8_e4m3fnuz, float8_e5m2fnuz, float8_e8m0fnu); update
stochastic_rounding to explicitly handle all FLOAT8_TYPES (implement appropriate
rounding logic for float8_e4m3fnuz, float8_e5m2fnuz, float8_e8m0fnu) or, if
those types are not supported, have stochastic_rounding raise a clear exception
or log an error instead of returning None; reference the stochastic_rounding
function and the FLOAT8_TYPES constant and ensure callers in comfy/ops.py and
comfy/model_patcher.py receive a definite result or catch the explicit error.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c12f640f-cec8-4f6f-87cc-03af9aa643a4

📥 Commits

Reviewing files that changed from the base of the PR and between 5f01662 and b7cbe8a.

📒 Files selected for processing (1)
  • comfy/model_management.py

if target_device is not None and is_device_mps(target_device):
is_quantized = hasattr(weight, "storage_dtype")
real_dtype = weight.storage_dtype if is_quantized else weight.dtype
if real_dtype in FLOAT8_TYPES:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash

echo "=== FLOAT8_TYPES definition in model_management.py ==="
rg -n "FLOAT8_TYPES\s*=" comfy/model_management.py -A 1

echo ""
echo "=== Full context of FLOAT8_TYPES and its usage ==="
rg -n "FLOAT8_TYPES" comfy/model_management.py

echo ""
echo "=== stochastic_rounding function in float.py ==="
rg -n "def stochastic_rounding" comfy/float.py -A 20

echo ""
echo "=== float8 type checks in float.py ==="
rg -n "float8" comfy/float.py -B 2 -A 2

Repository: Comfy-Org/ComfyUI

Length of output: 3245


🏁 Script executed:

#!/bin/bash

echo "=== get_supported_float8_types definition ==="
rg -n "def get_supported_float8_types" comfy/model_management.py -A 10

echo ""
echo "=== Check if there are other float8 type references ==="
rg -n "float8_e" comfy/model_management.py

echo ""
echo "=== Verify what torch.float8 variants exist and are being used ==="
rg -n "torch\.float8" comfy/ -t py --no-heading | sort -u | head -20

Repository: Comfy-Org/ComfyUI

Length of output: 2623


🏁 Script executed:

#!/bin/bash

echo "=== Context around line 1213 in model_management.py ==="
sed -n '1210,1230p' comfy/model_management.py

echo ""
echo "=== Check function containing line 1213 ==="
rg -n -B 20 "if real_dtype in FLOAT8_TYPES" comfy/model_management.py | head -40

echo ""
echo "=== Check where stochastic_rounding is called ==="
rg -n "stochastic_rounding" comfy/

Repository: Comfy-Org/ComfyUI

Length of output: 3702


Fix stochastic_rounding to handle all supported FLOAT8 types, or clarify which types it supports.

The stochastic_rounding function in comfy/float.py only explicitly handles float8_e4m3fn and float8_e5m2, but it can be called from comfy/ops.py (line 191) and comfy/model_patcher.py with weight.dtype or orig.dtype, which could be any of the 5 types in FLOAT8_TYPES (including e4m3fnuz, e5m2fnuz, e8m0fnu). Unsupported types currently return None silently.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/model_management.py` at line 1213, stochastic_rounding in
comfy/float.py currently only handles float8_e4m3fn and float8_e5m2 but may be
called with any member of FLOAT8_TYPES (e.g., float8_e4m3fnuz, float8_e5m2fnuz,
float8_e8m0fnu); update stochastic_rounding to explicitly handle all
FLOAT8_TYPES (implement appropriate rounding logic for float8_e4m3fnuz,
float8_e5m2fnuz, float8_e8m0fnu) or, if those types are not supported, have
stochastic_rounding raise a clear exception or log an error instead of returning
None; reference the stochastic_rounding function and the FLOAT8_TYPES constant
and ensure callers in comfy/ops.py and comfy/model_patcher.py receive a definite
result or catch the explicit error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants