diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 58f082b12..15538d787 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -2330,18 +2330,20 @@ def _prepare_cache_for_generation( return model_kwargs["past_key_values"] = Apriel2Cache(config=self.config) - def _init_weights(self, module): - std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, MistralRMSNorm): - module.weight.data.fill_(1.0) + if _TRANSFORMERS_V4: + + def _init_weights(self, module): + std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) def tie_weights(self, **kwargs): super().tie_weights(**kwargs) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 73e9f4807..aa50fbb5e 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -147,6 +147,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: torch.full(input_.shape[:-1], float((labels_ >= 0).sum()), dtype=torch.float32, device=device) for labels_ in kwargs[LanguageModelKwargs.labels] ] + kwargs[LanguageModelKwargs.num_documents_in_batch] = 1 return input_, kwargs def get_reference_outputs(