In [2]:
import torch
import torch.nn.functional as F

In [7]:
# 使用常规方式计算的 Attention
KT = torch.transpose(K, 0, 1)
S = torch.matmul(Q, KT)
P = torch.softmax(S, 1)
O = torch.matmul(P, V)
ouputBase = O
print(ouputBase.size())

torch.Size([256, 64])


In [8]:
ouputBase.shape

torch.Size([256, 64])

### 计算单个向量

In [130]:
def func_m(x):
    res_m = torch.max(x, 1, keepdim=True).values # keep dim 是为了后面
    #print("res_m={}".format(res_m.shape))
    return res_m

def func_f(x):
    res_f = torch.exp(x - func_m(x))
    #print("res_f={}".format(res_f.shape))
    return res_f # 包含在0维的broadcast操作

def func_l(x):
    res_l = torch.sum(func_f(x), 1, keepdim=True)
    #print("res_l={}".format(res_l.shape))
    return res_l

def func_softmax(x): 
    """
    等价于 F.softmax(x, 1)
    """
    return func_f(x)/func_l(x)

### 计算两个向量合并

In [131]:
def func_m2(x1, x2): 
    #print(func_m(x1))
    #print(func_m(x2))
    return torch.maximum(func_m(x1), func_m(x2))

def func_f2(x1, x2):
    #x = torch.cat(x1, x2, 1)
    x1_part = torch.exp(func_m(x1) - func_m2(x1, x2)) * func_f(x1)
    x2_part = torch.exp(func_m(x2) - func_m2(x1, x2)) * func_f(x2)
    res = torch.cat((x1_part, x2_part), 1)
    print("res_f2={}".format(res.shape))
    return res

def func_l2(x1, x2):
    x1_part = torch.exp(func_m(x1)-func_m2(x1, x2)) * func_l(x1)
    x2_part = torch.exp(func_m(x2)-func_m2(x1, x2)) * func_l(x2)
    res = (x1_part + x2_part)
    print("res_l2={}".format(res.shape))
    return res

def func_softmax2(x1, x2): 
    return func_f2(x1, x2) / func_l2(x1, x2)

### 分块计算 和 整体计算对比

In [145]:
def split_attention(Q, K, V, Br=1, Bc=1):
    Qs = Q.split(Br, 0)
    Ks = K.split(Bc, 0)
    Vs = V.split(Bc, 0)
    
    Tr = len(Qs)
    Tc = len(Vs)
    
    Output = torch.zeros(Q.shape)
    Os = [o for o in Output.split(Br, 0)] # 因为split返回的tuple没法修改值 所以改为了list
    ls = [torch.zeros(q.shape[0], 1) for q in Qs]
    ms = [torch.ones(q.shape[0], 1) * float("-Inf") for q in Qs] # -inf
    for j in range(Tc):
        # load Kj, Vj from HBM
        Kj = Ks[j]
        Vj = Vs[j]
        for i in range(Tr):
            # load li Oi Qi from HBM
            li = ls[i]
            mi = ms[i]
            Oi = Os[i]
            Qi = Qs[i]
            
            Kt = Kj.transpose(1, 0)
            Sij = Qi.matmul(Kt)
            mij, _ = Sij.max(1, keepdim=True)
            minew = torch.maximum(mi, mij)
            Pij = (Sij - mij).exp() # 为了计算稳定 避免溢出
            lij = Pij.sum(1, keepdim=True) 
            linew = torch.exp(mi - minew) * li + torch.exp(mij - minew) * lij
            old_part = Oi * li * torch.exp(mi - minew)
            new_part = torch.exp(mij - minew) * Pij.matmul(Vj)
            Oi = (old_part + new_part) / linew # 乘上原来的值 
            
            # write back Oi li to HBM
            Os[i] = Oi
            ls[i] = linew
            ms[i] = minew
    
    Output = torch.cat(Os, 0)
    return Output

def attention(Q, K, V):
    KT = torch.transpose(K, 1, 0)
    S = torch.matmul(Q, KT)
    P = torch.softmax(S, 1)
    O = torch.matmul(P, V)
    return O

In [146]:
N =8
d = 3
Br = 2
Bc = 2
Q = torch.randn((N, d))
K = torch.randn((N, d))
V = torch.randn((N, d))

ori_res = attention(Q, K, V)
split_res = split_attention(Q, K, V, Br=Br, Bc=Bc)
CHECK_EQ(ori_res, split_res)

In [141]:
ori_res

tensor([[ 0.2309, -0.4989,  0.3301],
        [ 0.4307, -0.4754,  0.7235],
        [ 0.3869, -0.4787,  0.2893],
        [ 0.4626, -0.4860,  1.0540],
        [ 0.5465, -0.5487,  1.4449],
        [ 0.6318, -0.5808,  1.0802],
        [ 0.5865, -0.4663,  0.2924],
        [ 0.3239, -0.4439,  0.2795]])

In [142]:
split_res

tensor([[ 0.2456,  0.5406,  0.8136],
        [ 0.1798,  0.5123,  0.7575],
        [ 0.1319,  0.4918,  0.7166],
        [ 0.2535,  0.5440,  0.8204],
        [ 0.2286,  0.5333,  0.7991],
        [ 0.3873,  0.6014,  0.9345],
        [-0.1630,  0.3653,  0.4650],
        [-0.0550,  0.4116,  0.5572]])

In [100]:
ori_res

tensor([[-1.4864, -0.5113,  1.4111],
        [ 0.8483, -0.6665, -0.3296],
        [ 0.1248, -1.7141, -0.5245],
        [ 0.1519, -1.7555, -0.5763],
        [-1.7162, -1.1667,  1.9009],
        [ 0.9712, -0.8719, -0.4632],
        [-0.5934,  0.6552, -0.0957],
        [-0.8440, -0.7500,  0.5437]])

In [14]:
# 使用常规方式计算的 Attention
KT = torch.transpose(K, 0, 1)
S = torch.matmul(Q, KT)
P = torch.softmax(S, 1)
O = torch.matmul(P, V)

In [34]:
O2 = attention(Q, K, V)
CHECK_EQ(O, O2)

In [102]:
def CHECK_EQ(O1, O2):
    assert (O1 - O2).sum() > -1e-4

In [40]:
a = torch.randn(3, 3)
a

tensor([[-0.5640,  1.8226, -0.0988],
        [ 1.1099, -0.1472,  0.4329],
        [-1.3287, -1.9497,  0.0321]])

In [42]:
a.transpose(0, 1)

tensor([[-0.5640,  1.1099, -1.3287],
        [ 1.8226, -0.1472, -1.9497],
        [-0.0988,  0.4329,  0.0321]])

In [136]:
# test
x = torch.randn(3, 8)
x - func_m(x)
# 直接计算
res1 = F.softmax(x, 1)
# 分布计算
res2 = func_softmax(x)
# 分段分布计算
x1, x2 = x.split([3, 5], 1)
print(x1.shape, x2.shape)
res3 = func_softmax2(x1, x2)

torch.Size([3, 3]) torch.Size([3, 5])
res_f2=torch.Size([3, 8])
res_l2=torch.Size([3, 1])


In [135]:
res1

tensor([[0.2059, 0.0911, 0.0528, 0.0326, 0.2141, 0.1049, 0.1402, 0.1584],
        [0.0829, 0.0242, 0.0725, 0.0598, 0.0519, 0.2991, 0.0562, 0.3534],
        [0.2934, 0.1037, 0.1191, 0.1432, 0.0084, 0.0292, 0.1676, 0.1354]])

In [134]:
res3

tensor([[0.2059, 0.0911, 0.0528, 0.0326, 0.2141, 0.1049, 0.1402, 0.1584],
        [0.0829, 0.0242, 0.0725, 0.0598, 0.0519, 0.2991, 0.0562, 0.3534],
        [0.2934, 0.1037, 0.1191, 0.1432, 0.0084, 0.0292, 0.1676, 0.1354]])

In [133]:
## 分块计算矩阵 attention

In [None]:
Qs = Q.split(Br, 0) # 按照Br切分Q
Ks = Q.split(Bc, 0) # 按照Bc切分K
Vs = Q.split(Bc, 0) # 按照Bc切分V

def func_block_calc():
    assert len(Vs) == len(Ks), "K != v"
    for i in range(len(Qs)):
        
        for j in range(len(Vs)):
            torch.matmul(q, Ks[j])

In [117]:
x1, x2 = x.split(4, 1)

In [118]:
func_m2(x1, x2)

tensor([[0.7145],
        [1.8365],
        [1.9663]])

In [113]:
torch.max(x, 1, keepdim=True).values

tensor([[1.7144],
        [0.1132],
        [1.1683]])

In [129]:
res1

tensor([[0.0509, 0.2278, 0.0937, 0.2169, 0.0089, 0.1849, 0.0430, 0.1738],
        [0.0126, 0.0479, 0.0301, 0.0149, 0.1444, 0.3562, 0.1616, 0.2324],
        [0.3082, 0.1054, 0.0194, 0.2529, 0.0802, 0.1178, 0.0287, 0.0873]])

In [128]:
res3

tensor([[0.0509, 0.2278, 0.0937, 0.2169, 0.0089, 0.1849, 0.0430, 0.1738],
        [0.0126, 0.0479, 0.0301, 0.0149, 0.1444, 0.3562, 0.1616, 0.2324],
        [0.3082, 0.1054, 0.0194, 0.2529, 0.0802, 0.1178, 0.0287, 0.0873]])

In [6]:
diag1 = torch.diag(torch.Tensor([1, 2, 3]))
n = torch.randn(3, 3)

In [7]:
n

tensor([[ 0.6589, -0.3955, -0.7406],
        [-1.1916,  0.6056, -0.2908],
        [ 0.3575, -2.2452,  1.2301]])

In [10]:
torch.matmul(diag1, n)

tensor([[ 0.6589, -0.3955, -0.7406],
        [-2.3832,  1.2111, -0.5817],
        [ 1.0725, -6.7356,  3.6902]])