Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust Online DPO changes for RL update #1664

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update llama.py making set and reset functions in order to properly u…
…se autoSequenceClassification
  • Loading branch information
pluesclues authored Dec 12, 2024
commit 744562847ab2925ac5f131014e9a61b89d57d2d1
28 changes: 27 additions & 1 deletion unsloth/models/llama.py
Original file line number Diff line number Diff line change
@@ -1504,9 +1504,35 @@ def _fast_generate(*args, **kwargs):
return _fast_generate
pass


original_llama_attention_forward = LlamaAttention.forward
original_llama_sdpa_attention_forward = LlamaSdpaAttention.forward
original_llama_flash_attention2_forward = LlamaFlashAttention2.forward
original_llama_decoder_layer_forward = LlamaDecoderLayer.forward
original_llama_model_forward = LlamaModel.forward
original_llama_for_causal_lm_forward = LlamaForCausalLM.forward
original_peft_model_for_causal_lm_forward = PeftModelForCausalLM.forward
class FastLlamaModel:
def set_functions():
LlamaAttention .forward = LlamaAttention_fast_forward
LlamaSdpaAttention .forward = LlamaAttention_fast_forward
LlamaFlashAttention2.forward = LlamaAttention_fast_forward
LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward
LlamaModel .forward = LlamaModel_fast_forward
LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding

def reset_functions():
LlamaAttention .forward = original_llama_attention_forward
LlamaSdpaAttention .forward = original_llama_sdpa_attention_forward
LlamaFlashAttention2.forward = original_llama_flash_attention2_forward
LlamaDecoderLayer .forward = original_llama_decoder_layer_forward
LlamaModel .forward = original_llama_model_forward
LlamaForCausalLM .forward = original_llama_for_causal_lm_forward
PeftModelForCausalLM.forward = original_peft_model_for_causal_lm_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = original_LLamaRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = original_LLamaLinearScalingRotaryEmbedding
@staticmethod
def pre_patch():
init_name, function = patch_llama_rope_scaling(