Skip to content

Commit

Permalink
feat: replace current attention mechanism with flash-attn
Browse files Browse the repository at this point in the history
  • Loading branch information
luzian-hahn committed Mar 18, 2024
1 parent c3242e3 commit 7d27b59
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 90 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ dependencies = [
"jq",
"xformers",
"class_resolver",
"wandb"
"wandb",
"flash-attn" # install this directly via `pip install flash-attn --no-build-isolation`

]

[project.optional-dependencies]
linting = ["pre-commit"]
tests = ["pytest", "pytest-cov"]
install_helper = ["ninja"]

[project.scripts]
modalities = "modalities.__main__:main"
Expand Down
51 changes: 12 additions & 39 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
import xformers.ops as xops
from flash_attn import flash_attn_func
from pydantic import BaseModel, Field, model_validator
from torch.nn import functional as F

Expand Down Expand Up @@ -127,21 +128,13 @@ def __init__(
)

# regularization
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self._dropout = dropout
self.resid_dropout = nn.Dropout(self._dropout)
self.n_head_q = n_head_q
self.n_head_kv = n_head_kv

self.n_embd = n_embd
self.dropout = dropout
self.flash = attention_type == AttentionType.PYTORCH_FLASH_ATTENTION

if not self.flash:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"bias",
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, _ = x.size() # batch size (B), sequence length (T), embedding dimensionality (self.n_embd)
Expand All @@ -151,35 +144,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
k = self.k_attn(x) # (B, T, n_embd / n_rep)
v = self.v_attn(x) # (B, T, n_embd / n_rep)

q = q.view(B, T, self.n_head_q, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_q, T, hs)
k = k.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_kv, T, hs)
v = v.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_kv, T, hs)

# repeat k/v heads if self.n_rep > 1
k = repeat_kv(k, self.n_rep) # (B, nh_q, T, hs)
v = repeat_kv(v, self.n_rep) # (B, nh_q, T, hs)

# causal self-attention; Self-attend: (B, nh_q, T, hs) x (B, nh_q, hs, T) -> (B, nh_q, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=True,
)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh_q, T, T)
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh_q, T, T) x (B, nh_q, T, hs) -> (B, nh_q, T, hs)
y = (
y.transpose(1, 2).contiguous().view(B, T, self.n_embd)
) # (B, T, n_embd), re-assemble all head outputs side by side
q = q.view(B, T, self.n_head_q, self.n_embd // self.n_head_q) # (B, T, nh_q, hs)
k = k.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q) # (B, T, nh_kv, hs)
v = v.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q) # (B, T, nh_kv, hs)

# TODO: make parameters configurable
y = flash_attn_func(
q, k, v, dropout_p=self._dropout, causal=True, softmax_scale=None, window_size=(-1, -1)
).reshape(B, T, self.n_embd)
# (B, T, n_embd), re-assemble all head outputs side by side

# output projection
y = self.resid_dropout(self.c_proj(y)) # (B, T, n_embd)
Expand Down
69 changes: 19 additions & 50 deletions tests/models/test_causal_self_attention.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
from copy import deepcopy

import pytest
import torch

from modalities.models.gpt2.gpt2_model import AttentionType, CausalSelfAttention


def _get_random_input_seq(embedding_shape):
return torch.rand(size=embedding_shape, dtype=torch.float32)
flash_attn_supported_dtype = torch.bfloat16
return torch.rand(size=embedding_shape, dtype=flash_attn_supported_dtype)


def _get_random_attention_layer(n_head_q, n_head_kv, n_embd, attention_type, block_size):
return CausalSelfAttention(
self_attention_layer = CausalSelfAttention(
n_head_q=n_head_q,
n_head_kv=n_head_kv,
n_embd=n_embd,
attention_type=attention_type,
bias=False,
dropout=0.0,
block_size=block_size,
)
).cuda()
self_attention_layer.q_attn = self_attention_layer.q_attn.bfloat16()
self_attention_layer.k_attn = self_attention_layer.k_attn.bfloat16()
self_attention_layer.v_attn = self_attention_layer.v_attn.bfloat16()
self_attention_layer.c_proj = self_attention_layer.c_proj.bfloat16()
return self_attention_layer


@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This e2e test requires 1 GPU.")
@pytest.mark.parametrize(
"n_head_q, n_head_kv, n_embd, attention_type, successful",
"n_head_q, n_head_kv, n_embd, successful",
[
# Flash Attention
(4, 4, 32, AttentionType.PYTORCH_FLASH_ATTENTION, True),
(8, 2, 32, AttentionType.PYTORCH_FLASH_ATTENTION, True),
(9, 8, 32, AttentionType.PYTORCH_FLASH_ATTENTION, False),
(8, 3, 32, AttentionType.PYTORCH_FLASH_ATTENTION, False),
# Default Attention
(4, 4, 32, AttentionType.DEFAULT_ATTENTION, True),
(8, 2, 32, AttentionType.DEFAULT_ATTENTION, True),
(9, 8, 32, AttentionType.DEFAULT_ATTENTION, False),
(8, 3, 32, AttentionType.DEFAULT_ATTENTION, False),
(4, 4, 32, True),
(8, 2, 32, True),
(9, 8, 32, False),
(8, 3, 32, False),
],
)
def test_forward_pass_success(n_head_q, n_head_kv, n_embd, attention_type, successful):
def test_forward_pass_success(n_head_q, n_head_kv, n_embd, successful):
batch_size = 2
block_size = 10
embedding_shape = (batch_size, block_size, n_embd)
Expand All @@ -46,45 +45,15 @@ def test_forward_pass_success(n_head_q, n_head_kv, n_embd, attention_type, succe
"n_head_q": n_head_q,
"n_head_kv": n_head_kv,
"n_embd": n_embd,
"attention_type": attention_type,
"attention_type": AttentionType.DEFAULT_ATTENTION,
"block_size": block_size,
}

if not successful:
with pytest.raises(Exception):
_get_random_attention_layer(**attention_layer_args)
else:
attention_layer = _get_random_attention_layer(**attention_layer_args)
embedded_input_seq = _get_random_input_seq(embedding_shape)
attention_layer = _get_random_attention_layer(**attention_layer_args).cuda()
embedded_input_seq = _get_random_input_seq(embedding_shape).cuda()
output_tensor = attention_layer(embedded_input_seq)
assert output_tensor.shape == embedding_shape


@pytest.mark.parametrize(
"n_head_q, n_head_kv, n_embd",
[
(4, 4, 32),
(8, 2, 32),
],
)
def test_attention_types_equality(n_head_q, n_head_kv, n_embd):
batch_size = 2
block_size = 10
embedding_shape = (batch_size, block_size, n_embd)
embedded_input_seq = _get_random_input_seq(embedding_shape)

attention_layer_args = {
"n_head_q": n_head_q,
"n_head_kv": n_head_kv,
"n_embd": n_embd,
"attention_type": AttentionType.DEFAULT_ATTENTION,
"block_size": block_size,
}

attention_layer_default = _get_random_attention_layer(**attention_layer_args)
attention_layer_flash = deepcopy(attention_layer_default)
attention_layer_flash.flash = True

output_tensor_default = attention_layer_default(embedded_input_seq)
output_tensor_flash = attention_layer_flash(embedded_input_seq)
torch.testing.assert_close(output_tensor_default, output_tensor_flash)

0 comments on commit 7d27b59

Please sign in to comment.