In [161]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.palu_attention import HeadwiseLowRankModule

v_proj = nn.Linear(80, 80, bias=False)
o_proj = nn.Linear(80, 80, bias=False)
num_groups = 2
group_size = 5
num_heads = 10
head_dim = 8
hidden_dim = 80
rank_v = 80
group_rank = rank_v // num_groups
group_dim = hidden_dim // num_groups
q_len = 1

In [162]:
inputs = torch.randn(1, q_len, 80)
# attn_weights: (bsz, self.num_heads, q_len, kv_seq_len)
attn_weight = torch.randn(1, num_heads, q_len, q_len)
# value_states: (bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
v_states = v_proj(inputs).view(1, q_len, num_heads, head_dim).transpose(1, 2)
attn_output = torch.matmul(attn_weight, v_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(1, q_len, -1)
ori_output = o_proj(attn_output)

print(f'attn_weight size: {attn_weight.size()}')
print(f'v_states size: {v_states.size()}')
print(f'attn_output size: {attn_output.size()}')
print(f'ori_output size: {ori_output.size()}')

attn_weight size: torch.Size([1, 10, 1, 1])
v_states size: torch.Size([1, 10, 1, 8])
attn_output size: torch.Size([1, 1, 80])
ori_output size: torch.Size([1, 1, 80])


In [163]:
# Decompose and fuse v_proj and o_proj
rank_v_list = [group_rank for _ in range(num_groups)]
new_v_proj = HeadwiseLowRankModule.from_linear(v_proj, rank_v_list)

# attn_weights: (bsz, self.num_groups, q_len * self.group_size, kv_seq_len)
attn_weight = attn_weight.reshape(1, num_heads, q_len, q_len)
print(f'attn_weight size: {attn_weight.size()}')
# value_states: (bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
v_states = v_proj(inputs).view(1, q_len, num_heads, head_dim).transpose(1, 2)
print(f'v_states size: {v_states.size()}')

attn_output = torch.matmul(attn_weight, v_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(1, q_len, -1)
new_v_output = o_proj(attn_output)
torch.testing.assert_close(ori_output, new_v_output)

attn_weight size: torch.Size([1, 10, 1, 1])
v_states size: torch.Size([1, 10, 1, 8])


In [164]:
# Decompose and fuse v_proj and o_proj
rank_v_list = [group_rank for _ in range(num_groups)]
new_v_proj = HeadwiseLowRankModule.from_linear(v_proj, rank_v_list)

# attn_weights: (bsz, self.num_groups, q_len * self.group_size, kv_seq_len)
attn_weight = attn_weight.reshape(1, num_heads, q_len, q_len)
print(f'attn_weight size: {attn_weight.size()}')

# value_h_states: (bsz, q_len, self.num_groups, self.fused_group_dim).transpose(1, 2)
v_h_states = new_v_proj.project_to_latent(inputs) #.reshape(1, q_len, num_groups, group_size).transpose(1, 2)
print(f'v_h_states size: {v_h_states.size()}')

# value_states: (bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
v_states = new_v_proj.reconstruct(v_h_states).reshape(1, q_len, num_heads, head_dim).transpose(1, 2)
print(f'v_states size: {v_states.size()}')

attn_output = torch.matmul(attn_weight, v_states)
print(f'attn_output size: {attn_output.size()}')
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(1, q_len, -1)
new_v_output = o_proj(attn_output)
torch.testing.assert_close(ori_output, new_v_output)

attn_weight size: torch.Size([1, 10, 1, 1])
v_h_states size: torch.Size([1, 1, 80])
v_states size: torch.Size([1, 10, 1, 8])
attn_output size: torch.Size([1, 10, 1, 8])


In [176]:
# Decompose and fuse v_proj and o_proj
rank_v_list = [group_rank for _ in range(num_groups)]
new_v_proj = HeadwiseLowRankModule.from_linear(v_proj, rank_v_list)

# attn_weights: (bsz, self.num_groups, q_len * self.group_size, kv_seq_len)
attn_weight = attn_weight.reshape(1, num_groups, q_len * group_size, q_len)
print(f'attn_weight size: {attn_weight.size()}')
# value_h_states: (bsz, kv_seq_len, self.num_groups, self.fused_group_dim).transpose(1, 2)
v_h_states = new_v_proj.project_to_latent(inputs).reshape(1, q_len, num_groups, group_rank).transpose(1, 2)
print(f'v_h_states size: {v_h_states.size()}')

attn_h_output = torch.matmul(attn_weight, v_h_states).reshape(1, num_heads, q_len, group_rank)
print(f'attn_h_output size: {attn_h_output.size()}')

outputs = []
total_dims = 0
total_ranks = 0
for i in range(num_heads):
    output = F.linear(attn_h_output[:, i:i+1, :, :], 
                      new_v_proj.U[total_dims:total_dims + head_dim, total_ranks : total_ranks + group_rank])
    outputs.append(output)
    total_dims += head_dim
    if total_dims == group_dim:
        total_dims = 0
        total_ranks += group_rank

new_attn_output = torch.cat(outputs, dim=-1).reshape(1, q_len, -1)
print(f'attn_output size: {new_attn_output.size()}')


# # 重新排列 new_v_proj.U，使其可以一次性計算所有的頭部
# reshaped_U = new_v_proj.U.view(num_heads, head_dim, -1).transpose(1, 2)

# # 使用 torch.matmul 進行矩陣乘法
# outputs = torch.matmul(attn_h_output, reshaped_U)

# # 重新排列 outputs，使其回到原來的形狀
# new_attn_output = outputs.view(1, q_len, -1)


torch.testing.assert_close(attn_output, new_attn_output)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(1, q_len, -1)
new_v_output = o_proj(attn_output)
torch.testing.assert_close(ori_output, new_v_output)

attn_weight size: torch.Size([1, 2, 5, 1])
v_h_states size: torch.Size([1, 2, 1, 40])
attn_h_output size: torch.Size([1, 10, 1, 40])
attn_output size: torch.Size([1, 1, 80])


In [None]:
# Decompose and fuse v_proj and o_proj
rank_v_list = [group_rank for _ in range(num_groups)]
new_v_proj = HeadwiseLowRankModule.from_linear(v_proj, rank_v_list)

# attn_weights: (bsz, self.num_groups, q_len * self.group_size, kv_seq_len)
attn_weight = attn_weight.reshape(1, num_groups, q_len * group_size, q_len)
print(f'attn_weight size: {attn_weight.size()}')
# value_h_states: (bsz, kv_seq_len, self.num_groups, self.fused_group_dim).transpose(1, 2)
v_h_states = new_v_proj.project_to_latent(inputs).reshape(1, q_len, num_groups, group_rank).transpose(1, 2)
print(f'v_h_states size: {v_h_states.size()}')

attn_h_output = torch.matmul(attn_weight, v_h_states).reshape(1, num_heads, q_len, group_rank)
print(f'attn_h_output size: {attn_h_output.size()}')

# outputs = []
# total_dims = 0
# total_ranks = 0
# for i in range(num_heads):
#     output = F.linear(attn_h_output[:, i:i+1, :, :], 
#                       new_v_proj.U[total_dims:total_dims + head_dim, total_ranks : total_ranks + group_rank])
#     outputs.append(output)
#     total_dims += head_dim
#     if total_dims == group_dim:
#         total_dims = 0
#         total_ranks += group_rank

# new_attn_output = torch.cat(outputs, dim=-1).reshape(1, q_len, -1)
# print(f'attn_output size: {new_attn_output.size()}')


# 重新排列 new_v_proj.U，使其可以一次性計算所有的頭部
reshaped_U = new_v_proj.U.view(num_heads, head_dim, -1).transpose(1, 2)

# 使用 torch.matmul 進行矩陣乘法
outputs = torch.matmul(attn_h_output, reshaped_U)

# 重新排列 outputs，使其回到原來的形狀
outputs = outputs.view(1, q_len, -1)


torch.testing.assert_close(attn_output, new_attn_output)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(1, q_len, -1)
new_v_output = o_proj(attn_output)
torch.testing.assert_close(ori_output, new_v_output)