Skip to content

Commit

Permalink
Add flash-attention patch for falcon-7b (#3580)
Browse files Browse the repository at this point in the history
Enable the `use_flash_attention` configuration flag for Falcon models.
When `use_flash_attention` is set to `true` the
[FalconAttention.forwad()](https://github.com/huggingface/transformers/blob/c965d302791cf935d6ea7776428749be678cf509/src/transformers/models/falcon/modeling_falcon.py#L281)
method is replaced with a variant that uses Tri Dao's flash_attention
instead of pytorch's `scaled_dot_product_attention` function.

At the moment the patch works only for falcon-7b but technically it will
also work for falcon-40b with the right configuration. The falcon model
situation is currently a bit messy: The Falcon model was recently added
to Huggingface transformers (see [PR
transformers#24523](huggingface/transformers#24523))
but the falcon models on the hugginface hub use still the code which is
shipped together with the weights (a PR to change this [was
reverted](https://huggingface.co/tiiuae/falcon-7b/discussions/66)).
Falcon-7b and 40b use both slightly different code (which was unified in
the HF transformers impl and can there be controlled via a configuration
member called `new_decoder_architecture` see
[configuration_falcon.py#L65-L67](https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/configuration_falcon.py#L65-L67)).
The HF Falcon impl uses different names in the configuration class, e.g.
compare new
[configuration_falcon.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/configuration_falcon.py)
and old
[configuration_RW.py](https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py)

HF Falcon implementation compatible model configurations can be found
here:
7B:
[config.json](https://huggingface.co/tiiuae/falcon-7b/blob/4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76/config.json)
40B:
[config.json](https://huggingface.co/tiiuae/falcon-40b/blob/f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232/config.json)
  • Loading branch information
andreaskoepf committed Jul 19, 2023
1 parent aa42ed4 commit 1e6e569
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 68 deletions.
39 changes: 36 additions & 3 deletions model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,10 @@ oasst-top1:

falcon-7b:
dtype: bf16
log_dir: "llama_log_7b"
log_dir: "falcon_log_7b"
learning_rate: 1e-5
model_name: "tiiuae/falcon-7b"
deepspeed_config: configs/zero_config_falcon.json
deepspeed_config: configs/zero_config.json
output_dir: falcon
weight_decay: 0.0
max_length: 2048
Expand Down Expand Up @@ -787,7 +787,7 @@ rope_scaling_test:
log_dir: "llama_log_7b"
learning_rate: 1e-5
model_name: "huggyllama/llama-7b"
deepspeed_config: configs/zero_config_falcon.json
deepspeed_config: configs/zero_config.json
output_dir: llama
weight_decay: 0.0
max_length: 4048
Expand All @@ -814,3 +814,36 @@ rope_scaling_test:
scale: 2
datasets:
- dolly15k

falcon_7b_ntk_test:
dtype: bf16
learning_rate: 1e-5
model_name: "tiiuae/falcon-7b"
deepspeed_config: configs/zero_config.json
log_dir: "falcon_7b_ntk"
output_dir: falcon_7b_ntk
weight_decay: 0.0
max_length: 4048
warmup_steps: 100
gradient_checkpointing: true
gradient_accumulation_steps: 2
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
eval_steps: 500
save_steps: 1000
num_train_epochs: 8
save_total_limit: 4
use_flash_attention: true
residual_dropout: 0.3
residual_dropout_lima: true
log_wandb: true
# peft_model: true
# peft_config:
# peft_type: "lora"
# r: 16
superhot: true
superhot_config:
type: ntk
alpha: 2
datasets:
- dolly15k
50 changes: 50 additions & 0 deletions model/model_training/configs/zero3_config_falcon.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"bf16": {
"enabled": "auto"
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
29 changes: 0 additions & 29 deletions model/model_training/configs/zero_config_falcon.json

This file was deleted.

28 changes: 17 additions & 11 deletions model/model_training/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import transformers

# from .gptj import get_model as get_gptj_model

SUPPORTED_MODELS = ["galactica", "gpt-j"]


def freeze_top_n_layers(model, target_layers):
# its possible we can simply detect which module is a ModuleList
Expand All @@ -25,17 +21,27 @@ def freeze_top_n_layers(model, target_layers):


def get_specific_model(
model_name, seq2seqmodel=False, without_head=False, cache_dir=".cache", quantization=False, **kwargs
model_name,
seq2seqmodel=False,
without_head=False,
cache_dir=".cache",
quantization=False,
**kwargs,
):
# encoder-decoder support for Flan-T5 like models
# for now, we can use an argument but in the future,
# we can automate this
if without_head:
model = transformers.AutoModel.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
elif seq2seqmodel:
# encoder-decoder support for Flan-T5 like models
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
else:
if "falcon" in model_name:
kwargs["trust_remote_code"] = True
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
if "falcon-7b" in model_name:
# temporary hack until tiiuae/falcon-7b uses the transformer's Falcon impl by default
# in-library PR was reverted https://huggingface.co/tiiuae/falcon-7b/commit/378337427557d1df3e742264a2901a49f25d4eb1
model = transformers.models.falcon.modeling_falcon.FalconForCausalLM.from_pretrained(
model_name, cache_dir=cache_dir, **kwargs
)
else:
if "falcon" in model_name:
kwargs["trust_remote_code"] = True
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
return model
32 changes: 25 additions & 7 deletions model/model_training/models/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,18 @@

import torch.nn as nn
import transformers
from transformers import AutoConfig, GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel
from transformers import (
AutoConfig,
FalconForCausalLM,
FalconModel,
GPTNeoXForCausalLM,
GPTNeoXModel,
LlamaForCausalLM,
LlamaModel,
)
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead

from .patching_falcon import falcon_forward_with_flash_attn
from .patching_llama import llama_forward_with_flash_attn
from .patching_neox import neox_forward_with_flash_attn
from .reward_model import GPTNeoXRewardModel
Expand All @@ -19,6 +28,8 @@
GPTNeoXForCausalLM,
LlamaForCausalLM,
LlamaModel,
FalconForCausalLM,
FalconModel,
GPTNeoXRewardModel,
# Currently only supported by NeoX models; Will work on LLaMa models
AutoModelForCausalLMWithHydraValueHead,
Expand Down Expand Up @@ -65,6 +76,8 @@ def add_flash_attn(module: nn.Module, causal: bool = True):
if not hasattr(module, "_attn"):
warnings.warn("Provided module doesn't have a _attn() function to be patched.")
module._attn = partial(neox_forward_with_flash_attn, module, flash_attn)
elif isinstance(module, transformers.models.falcon.modeling_falcon.FalconAttention):
module.forward = partial(falcon_forward_with_flash_attn, module, flash_attn)
else:
raise NotImplementedError(f"Flash attention is not implemented for {module.__class__.__name__}.")

Expand Down Expand Up @@ -149,17 +162,22 @@ def patch_model(
if model.__class__.__name__ == "RWForCausalLM":
model = model.base_model

if isinstance(model, FalconForCausalLM):
model = model.transformer

attention_key_lookup = {
GPTNeoXModel: "attention",
GPTNeoXRewardModel: "attention",
LlamaModel: "self_attn",
FalconModel: "self_attention",
}
mlp_key_lookup = {
GPTNeoXModel: "mlp",
GPTNeoXRewardModel: "mlp",
LlamaModel: "mlp",
FalconModel: "mlp",
}
if model.__class__.__name__ == "RWModel":
if isinstance(model, FalconModel) or model.__class__.__name__ == "RWModel":
layers = model.h
attention_key = "self_attention"
mlp_key = "mlp"
Expand Down Expand Up @@ -187,8 +205,8 @@ def __init__(self, model_name, **kwargs):
architecture = config.architectures
if architecture:
self.model_name = architecture[0]
if "RWForCausalLM" in architecture:
self.architecture = "RWForCausalLM"
if "FalconForCausalLM" in architecture or "RWForCausalLM" in architecture:
self.architecture = "FalconForCausalLM"
if rope_type == "ntk":
self.patch_fun = RWNTKScaledRope
else:
Expand All @@ -213,14 +231,14 @@ def from_config(cls, config):
return cls(model_name, **args)

def patch(self, model):
if self.architecture == "RWForCausalLM":
self.patch_rw_model(model, **self.args)
if self.architecture == "FalconForCausalLM":
self.patch_falcon_model(model, **self.args)
elif self.architecture == "LlamaForCausalLM":
self.patch_llama_model(model, **self.args)
else:
raise NotImplementedError()

def patch_rw_model(self, model, **kwargs):
def patch_falcon_model(self, model, **kwargs):
for each in model.transformer.h:
each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)

Expand Down
78 changes: 78 additions & 0 deletions model/model_training/models/patching_falcon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Optional, Tuple

import torch
import torch.nn as nn

from .patching_utils import compute_flash_attention


def falcon_forward_with_flash_attn(
self,
flash_attn: nn.Module, # flash_attn.modules.mha.FlashSelfAttention
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
head_mask, alibi & output_attention are not supported.
Reference to the original `FalconAttention.forwad()` method which this patch replaces:
https://github.com/huggingface/transformers/blob/c965d302791cf935d6ea7776428749be678cf509/src/transformers/models/falcon/modeling_falcon.py#L281
"""

assert head_mask is None # not supported.
assert alibi is None # not supported.
assert not output_attentions # not supported.

fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

batch_size, query_length, _, _ = query_layer.shape

query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)

past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)

if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)

if use_cache:
present = (key_layer, value_layer)
else:
present = None

query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)

q = query_layer_.permute(0, 2, 1, 3)
k = key_layer_.permute(0, 2, 1, 3).expand(q.shape)
v = value_layer_.permute(0, 2, 1, 3).expand(q.shape)

if attention_mask is not None:
attention_mask = attention_mask[:, 0, -1]

flash_attn.train(self.training)
attn_output = compute_flash_attention(flash_attn, q, k, v, attention_mask=attention_mask)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)

output_tensor = self.dense(attn_output)

return output_tensor, present
12 changes: 9 additions & 3 deletions model/model_training/models/patching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ def compute_flash_attention(flash_attn, q, k, v, attention_mask=None, head_mask=
# attention_mask (float): [bs, seq_len]
batch_size, max_len = q.size(0), q.size(1)

qkv = torch.stack([q, k, v], dim=2).to(torch.float16) # need to truncate in case input is fp32
qkv = torch.stack([q, k, v], dim=2)
dtype_in = qkv.dtype
if dtype_in == torch.float32:
qkv = qkv.to(torch.float16) # need to truncate in case input is fp32
cu_seqlens, max_seqlen = None, None

if attention_mask is None:
return flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
out = flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
else:
# Limitation: non-contiguous attention mask will not be handled correctly
# model will be able to pay attention between the first and last non-masked token, i.e. left- and right-side padding is supported.
Expand All @@ -35,7 +38,10 @@ def compute_flash_attention(flash_attn, q, k, v, attention_mask=None, head_mask=
for i in range(batch_size)
]
out = torch.stack(padded_seqs)
return out

if out.dtype != dtype_in:
out = out.to(dtype_in)
return out


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 1e6e569

Please sign in to comment.