Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install torch nightly for flash attention #11

Merged
merged 4 commits into from May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/cpu-tests.yml
Expand Up @@ -40,6 +40,9 @@ jobs:

- name: Run tests without the package installed
run: |
# install torch cpu nightly
python -c "with open('requirements.txt', 'r+') as fp: c = fp.read().replace('cu118', 'cpu'); fp.seek(0); fp.write(c); fp.truncate()"

pip install pytest -r requirements.txt 'transformers==4.27.3'
pip list

Expand Down
7 changes: 2 additions & 5 deletions generate.py
Expand Up @@ -114,17 +114,14 @@ def main(
config = StableLMConfig(**json.load(fp))

fabric = L.Fabric(devices=1)
dtype = torch.float32
# avoid RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
# see https://github.com/Lightning-AI/lit-stablelm/issues/2
# dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.time()
with EmptyInitOnDevice(device=fabric.device, dtype=dtype, quantization_mode=quantize):
model = StableLM(config)
with lazy_load(checkpoint_path) as checkpoint:
model.load_state_dict(checkpoint, strict=False)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
Expand Down
8 changes: 8 additions & 0 deletions lit_stablelm/__init__.py
@@ -1,2 +1,10 @@
from lit_stablelm.model import StableLMConfig, StableLM, build_rope_cache, apply_rope
from lit_stablelm.tokenizer import Tokenizer

from lightning_utilities.core.imports import RequirementCache

if not bool(RequirementCache("torch>=2.1.0dev")):
raise ImportError(
"Lit-StableLM requires torch nightly (future torch 2.1). Please follow the installation instructions in the"
" repository README.md"
)
26 changes: 6 additions & 20 deletions lit_stablelm/model.py
Expand Up @@ -126,14 +126,6 @@ def __init__(self, config: StableLMConfig) -> None:
self.rotary_percentage = config.rotary_percentage
self.rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None

if self.rotary_percentage != 1.0:
self.register_buffer(
"bias",
torch.tril(torch.ones(config.block_size, config.block_size)).view(
1, 1, config.block_size, config.block_size
),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

Expand All @@ -153,17 +145,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)

if hasattr(self, "bias"):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
# NOTE: cannot use flash attention because it takes q.size(-1) as the norm factor which is different to the
# head size when rotary_percentage is set
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_size))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
else:
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, scale=1.0 / math.sqrt(head_size)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)

y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

Expand Down Expand Up @@ -219,4 +204,5 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
return (x * cos) + (rotated * sin)
roped = (x * cos) + (rotated * sin)
return roped.type_as(x)
5 changes: 4 additions & 1 deletion requirements.txt
@@ -1,4 +1,7 @@
torch>=2.0.0
# we require torch nightly for flash attention support
# NOTE: you might need to replace "cu118" with "cpu" or "cu117"
--extra-index-url https://download.pytorch.org/whl/nightly/cu118 --pre
torch>=2.1.0dev
lightning @ git+https://github.com/Lightning-AI/lightning@master
tokenizers
jsonargparse[signatures] # CLI
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Expand Up @@ -15,7 +15,10 @@
author='Lightning AI',
url='https://github.com/lightning-AI/lit-stablelm',
install_requires=[
"torch>=2.0.0",
# needs to be installed with
# pip install . --extra-index-url https://download.pytorch.org/whl/nightly/cu118 --pre
# NOTE: you might need to replace "cu118" with "cpu" or "cu117"
"torch>=2.1.0dev",
"lightning @ git+https://github.com/Lightning-AI/lightning@master",
"tokenizers",
],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Expand Up @@ -78,7 +78,7 @@ def test_against_hf_model(rotary_pct, batch_size, n_embd, lit_stablelm) -> None:

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
@torch.no_grad()
def test_bfloat16_llama_init(lit_stablelm) -> None:
def test_model_bfloat16(lit_stablelm) -> None:
from lit_stablelm.utils import EmptyInitOnDevice

block_size = 64
Expand Down