Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust Online DPO changes for RL update #1664

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update llama.py: kept reset functions and made their names generic
  • Loading branch information
pluesclues authored Feb 8, 2025
commit d219f2563db45a5f6ab6d0f6d57e7e3bac47d86c
26 changes: 14 additions & 12 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
@@ -1583,12 +1583,13 @@ def _fast_generate(*args, **kwargs):
return _fast_generate
pass

original_llama_attention_forward = LlamaAttention.forward
original_llama_sdpa_attention_forward = LlamaSdpaAttention.forward
original_llama_flash_attention2_forward = LlamaFlashAttention2.forward
original_llama_decoder_layer_forward = LlamaDecoderLayer.forward
original_llama_model_forward = LlamaModel.forward
original_llama_for_causal_lm_forward = LlamaForCausalLM.forward

original_attention_forward = LlamaAttention.forward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that's very smart on keeping the old function!

original_sdpa_attention_forward = LlamaSdpaAttention.forward
original_flash_attention2_forward = LlamaFlashAttention2.forward
original_decoder_layer_forward = LlamaDecoderLayer.forward
original_model_forward = LlamaModel.forward
original_for_causal_lm_forward = LlamaForCausalLM.forward
original_peft_model_for_causal_lm_forward = PeftModelForCausalLM.forward
import transformers.models.llama.modeling_llama
original_LLamaRotaryEmbedding = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
@@ -1607,15 +1608,16 @@ def set_functions():
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding

def reset_functions():
LlamaAttention .forward = original_llama_attention_forward
LlamaSdpaAttention .forward = original_llama_sdpa_attention_forward
LlamaFlashAttention2.forward = original_llama_flash_attention2_forward
LlamaDecoderLayer .forward = original_llama_decoder_layer_forward
LlamaModel .forward = original_llama_model_forward
LlamaForCausalLM .forward = original_llama_for_causal_lm_forward
LlamaAttention .forward = original_attention_forward
LlamaSdpaAttention .forward = original_sdpa_attention_forward
LlamaFlashAttention2.forward = original_flash_attention2_forward
LlamaDecoderLayer .forward = original_decoder_layer_forward
LlamaModel .forward = original_model_forward
LlamaForCausalLM .forward = original_for_causal_lm_forward
PeftModelForCausalLM.forward = original_peft_model_for_causal_lm_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = original_LLamaRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = original_LLamaLinearScalingRotaryEmbedding

@staticmethod
def pre_patch():
init_name, function = patch_llama_rope_scaling(