# Original self-attention: scaled-dot product attention

### Refs
* [Sebastian's blog](https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html)

### Notes
* 算矩阵乘法时，如果不确定是否要转置，可以先打印形状，对齐即可

#### 单头注意力

In [17]:
# make a dictionary from one sentence
sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))} # 字典序
print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


In [18]:
# dictionary -> one-hot vector
import torch

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


In [19]:
# one-hot vector -> embedding
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])
torch.Size([6, 16])


In [20]:
# 定义 W_q, W_k, W_v
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 24, 24, 28
# 为什么是24和28呢？这个是可以调整的，但是要保证d_q=d_k

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

In [21]:
# x_2 经过 W_q, W_k, W_v 得到 query_2, key_2, value_2
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


In [6]:
# x 经过 W_q, W_k, W_v 得到 query, key, value
keys = W_key.matmul(embedded_sentence.T).T # (seq_len, d_k)
# W_key.shape = (24, 16)
# embedded_sentence.shape = (seq_len, d) = (6, 16)
# embedded_sentence.T.shape = (16, 6)，转置为了与W_key相乘
# 转置回来为什么？

values = W_value.matmul(embedded_sentence.T).T # (seq_len, d_v)

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


In [16]:
# x_2 经过 W_q, W_k, W_v 得到 query_2, key_2, value_2，然后计算attention_weights_2
import torch.nn.functional as F

omega_2 = query_2.matmul(keys.T)
print(omega_2.shape)
attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
# softmax 会放大最大值，缩小最小值，使得最大值更大，最小值更小
# 如果 omega_2 向量中的值方差较大，那么 softmax 后向量的某一个元素会接近于1，其他元素接近于0
# 这样反向传播时梯度可能很小
# 为了避免这种情况，我们缩放 omega_2，使得 omega_2 的方差比较小
print(attention_weights_2)

torch.Size([6])
tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458],
       grad_fn=<SoftmaxBackward0>)


In [8]:
# attention_weights_2 乘以 value_2 得到 context_vector_2
context_vector_2 = attention_weights_2.matmul(values)

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084], grad_fn=<SqueezeBackward4>)


#### 多头注意力

In [9]:
# 定义多头 W_q, W_k, W_v
h = 3
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))

In [10]:
# x_2 经过多头 W_q, W_k, W_v 得到 multihead_query_2, multihead_key_2, multihead_value_2
multihead_query_2 = multihead_W_query.matmul(x_2)
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)
print(multihead_query_2.shape)

torch.Size([3, 24])


In [12]:
embedded_sentence.shape

torch.Size([6, 16])

In [11]:
# 多头形式下，重新定义输入 x：(seq_len, d) -> (h, d, seq_len)
stacked_inputs = embedded_sentence.T.repeat(h, 1, 1)
print(stacked_inputs.shape)

torch.Size([3, 16, 6])


In [15]:
# 使用 batch matrix multiplication 一次性计算 multihead_query, multihead_key, multihead_value
# permute 交换维度，使得多头 key, value 的维度顺序与 embedded_sentence 一致
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs).permute(0, 2, 1)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs).permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 6, 24])
multihead_values.shape: torch.Size([3, 6, 28])


# cross-attention