Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable deepspeed.zero.Init causes very strange spikes in PPO policy_loss #191

Closed
wuxibin89 opened this issue Jan 17, 2024 · 2 comments
Closed
Assignees
Labels
bug Something isn't working P0 High priority

Comments

@wuxibin89
Copy link
Collaborator

Also report this issue to deepspeed microsoft/DeepSpeed#4932 for help

@wuxibin89 wuxibin89 self-assigned this Jan 17, 2024
@wuxibin89 wuxibin89 added bug Something isn't working P0 High priority labels Jan 17, 2024
@hijkzzz
Copy link
Collaborator

hijkzzz commented Jan 22, 2024

For rope-based models, we can fix the training instability using the following patch:

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        hooked_arange = torch.arange
        # use torch.arange of CPU version
        torch.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

        self.cos_cached = self.cos_cached.to("cuda")
        self.sin_cached = self.sin_cached.to("cuda")
        torch.arange = hooked_arange

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

from transformers.models.llama import modeling_llama
modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding

@hijkzzz
Copy link
Collaborator

hijkzzz commented Mar 4, 2024

It has been fixed in transformers v4.38.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P0 High priority
Projects
None yet
Development

No branches or pull requests

2 participants