Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust Online DPO changes for RL update #1664

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update llama.py did not included rotary embeddings in the reset funct…
…ions correctly
  • Loading branch information
pluesclues authored Jan 7, 2025
commit 4705906536f8aa1a10143a3cfa814ddd50f05bdc
4 changes: 4 additions & 0 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
@@ -1511,6 +1511,10 @@ def _fast_generate(*args, **kwargs):
original_llama_model_forward = LlamaModel.forward
original_llama_for_causal_lm_forward = LlamaForCausalLM.forward
original_peft_model_for_causal_lm_forward = PeftModelForCausalLM.forward
import transformers.models.llama.modeling_llama
original_LLamaRotaryEmbedding = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
original_LLamaLinearScalingRotaryEmbedding = transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding

class FastLlamaModel:
def set_functions():
LlamaAttention .forward = LlamaAttention_fast_forward