Skip to content

Fix Unsloth autocast dtype for bf16 models#663

Merged
vivekkalyan merged 2 commits into
mainfrom
fix/infer-unsloth-autocast-dtype
Apr 27, 2026
Merged

Fix Unsloth autocast dtype for bf16 models#663
vivekkalyan merged 2 commits into
mainfrom
fix/infer-unsloth-autocast-dtype

Conversation

@vivekkalyan
Copy link
Copy Markdown
Collaborator

Summary

PipelineTrainer + LocalBackend can load bf16 base models like Llama 3.1 in 16-bit mode while ACCELERATE_MIXED_PRECISION is unset. We were defaulting the Unsloth logprob path to fp16 in that case, which can hit Half vs BFloat16 matmul mismatches.

This keeps explicit mixed-precision settings authoritative, but when the env var is unset, infers the autocast dtype from the loaded model parameters. Unknown mixed-precision/model dtype states fail early instead of silently falling through.

Validation

  • .venv/bin/pytest tests/unit/test_unsloth_autocast_dtype.py -q
  • Sky 2x H200 smoke: Llama 3.1 + PipelineTrainer + dedicated LocalBackend, load_in_4bit=False, load_in_16bit=True, ACCELERATE_MIXED_PRECISION unset; reached step 2 and reloaded adapters for steps 1 and 2.

@vivekkalyan vivekkalyan requested a review from arcticfly April 27, 2026 21:27
Copy link
Copy Markdown
Collaborator

@arcticfly arcticfly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @vivekkalyan!

@vivekkalyan vivekkalyan merged commit 5cfe180 into main Apr 27, 2026
4 checks passed
@vivekkalyan vivekkalyan deleted the fix/infer-unsloth-autocast-dtype branch April 27, 2026 22:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants