fix: Add the compatibility processing of MPS devices for Float8_e4m3fn data types#12406
fix: Add the compatibility processing of MPS devices for Float8_e4m3fn data types#12406Kiruno-lz wants to merge 8 commits intoComfy-Org:masterfrom
Conversation
在MPS设备上,PyTorch不支持直接创建或转换float8数据类型。这会导致在尝试将float8张量移至MPS设备时出现运行时错误。 - 在 `stochastic_rounding` 函数中添加递归回退到CPU的逻辑,当检测到MPS设备且目标类型为float8时。 - 在 `cast_to` 函数中添加针对MPS设备的特殊处理,确保float8张量在移动到MPS前先转换为float16。 - 测试了两个workflow:templates-6-key-frames和image_z_image_turbo 相比Comfy-Org#12378 更具有维护性
which version will release this fix #12406 ? |
|
+1 |
Test Evidence CheckIf this PR changes user-facing behavior, visual proof (screen recording or screenshot) is required. PRs without applicable visual documentation may not be reviewed until provided. You can add it by:
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds MPS-specific handling for float8 types in two files. comfy/float.py: 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. 📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can enforce grammar and style rules using `languagetool`.Configure the |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
comfy/float.py (1)
58-62: Returned tensor changes device from MPS to CPU — consider documenting this contract.When the input is on MPS and the target dtype is float8, the returned tensor will be on CPU rather than the original MPS device. This is a necessary workaround since MPS doesn't support float8, but it changes the function's implicit contract.
Callers that subsequently need the tensor on MPS will need to handle this (via
cast_towhich converts to float16). Consider adding a brief docstring or inline note clarifying this behavior for future maintainers.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@comfy/float.py` around lines 58 - 62, Document the device-change behavior of stochastic_rounding: when stochastic_rounding(value, dtype, seed) sees value.device.type == "mps" and forces a CPU conversion (cpu_value = value.to("cpu")) it returns a CPU tensor (not MPS) because MPS doesn't support float8; update the stochastic_rounding function docstring or add an inline comment near the MPS workaround to state this explicit contract, note that callers who need the tensor back on MPS should re-cast (e.g., via cast_to which converts to float16) and include the dtype=float8 limitation and recommended workaround in the comment.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy/model_management.py`:
- Line 1250: In cast_to, when handling an explicit dtype that's in FLOAT8_TYPES
and the target device is MPS, update the dtype variable after converting/moving
the weight to float16 so subsequent calls don't attempt weight.to(dtype=float8)
on MPS; specifically, in the cast_to function adjust the branch that converts
weights to float16 for MPS (the lines around dtype = torch.float16 if dtype is
None else dtype and the FLOAT8 handling) to set dtype = torch.float16 whenever
dtype is a FLOAT8_TYPES value and the device is MPS.
---
Nitpick comments:
In `@comfy/float.py`:
- Around line 58-62: Document the device-change behavior of stochastic_rounding:
when stochastic_rounding(value, dtype, seed) sees value.device.type == "mps" and
forces a CPU conversion (cpu_value = value.to("cpu")) it returns a CPU tensor
(not MPS) because MPS doesn't support float8; update the stochastic_rounding
function docstring or add an inline comment near the MPS workaround to state
this explicit contract, note that callers who need the tensor back on MPS should
re-cast (e.g., via cast_to which converts to float16) and include the
dtype=float8 limitation and recommended workaround in the comment.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 08aaaed2-2dd1-4a7d-9f5a-6a2a70e1fb14
📒 Files selected for processing (2)
comfy/float.pycomfy/model_management.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
comfy/model_management.py (1)
1209-1224:⚠️ Potential issue | 🔴 CriticalBug: Explicit FLOAT8
dtypeparameter will still cause RuntimeError on MPS.When
cast_tois called with an explicitdtypeinFLOAT8_TYPEStargeting MPS, the weight is correctly converted to float16 and moved to MPS at line 1223. However, line 1224 only updatesdtypewhen it'sNone, leaving an explicit FLOAT8 dtype unchanged. This causes the subsequent code at line 1235 to attemptweight.to(dtype=float8)on MPS, which will fail.Proposed fix
- dtype = torch.float16 if dtype is None else dtype + if dtype is None or dtype in FLOAT8_TYPES: + dtype = torch.float16🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@comfy/model_management.py` around lines 1209 - 1224, The code fails when an explicit FLOAT8 dtype is passed while targeting MPS because dtype is only changed when None; update the branch that handles MPS/FLOAT8 (using symbols target_device, is_quantized, real_dtype, FLOAT8_TYPES, weight) to force dtype to torch.float16 after converting/dequantizing and moving weight to MPS; specifically, if real_dtype or the passed dtype is in FLOAT8_TYPES, set dtype = torch.float16 (overriding an explicit FLOAT8) so later weight.to(dtype=...) does not attempt to move a FLOAT8 dtype to MPS.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@comfy/model_management.py`:
- Around line 1209-1224: The code fails when an explicit FLOAT8 dtype is passed
while targeting MPS because dtype is only changed when None; update the branch
that handles MPS/FLOAT8 (using symbols target_device, is_quantized, real_dtype,
FLOAT8_TYPES, weight) to force dtype to torch.float16 after
converting/dequantizing and moving weight to MPS; specifically, if real_dtype or
the passed dtype is in FLOAT8_TYPES, set dtype = torch.float16 (overriding an
explicit FLOAT8) so later weight.to(dtype=...) does not attempt to move a FLOAT8
dtype to MPS.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 1fd4c955-38b6-41b6-a299-3d10e4ea5e0f
📒 Files selected for processing (1)
comfy/model_management.py
…o` function. When the dtype is `None` or `float8`, it should be uniformly converted to `float16` to prevent type mismatches during the conversion of `float8` types.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy/model_management.py`:
- Line 1213: stochastic_rounding in comfy/float.py currently only handles
float8_e4m3fn and float8_e5m2 but may be called with any member of FLOAT8_TYPES
(e.g., float8_e4m3fnuz, float8_e5m2fnuz, float8_e8m0fnu); update
stochastic_rounding to explicitly handle all FLOAT8_TYPES (implement appropriate
rounding logic for float8_e4m3fnuz, float8_e5m2fnuz, float8_e8m0fnu) or, if
those types are not supported, have stochastic_rounding raise a clear exception
or log an error instead of returning None; reference the stochastic_rounding
function and the FLOAT8_TYPES constant and ensure callers in comfy/ops.py and
comfy/model_patcher.py receive a definite result or catch the explicit error.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c12f640f-cec8-4f6f-87cc-03af9aa643a4
📒 Files selected for processing (1)
comfy/model_management.py
| if target_device is not None and is_device_mps(target_device): | ||
| is_quantized = hasattr(weight, "storage_dtype") | ||
| real_dtype = weight.storage_dtype if is_quantized else weight.dtype | ||
| if real_dtype in FLOAT8_TYPES: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
echo "=== FLOAT8_TYPES definition in model_management.py ==="
rg -n "FLOAT8_TYPES\s*=" comfy/model_management.py -A 1
echo ""
echo "=== Full context of FLOAT8_TYPES and its usage ==="
rg -n "FLOAT8_TYPES" comfy/model_management.py
echo ""
echo "=== stochastic_rounding function in float.py ==="
rg -n "def stochastic_rounding" comfy/float.py -A 20
echo ""
echo "=== float8 type checks in float.py ==="
rg -n "float8" comfy/float.py -B 2 -A 2Repository: Comfy-Org/ComfyUI
Length of output: 3245
🏁 Script executed:
#!/bin/bash
echo "=== get_supported_float8_types definition ==="
rg -n "def get_supported_float8_types" comfy/model_management.py -A 10
echo ""
echo "=== Check if there are other float8 type references ==="
rg -n "float8_e" comfy/model_management.py
echo ""
echo "=== Verify what torch.float8 variants exist and are being used ==="
rg -n "torch\.float8" comfy/ -t py --no-heading | sort -u | head -20Repository: Comfy-Org/ComfyUI
Length of output: 2623
🏁 Script executed:
#!/bin/bash
echo "=== Context around line 1213 in model_management.py ==="
sed -n '1210,1230p' comfy/model_management.py
echo ""
echo "=== Check function containing line 1213 ==="
rg -n -B 20 "if real_dtype in FLOAT8_TYPES" comfy/model_management.py | head -40
echo ""
echo "=== Check where stochastic_rounding is called ==="
rg -n "stochastic_rounding" comfy/Repository: Comfy-Org/ComfyUI
Length of output: 3702
Fix stochastic_rounding to handle all supported FLOAT8 types, or clarify which types it supports.
The stochastic_rounding function in comfy/float.py only explicitly handles float8_e4m3fn and float8_e5m2, but it can be called from comfy/ops.py (line 191) and comfy/model_patcher.py with weight.dtype or orig.dtype, which could be any of the 5 types in FLOAT8_TYPES (including e4m3fnuz, e5m2fnuz, e8m0fnu). Unsupported types currently return None silently.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@comfy/model_management.py` at line 1213, stochastic_rounding in
comfy/float.py currently only handles float8_e4m3fn and float8_e5m2 but may be
called with any member of FLOAT8_TYPES (e.g., float8_e4m3fnuz, float8_e5m2fnuz,
float8_e8m0fnu); update stochastic_rounding to explicitly handle all
FLOAT8_TYPES (implement appropriate rounding logic for float8_e4m3fnuz,
float8_e5m2fnuz, float8_e8m0fnu) or, if those types are not supported, have
stochastic_rounding raise a clear exception or log an error instead of returning
None; reference the stochastic_rounding function and the FLOAT8_TYPES constant
and ensure callers in comfy/ops.py and comfy/model_patcher.py receive a definite
result or catch the explicit error.
On MPS devices, PyTorch lacks the capability to directly create or convert float8 data types. Consequently, attempting to move a float8 tensor to an MPS device will result in a runtime error.
When the MPS device is detected and the target type is Float8_e4m3fn, incorporate the logic for recursive regression into the
stochastic_roundingfunction on the CPU.Modify the
cast_tofunction to include special processing for MPS devices. This ensures that the float8 tensor is converted to float16 before being moved to the MPS device.Conducted two workflows: templates-6-key-frames and image_z_image_turbo.
This solution is more maintainable compared to the approach outlined in #12378.