Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize rl #2065

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
@@ -2735,18 +2735,15 @@ def _for_training(m):
m = m.model
_for_training(m)

# Also re-enable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
if hasattr(embeddings, "training"): embeddings.training = True
pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"): embeddings.training = True
pass
for get_embeddings_fn in (
model.get_input_embeddings,
model.get_output_embeddings):
Copy link
Contributor

@KareemMusleh KareemMusleh Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this result in an error if get_input_embeddings or get_output_embeddings aren't defined for model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i think you might be right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if hasattr(model, get_embeddings_fn.__name__):
embeddings = get_embeddings_fn()
if hasattr(embeddings, "training"):
embeddings.training = True
return model
pass
pass

from .rl import PatchFastRL
PatchFastRL(FastLanguageModel = FastLlamaModel)
12 changes: 8 additions & 4 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
@@ -576,10 +576,14 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
old_init = init

# Remove peft_config
init = init.replace("elif peft_config is None:", "elif False:")
init = init.replace("elif peft_config is not None:", "elif False:")
init = init.replace("if peft_config is None:", "if False:")
init = init.replace("if peft_config is not None:", "if False:")
replacements = {
"elif peft_config is None:": "elif False:",
"elif peft_config is not None:": "elif False:",
"if peft_config is None:": "if False:",
"if peft_config is not None:": "if False:",
}
for old, new in replacements.items():
init = init.replace(old, new)
init = init.replace("get_peft_model(model, peft_config)", "model")

# Set use_vllm if not set
1 change: 0 additions & 1 deletion unsloth/models/vision.py
Original file line number Diff line number Diff line change
@@ -50,7 +50,6 @@
import os
import gc
import math
import functools
from typing import Optional, Tuple, List, Union
import re, inspect, sys
import types