Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash-attention patch for falcon-7b #3580

Merged
merged 5 commits into from Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 37 additions & 3 deletions model/model_training/configs/config.yaml
Expand Up @@ -88,6 +88,7 @@ defaults:
deepspeed_config: configs/zero_config.json
peft_model: false
peft_type: "lora"
superhot: false

use_system_tag:
use_system_tag: True
Expand Down Expand Up @@ -256,10 +257,10 @@ oasst-top1:

falcon-7b:
dtype: bf16
log_dir: "llama_log_7b"
log_dir: "falcon_log_7b"
learning_rate: 1e-5
model_name: "tiiuae/falcon-7b"
deepspeed_config: configs/zero_config_falcon.json
deepspeed_config: configs/zero_config.json
output_dir: falcon
weight_decay: 0.0
max_length: 2048
Expand Down Expand Up @@ -786,7 +787,7 @@ rope_scaling_test:
log_dir: "llama_log_7b"
learning_rate: 1e-5
model_name: "huggyllama/llama-7b"
deepspeed_config: configs/zero_config_falcon.json
deepspeed_config: configs/zero_config.json
output_dir: llama
weight_decay: 0.0
max_length: 4048
Expand All @@ -813,3 +814,36 @@ rope_scaling_test:
scale: 2
datasets:
- dolly15k

falcon_7b_ntk_test:
dtype: bf16
learning_rate: 1e-5
model_name: "tiiuae/falcon-7b"
deepspeed_config: configs/zero_config.json
log_dir: "falcon_7b_ntk"
output_dir: falcon_7b_ntk
weight_decay: 0.0
max_length: 4048
warmup_steps: 100
gradient_checkpointing: true
gradient_accumulation_steps: 2
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
eval_steps: 500
save_steps: 1000
num_train_epochs: 8
save_total_limit: 4
use_flash_attention: true
residual_dropout: 0.3
residual_dropout_lima: true
log_wandb: true
# peft_model: true
# peft_config:
# peft_type: "lora"
# r: 16
superhot: true
superhot_config:
type: ntk
alpha: 2
datasets:
- dolly15k
50 changes: 50 additions & 0 deletions model/model_training/configs/zero3_config_falcon.json
@@ -0,0 +1,50 @@
{
"bf16": {
"enabled": "auto"
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
29 changes: 0 additions & 29 deletions model/model_training/configs/zero_config_falcon.json

This file was deleted.

28 changes: 17 additions & 11 deletions model/model_training/models/__init__.py
@@ -1,9 +1,5 @@
import transformers

# from .gptj import get_model as get_gptj_model

SUPPORTED_MODELS = ["galactica", "gpt-j"]


def freeze_top_n_layers(model, target_layers):
# its possible we can simply detect which module is a ModuleList
Expand All @@ -25,17 +21,27 @@ def freeze_top_n_layers(model, target_layers):


def get_specific_model(
model_name, seq2seqmodel=False, without_head=False, cache_dir=".cache", quantization=False, **kwargs
model_name,
seq2seqmodel=False,
without_head=False,
cache_dir=".cache",
quantization=False,
**kwargs,
):
# encoder-decoder support for Flan-T5 like models
# for now, we can use an argument but in the future,
# we can automate this
if without_head:
model = transformers.AutoModel.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
elif seq2seqmodel:
# encoder-decoder support for Flan-T5 like models
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
else:
if "falcon" in model_name:
kwargs["trust_remote_code"] = True
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
if "falcon-7b" in model_name:
# temporary hack until tiiuae/falcon-7b uses the transformer's Falcon impl by default
# in-library PR was reverted https://huggingface.co/tiiuae/falcon-7b/commit/378337427557d1df3e742264a2901a49f25d4eb1
model = transformers.models.falcon.modeling_falcon.FalconForCausalLM.from_pretrained(
model_name, cache_dir=cache_dir, **kwargs
)
else:
if "falcon" in model_name:
kwargs["trust_remote_code"] = True
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, **kwargs)
jordiclive marked this conversation as resolved.
Show resolved Hide resolved
return model
32 changes: 25 additions & 7 deletions model/model_training/models/patching.py
Expand Up @@ -6,9 +6,18 @@

import torch.nn as nn
import transformers
from transformers import AutoConfig, GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel
from transformers import (
AutoConfig,
FalconForCausalLM,
FalconModel,
GPTNeoXForCausalLM,
GPTNeoXModel,
LlamaForCausalLM,
LlamaModel,
)
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead

from .patching_falcon import falcon_forward_with_flash_attn
from .patching_llama import llama_forward_with_flash_attn
from .patching_neox import neox_forward_with_flash_attn
from .reward_model import GPTNeoXRewardModel
Expand All @@ -19,6 +28,8 @@
GPTNeoXForCausalLM,
LlamaForCausalLM,
LlamaModel,
FalconForCausalLM,
FalconModel,
GPTNeoXRewardModel,
# Currently only supported by NeoX models; Will work on LLaMa models
AutoModelForCausalLMWithHydraValueHead,
Expand Down Expand Up @@ -65,6 +76,8 @@ def add_flash_attn(module: nn.Module, causal: bool = True):
if not hasattr(module, "_attn"):
warnings.warn("Provided module doesn't have a _attn() function to be patched.")
module._attn = partial(neox_forward_with_flash_attn, module, flash_attn)
elif isinstance(module, transformers.models.falcon.modeling_falcon.FalconAttention):
module.forward = partial(falcon_forward_with_flash_attn, module, flash_attn)
else:
raise NotImplementedError(f"Flash attention is not implemented for {module.__class__.__name__}.")

Expand Down Expand Up @@ -149,17 +162,22 @@ def patch_model(
if model.__class__.__name__ == "RWForCausalLM":
model = model.base_model

if isinstance(model, FalconForCausalLM):
model = model.transformer

attention_key_lookup = {
GPTNeoXModel: "attention",
GPTNeoXRewardModel: "attention",
LlamaModel: "self_attn",
FalconModel: "self_attention",
}
mlp_key_lookup = {
GPTNeoXModel: "mlp",
GPTNeoXRewardModel: "mlp",
LlamaModel: "mlp",
FalconModel: "mlp",
}
if model.__class__.__name__ == "RWModel":
if isinstance(model, FalconModel) or model.__class__.__name__ == "RWModel":
layers = model.h
attention_key = "self_attention"
mlp_key = "mlp"
Expand Down Expand Up @@ -187,8 +205,8 @@ def __init__(self, model_name, **kwargs):
architecture = config.architectures
if architecture:
self.model_name = architecture[0]
if "RWForCausalLM" in architecture:
self.architecture = "RWForCausalLM"
if "FalconForCausalLM" in architecture or "RWForCausalLM" in architecture:
self.architecture = "FalconForCausalLM"
if rope_type == "ntk":
self.patch_fun = RWNTKScaledRope
else:
Expand All @@ -213,14 +231,14 @@ def from_config(cls, config):
return cls(model_name, **args)

def patch(self, model):
if self.architecture == "RWForCausalLM":
self.patch_rw_model(model, **self.args)
if self.architecture == "FalconForCausalLM":
self.patch_falcon_model(model, **self.args)
elif self.architecture == "LlamaForCausalLM":
self.patch_llama_model(model, **self.args)
else:
raise NotImplementedError()

def patch_rw_model(self, model, **kwargs):
def patch_falcon_model(self, model, **kwargs):
for each in model.transformer.h:
each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)

Expand Down
78 changes: 78 additions & 0 deletions model/model_training/models/patching_falcon.py
@@ -0,0 +1,78 @@
from typing import Optional, Tuple

import torch
import torch.nn as nn

from .patching_utils import compute_flash_attention


def falcon_forward_with_flash_attn(
self,
flash_attn: nn.Module, # flash_attn.modules.mha.FlashSelfAttention
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
head_mask, alibi & output_attention are not supported.
Reference to the original `FalconAttention.forwad()` method which this patch replaces:
https://github.com/huggingface/transformers/blob/c965d302791cf935d6ea7776428749be678cf509/src/transformers/models/falcon/modeling_falcon.py#L281
"""

assert head_mask is None # not supported.
assert alibi is None # not supported.
assert not output_attentions # not supported.

fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

batch_size, query_length, _, _ = query_layer.shape

query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)

past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)

if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)

if use_cache:
present = (key_layer, value_layer)
else:
present = None

query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)

q = query_layer_.permute(0, 2, 1, 3)
k = key_layer_.permute(0, 2, 1, 3).expand(q.shape)
v = value_layer_.permute(0, 2, 1, 3).expand(q.shape)

if attention_mask is not None:
attention_mask = attention_mask[:, 0, -1]

flash_attn.train(self.training)
attn_output = compute_flash_attention(flash_attn, q, k, v, attention_mask=attention_mask)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
jordiclive marked this conversation as resolved.
Show resolved Hide resolved

output_tensor = self.dense(attn_output)

return output_tensor, present
12 changes: 9 additions & 3 deletions model/model_training/models/patching_utils.py
Expand Up @@ -7,11 +7,14 @@ def compute_flash_attention(flash_attn, q, k, v, attention_mask=None, head_mask=
# attention_mask (float): [bs, seq_len]
batch_size, max_len = q.size(0), q.size(1)

qkv = torch.stack([q, k, v], dim=2).to(torch.float16) # need to truncate in case input is fp32
qkv = torch.stack([q, k, v], dim=2)
dtype_in = qkv.dtype
if dtype_in == torch.float32:
qkv = qkv.to(torch.float16) # need to truncate in case input is fp32
jordiclive marked this conversation as resolved.
Show resolved Hide resolved
cu_seqlens, max_seqlen = None, None

if attention_mask is None:
return flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
out = flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
else:
# Limitation: non-contiguous attention mask will not be handled correctly
# model will be able to pay attention between the first and last non-masked token, i.e. left- and right-side padding is supported.
Expand All @@ -35,7 +38,10 @@ def compute_flash_attention(flash_attn, q, k, v, attention_mask=None, head_mask=
for i in range(batch_size)
]
out = torch.stack(padded_seqs)
return out

if out.dtype != dtype_in:
out = out.to(dtype_in)
return out


if __name__ == "__main__":
Expand Down