In [212]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.palu_attention import HeadwiseLowRankModule
from transformers.models.llama.modeling_llama import LlamaConfig

config = LlamaConfig()

q_len = 3
group_size = 4
num_heads = config.num_attention_heads
hidden_dim = config.hidden_size
total_rank_v = 4096
out_features = config.hidden_size
num_groups = num_heads // group_size
head_dim = hidden_dim // num_heads
group_rank = total_rank_v // num_groups
group_dim = hidden_dim // num_groups

v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
o_proj = nn.Linear(hidden_dim, out_features, bias=False)

In [213]:
inputs = torch.randn(1, q_len, hidden_dim)
# attn_weights: (bsz, num_heads, q_len, kv_seq_len)
attn_weight = torch.randn(1, num_heads, q_len, q_len)
# value_states: (bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
v_states = v_proj(inputs).view(1, q_len, num_heads, head_dim).transpose(1, 2)
attn_output = torch.matmul(attn_weight, v_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(1, q_len, -1)
ori_output = o_proj(attn_output)

print(f'attn_weight size: {attn_weight.size()}')
print(f'v_states size: {v_states.size()}')
print(f'attn_output size: {attn_output.size()}')
print(f'ori_output size: {ori_output.size()}')

attn_weight size: torch.Size([1, 32, 3, 3])
v_states size: torch.Size([1, 32, 3, 128])
attn_output size: torch.Size([1, 3, 4096])
ori_output size: torch.Size([1, 3, 4096])


In [214]:
# Decompose and fuse v_proj and o_proj
total_rank_v_list = [group_rank for _ in range(num_groups)]
new_v_proj = HeadwiseLowRankModule.from_linear(v_proj, total_rank_v_list)

# attn_weights: (bsz, num_groups, q_len * group_size, kv_seq_len)
attn_weight = attn_weight.reshape(1, num_heads, q_len, q_len)
print(f'attn_weight size: {attn_weight.size()}')

# value_h_states: (bsz, q_len, num_groups, fused_group_dim).transpose(1, 2)
v_h_states = new_v_proj.project_to_latent(inputs) #.reshape(1, q_len, num_groups, group_size).transpose(1, 2)
print(f'v_h_states size: {v_h_states.size()}')

# value_states: (bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
v_states = new_v_proj.reconstruct(v_h_states).reshape(1, q_len, num_heads, head_dim).transpose(1, 2)
print(f'v_states size: {v_states.size()}')

attn_output = torch.matmul(attn_weight, v_states)
print(f'attn_output size: {attn_output.size()}')
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(1, q_len, -1)
new_v_output = o_proj(attn_output)
torch.testing.assert_close(ori_output, new_v_output)

attn_weight size: torch.Size([1, 32, 3, 3])
v_h_states size: torch.Size([1, 3, 4096])
v_states size: torch.Size([1, 32, 3, 128])
attn_output size: torch.Size([1, 32, 3, 128])


In [215]:
# Decompose and fuse v_proj and o_proj
total_rank_v_list = [group_rank for _ in range(num_groups)]
new_v_proj = HeadwiseLowRankModule.from_linear(v_proj, total_rank_v_list)

# attn_weights: (bsz, num_groups, q_len * group_size, kv_seq_len)
attn_weight = attn_weight.reshape(1, num_groups, q_len * group_size, q_len)
print(f'attn_weight size: {attn_weight.size()}')
# value_h_states: (bsz, kv_seq_len, num_groups, group_rank).transpose(1, 2)
v_h_states = new_v_proj.project_to_latent(inputs).reshape(1, q_len, num_groups, group_rank).transpose(1, 2)
print(f'v_h_states size: {v_h_states.size()}')

# attn_h_output: (bsz, num_heads, q_len * group_size, group_rank)
attn_h_output = torch.matmul(attn_weight, v_h_states)
print(f'attn_h_output size: {attn_h_output.size()}')
attn_h_output = attn_h_output.reshape(1, num_heads, q_len, group_rank)
print(f'attn_h_output size: {attn_h_output.size()}')

outputs = []
total_dims = 0
total_ranks = 0
for i in range(num_heads):
    output = F.linear(attn_h_output[:, i, :, :], 
                      new_v_proj.U[total_dims:total_dims + head_dim, total_ranks : total_ranks + group_rank])
    outputs.append(output)
    total_dims += head_dim
    if total_dims == group_dim:
        total_dims = 0
        total_ranks += group_rank

new_attn_output = torch.cat(outputs, dim=-1)
print(f'attn_output size: {new_attn_output.size()}')
torch.testing.assert_close(attn_output, new_attn_output)

new_v_output = o_proj(new_attn_output)
torch.testing.assert_close(ori_output, new_v_output)

attn_weight size: torch.Size([1, 8, 12, 3])
v_h_states size: torch.Size([1, 8, 3, 512])
attn_h_output size: torch.Size([1, 8, 12, 512])
attn_h_output size: torch.Size([1, 32, 3, 512])
attn_output size: torch.Size([1, 3, 4096])


In [216]:
# Decompose and fuse v_proj and o_proj
total_rank_v_list = [group_rank for _ in range(num_groups)]
new_v_proj = HeadwiseLowRankModule.from_linear(v_proj, total_rank_v_list)

# attn_weights: (bsz, num_groups, q_len * group_size, kv_seq_len)
attn_weight = attn_weight.reshape(1, num_groups, q_len * group_size, q_len)
print(f'attn_weight size: {attn_weight.size()}')
# value_h_states: (bsz, kv_seq_len, num_groups, group_rank).transpose(1, 2)
v_h_states = new_v_proj.project_to_latent(inputs).reshape(1, q_len, num_groups, group_rank).transpose(1, 2)
print(f'v_h_states size: {v_h_states.size()}')

# attn_h_output: (bsz, num_heads, q_len * group_size, group_rank)
attn_h_output = torch.matmul(attn_weight, v_h_states)
print(f'attn_h_output size: {attn_h_output.size()}')
attn_h_output = attn_h_output.reshape(1, num_heads, q_len, group_rank)
print(f'attn_h_output size: {attn_h_output.size()}')

attn_weight size: torch.Size([1, 8, 12, 3])
v_h_states size: torch.Size([1, 8, 3, 512])
attn_h_output size: torch.Size([1, 8, 12, 512])
attn_h_output size: torch.Size([1, 32, 3, 512])


In [217]:
fused_hidden_dim = group_rank * num_heads
new_o_proj = nn.Linear(fused_hidden_dim, out_features, bias=False)
new_o_weight = torch.zeros(new_o_proj.weight.size())

total_dims = 0
total_dims_2 = 0
total_ranks = 0
total_fused_dims = 0
for _ in range(num_groups):
    for _ in range(group_size):
        new_o_weight[:, total_fused_dims:total_fused_dims + group_rank] = \
            o_proj.weight.data[:, total_dims_2:total_dims_2 + head_dim] @ \
            new_v_proj.U.data[total_dims:total_dims + head_dim, total_ranks : total_ranks + group_rank]

        total_dims += head_dim
        total_dims_2 += head_dim
        total_fused_dims += group_rank

    total_dims = 0
    total_ranks += group_rank

with torch.no_grad():
    print(f'new_o_proj size: {new_o_proj.weight.data.size()}')
    new_o_proj.weight.copy_(new_o_weight)
final_fused_o_output = new_o_proj(attn_h_output.transpose(1, 2).reshape(1, q_len, -1))
torch.testing.assert_close(ori_output, final_fused_o_output)

new_o_proj size: torch.Size([4096, 16384])
