diff --git a/gpt_neox/gpt_neox.py b/gpt_neox/gpt_neox.py index ab456bb4a..e4e91ce40 100644 --- a/gpt_neox/gpt_neox.py +++ b/gpt_neox/gpt_neox.py @@ -4,10 +4,6 @@ from einops import rearrange -# constants - -MASK_VALUE = -1e7 - # helpers def exists(val): @@ -49,12 +45,12 @@ def forward(self, x, **kwargs): # attention -def dense_attn(q, k, v, key_padding_mask = None, dropout_fn = None): +def dense_attn(q, k, v, attn_mask = None, dropout_fn = None): scale = q.shape[-1] ** -0.5 sim = einsum('b h i d, b h j d -> b h i j', q, k) * scale - if exists(key_padding_mask): - sim = sim + key_padding_mask[:, None, :, :] + if exists(attn_mask): + sim = sim + attn_mask[None, None, :, :] attn = sim.softmax(dim=-1) @@ -103,7 +99,8 @@ def forward(self, x, **kwargs): i, j = q.shape[-2], k.shape[-2] bool_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() mask = torch.zeros(i, j, device=device).to(q) - mask.masked_fill_(bool_mask, MASK_VALUE) + mask_value = -torch.finfo(q.dtype).max + mask.masked_fill_(bool_mask, mask_value) out = self.attn_fn(q, k, v, attn_mask=mask) out = rearrange(out, 'b h n d -> b n (h d)') @@ -132,7 +129,7 @@ def __init__(self, *, num_tokens, dim, seq_len, depth, heads=8, dim_head=64, att for _, layer_sparse_attn in zip(range(depth), layers_sparse_attn): self.layers.append(nn.ModuleList([ - PreNorm(dim, norm_class, Attention(dim=dim, heads=heads, dim_head=dim_head, dropout=attn_dropout, sparse_attn=layer_sparse_attn)), + PreNorm(dim, norm_class, Attention(dim=dim, heads=heads, seq_len=seq_len, dim_head=dim_head, dropout=attn_dropout, sparse_attn=layer_sparse_attn)), PreNorm(dim, norm_class, FeedForward(dim=dim, dropout=ff_dropout)), ])) diff --git a/install_deepspeed.sh b/install_deepspeed.sh index 928d21e76..79add8150 100644 --- a/install_deepspeed.sh +++ b/install_deepspeed.sh @@ -1,3 +1,3 @@ sudo apt-get -y install llvm-9-dev cmake git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed -cd /tmp/Deepspeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh +cd /tmp/Deepspeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh -s