We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Also report this issue to deepspeed microsoft/DeepSpeed#4932 for help
The text was updated successfully, but these errors were encountered:
For rope-based models, we can fix the training instability using the following patch:
class LlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() hooked_arange = torch.arange # use torch.arange of CPU version torch.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) self.cos_cached = self.cos_cached.to("cuda") self.sin_cached = self.sin_cached.to("cuda") torch.arange = hooked_arange def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) from transformers.models.llama import modeling_llama modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
Sorry, something went wrong.
177f042
It has been fixed in transformers v4.38.2
wuxibin89
No branches or pull requests
Also report this issue to deepspeed microsoft/DeepSpeed#4932 for help
The text was updated successfully, but these errors were encountered: