Skip to content

Commit

Permalink
Update pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
GeneZC committed Dec 8, 2023
1 parent ecac7d0 commit eb91ded
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 113 deletions.
9 changes: 6 additions & 3 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ Now you get `llama2-7b-ada`.

**Pruning**

The pruning is executed with 1 Nvidia A100 GPU and only a small portion of adaptation data, and pruning data should be builded to 512 for pruning efficiency.
The pruning is executed with 8 Nvidia A100 GPUs and only a small portion of adaptation data, and pruning data should be builded to 512 for pruning efficiency.

The following is an example script (i.e., `scripts/build_pruning_data.sh`) to build pruning data (e.g., part of WuDao):
```bash
Expand All @@ -112,14 +112,17 @@ python run_building_data_llama.py \
The following is an example script (i.e., `scripts/prune_llama.sh`) to prune `llama2-7b-ada`:

```bash
python run_sparsification_llama.py \
CUDA_LAUNCH_BLOCKING=1 torchrun --nproc_per_node=$GPU_NUM --nnodes=$NODE_WORLD_SIZE --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT run_sparsification_llama.py \
--model_type sparsellama_lm \
--teacher_model_name_or_path path/to/llama2-7b-ada \
--record_path_or_regex "dir/to/builded/part-of-wudao/*.tfrecord" \
--data_type llama_lm \
--output_dir dir/to/outputs \
--max_length 512 \
--per_device_eval_batch_size 2 \
--per_device_eval_batch_size 8 \
--use_act_ckpt \
--use_bf16 \
--deepspeed ds_config.json \
--model_suffix 7b
```

Expand Down
6 changes: 3 additions & 3 deletions minima/models/sparsellama_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import modules.modeling_sparsellama
from modules.modeling_sparsellama import SparseLlamaModel, SparseLlamaAttention, SparseLlamaForCausalLM
# from modules.flash_attn_monkey_patch_sparsellama import _prepare_decoder_attention_mask, forward
from modules.flash_attn_monkey_patch_sparsellama import _prepare_decoder_attention_mask, forward
from modules.modeling_sparsellama import SparseLlamaForCausalLM as CustomizedLlamaForCausalLM


# SparseLlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
# SparseLlamaAttention.forward = forward
SparseLlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
SparseLlamaAttention.forward = forward
SparseLlamaForCausalLM.forward = CustomizedLlamaForCausalLM.forward

import collections
Expand Down
13 changes: 11 additions & 2 deletions minima/modules/flash_attn_monkey_patch_sparsellama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from torch import nn

from transformers.models.llama.modeling_llama import repeat_kv

from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange
Expand All @@ -27,11 +29,11 @@ def forward(
attention_mask: [bsz, q_len]
"""
self.hidden_size_sparsified = self.o_proj.in_features_sparsified
self.num_heads_sparsified = int(self.hidden_size_sparsified / self.head_dim)
self.num_heads_sparsified = int(self.hidden_size_sparsified / self.head_dim / self.num_key_value_groups)

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads_sparsified, self.head_dim).transpose(1, 2)
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_key_value_groups * self.num_heads_sparsified, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads_sparsified, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads_sparsified, self.head_dim).transpose(1, 2)
# [bsz, q_len, nh, hd]
Expand All @@ -46,6 +48,10 @@ def forward(
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"

# Repeat k/v heads if n_kv_heads < n_heads.
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py

Expand Down Expand Up @@ -79,6 +85,9 @@ def forward(
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, bsz, q_len),
'b s (h d) -> b s h d', h=nheads)
if head_mask is not None:
head_mask = repeat_kv(head_mask, self.num_key_value_groups)
output = output * head_mask.transpose(1, 2)
return self.o_proj(rearrange(output,
'b s h d -> b s (h d)')), None, None

Expand Down
74 changes: 49 additions & 25 deletions minima/modules/modeling_sparsellama.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ def __init__(self, hidden_size, eps=1e-6, sparsified_elements=0):
self.variance_epsilon = eps
self.normalized_shape_origin = (hidden_size,)
self.normalized_shape_sparsified = (hidden_size,)
self.sparsify(sparsified_elements)
self.densify()
if sparsified_elements:
self.sparsify(sparsified_elements)
self.densify()

def forward(self, hidden_states, hidden_mask=None):
weight = self.weight[:self.normalized_shape_sparsified[0]]
Expand Down Expand Up @@ -182,8 +183,9 @@ def __init__(self, num_embeddings, embedding_dim, padding_idx=None, sparsified_e
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
self.embedding_dim_origin = embedding_dim
self.embedding_dim_sparsified = embedding_dim
self.sparsify(sparsified_elements)
self.densify()
if sparsified_elements:
self.sparsify(sparsified_elements)
self.densify()

def forward(self, input):
weight = self.weight[:, :self.embedding_dim_sparsified]
Expand Down Expand Up @@ -222,9 +224,10 @@ def __init__(self, in_features, out_features, bias=True, element_size=1, dim=0,
self.out_features_sparsified = out_features
self.element_size = element_size
self.dim = dim
self.sparsify(sparsified_elements[0])
self.sparsify(sparsified_elements[1], for_hidden=True)
self.densify()
if sparsified_elements[0] or sparsified_elements[1]:
self.sparsify(sparsified_elements[0])
self.sparsify(sparsified_elements[1], for_hidden=True)
self.densify()

def forward(self, input):
weight = self.weight[:self.out_features_sparsified, :self.in_features_sparsified]
Expand Down Expand Up @@ -314,6 +317,18 @@ def forward(self, x, neuron_mask=None):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class SparseLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -323,21 +338,24 @@ def __init__(self, config: LlamaConfig, sparsified_heads, sparsified_hiddens):
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = getattr(config, "rope_theta", 10000)

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.hidden_size_sparsified = self.hidden_size
self.num_heads_sparsified = self.num_heads
self.num_heads_sparsified = self.num_key_value_heads

self.q_proj = SparseLinear(self.hidden_size, self.num_heads * self.head_dim, bias=False, element_size=self.head_dim, dim=1, sparsified_elements=(sparsified_heads, sparsified_hiddens))
self.k_proj = SparseLinear(self.hidden_size, self.num_heads * self.head_dim, bias=False, element_size=self.head_dim, dim=1, sparsified_elements=(sparsified_heads, sparsified_hiddens))
self.v_proj = SparseLinear(self.hidden_size, self.num_heads * self.head_dim, bias=False, element_size=self.head_dim, dim=1, sparsified_elements=(sparsified_heads, sparsified_hiddens))
self.o_proj = SparseLinear(self.num_heads * self.head_dim, self.hidden_size, bias=False, element_size=self.head_dim, dim=0, sparsified_elements=(sparsified_heads, sparsified_hiddens))
self.rotary_emb = SparseLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
self.q_proj = SparseLinear(self.hidden_size, self.num_heads * self.head_dim, bias=False, element_size=self.num_key_value_groups * self.head_dim, dim=1, sparsified_elements=(sparsified_heads, sparsified_hiddens))
self.k_proj = SparseLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, element_size=self.head_dim, dim=1, sparsified_elements=(sparsified_heads, sparsified_hiddens))
self.v_proj = SparseLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, element_size=self.head_dim, dim=1, sparsified_elements=(sparsified_heads, sparsified_hiddens))
self.o_proj = SparseLinear(self.num_heads * self.head_dim, self.hidden_size, bias=False, element_size=self.num_key_value_groups * self.head_dim, dim=0, sparsified_elements=(sparsified_heads, sparsified_hiddens))
self.rotary_emb = SparseLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads_sparsified, self.head_dim).transpose(1, 2).contiguous()
Expand All @@ -353,11 +371,11 @@ def forward(
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
self.hidden_size_sparsified = self.o_proj.in_features_sparsified
self.num_heads_sparsified = int(self.hidden_size_sparsified / self.head_dim)
self.num_heads_sparsified = int(self.hidden_size_sparsified / self.head_dim / self.num_key_value_groups)

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads_sparsified, self.head_dim).transpose(1, 2)
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_key_value_groups * self.num_heads_sparsified, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads_sparsified, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads_sparsified, self.head_dim).transpose(1, 2)

Expand All @@ -375,11 +393,15 @@ def forward(

past_key_value = (key_states, value_states) if use_cache else None

# Repeat k/v heads if n_kv_heads < n_heads.
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads_sparsified, q_len, kv_seq_len):
if attn_weights.size() != (bsz, self.num_key_value_groups * self.num_heads_sparsified, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads_sparsified, q_len, kv_seq_len)}, but is"
f"Attention weights should be of size {(bsz, self.num_key_value_groups * self.num_heads_sparsified, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

Expand All @@ -394,12 +416,13 @@ def forward(
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
if head_mask is not None:
head_mask = repeat_kv(head_mask, self.num_key_value_groups)
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads_sparsified, q_len, self.head_dim):
if attn_output.size() != (bsz, self.num_key_value_groups * self.num_heads_sparsified, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads_sparsified, q_len, self.head_dim)}, but is"
f"`attn_output` should be of size {(bsz, self.num_key_value_groups * self.num_heads_sparsified, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

Expand Down Expand Up @@ -723,12 +746,13 @@ def custom_forward(*inputs):

def reorder(self, head_indices, neuron_indices, hidden_indices):
for layer_idx, indices in head_indices.items():
n, h = self.layers[layer_idx].self_attn.num_heads, self.layers[layer_idx].self_attn.head_dim
indices = torch.arange(n * h).reshape(n, h)[indices.cpu()].reshape(-1).contiguous().long()
self.layers[layer_idx].self_attn.q_proj.reorder(indices)
self.layers[layer_idx].self_attn.k_proj.reorder(indices)
self.layers[layer_idx].self_attn.v_proj.reorder(indices)
self.layers[layer_idx].self_attn.o_proj.reorder(indices)
n, g, h = self.layers[layer_idx].self_attn.num_key_value_heads, self.layers[layer_idx].self_attn.num_key_value_groups, self.layers[layer_idx].self_attn.head_dim
qo_indices = torch.arange(n * g * h).reshape(n, g * h)[indices.cpu()].reshape(-1).contiguous().long()
kv_indices = torch.arange(n * h).reshape(n, h)[indices.cpu()].reshape(-1).contiguous().long()
self.layers[layer_idx].self_attn.q_proj.reorder(qo_indices)
self.layers[layer_idx].self_attn.k_proj.reorder(kv_indices)
self.layers[layer_idx].self_attn.v_proj.reorder(kv_indices)
self.layers[layer_idx].self_attn.o_proj.reorder(qo_indices)
for layer_idx, indices in neuron_indices.items():
self.layers[layer_idx].mlp.up_proj.reorder(indices)
self.layers[layer_idx].mlp.gate_proj.reorder(indices)
Expand Down
Loading

0 comments on commit eb91ded

Please sign in to comment.