From 9785fe4052b440268360e87831b52dd2e802b7b4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 24 Apr 2026 15:31:38 -0400 Subject: [PATCH] Fix two test failures introduced by transformers v5 support - modeling_apriel2.py: Guard _init_weights under _TRANSFORMERS_V4. In v5, from_pretrained calls initialize_weights() after loading, which re-invokes _init_weights via smart_apply. The raw .data.normal_() calls bypass the guard_torch_init_functions patching that respects _is_hf_initialized, clobbering all loaded weights. In v5 the inherited PreTrainedModel default (which uses init.* functions) handles all cases correctly. - test_lm_head.py: Add num_documents_in_batch=1 to GRPO test kwargs, which LanguageModelGRPOLoss._forward_backward requires when computing the new_logprobs metric. Co-Authored-By: Claude Sonnet 4.6 --- .../apriel2/modeling_apriel2.py | 26 ++++++++++--------- tests/layers/test_lm_head.py | 1 + 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 58f082b12..15538d787 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -2330,18 +2330,20 @@ def _prepare_cache_for_generation( return model_kwargs["past_key_values"] = Apriel2Cache(config=self.config) - def _init_weights(self, module): - std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, MistralRMSNorm): - module.weight.data.fill_(1.0) + if _TRANSFORMERS_V4: + + def _init_weights(self, module): + std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) def tie_weights(self, **kwargs): super().tie_weights(**kwargs) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 73e9f4807..aa50fbb5e 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -147,6 +147,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: torch.full(input_.shape[:-1], float((labels_ >= 0).sum()), dtype=torch.float32, device=device) for labels_ in kwargs[LanguageModelKwargs.labels] ] + kwargs[LanguageModelKwargs.num_documents_in_batch] = 1 return input_, kwargs def get_reference_outputs(