# self-attention

In [12]:
import torch
import torch.nn.functional as F

### 单头注意力

In [19]:
torch.manual_seed(123)

# 构建字典
sentence = 'Life is short, eat dessert first'

# 简单起见，我们只用句子中的单词作为词典
dc = {w:i for i, w in enumerate(sorted(sentence.replace(',', '').split()))}


# sentence -> dict index
sentence_int = torch.tensor([dc[w] for w in sentence.replace(',', '').split()]) # (1, seq_len)

# dict index -> embedding
d = 16 # embedding dimension
embed = torch.nn.Embedding(len(sentence_int), d)
embedded_sentence = embed(sentence_int).detach() # (seq_len, d)

# 定义 W_q, W_k, W_v
d_q, d_k, d_v = 24, 24, 28 # d_q = d_k
W_q = torch.nn.Parameter(torch.randn(d_q, d))
W_k = torch.nn.Parameter(torch.randn(d_k, d))
W_v = torch.nn.Parameter(torch.randn(d_v, d))

# 句子嵌入乘以 W_q, W_k, W_v 得到 Q, K, V
querys = (W_q @ embedded_sentence.T).T # (seq_len, d_q)
keys = (W_k @ embedded_sentence.T).T # (seq_len, d_k)
values = (W_v @ embedded_sentence.T).T # (seq_len, d_v)

# Q, K 点乘得到 attention weights
omega = querys @ keys.T # (seq_len, seq_len)
attention_weights = F.softmax(omega / d_k**0.5, dim=0) # (seq_len, seq_len)

# attention weights 乘以 V 得到 context vector
context_vector = attention_weights @ values # (seq_len, d_v)

### 多头注意力

In [23]:
torch.manual_seed(123)
h = 3 # number of heads

# 构建字典
sentence = 'Life is short, eat dessert first'

# 简单起见，我们只用句子中的单词作为词典
dc = {w:i for i, w in enumerate(sorted(sentence.replace(',', '').split()))}


# sentence -> dict index
sentence_int = torch.tensor([dc[w] for w in sentence.replace(',', '').split()]) # (1, seq_len)

# dict index -> 单头 embedding
d = 16 # embedding dimension
embed = torch.nn.Embedding(len(sentence_int), d)
embedded_sentence = embed(sentence_int).detach() # (seq_len, d)

# 将单头输入复制 h 次
stacked_inputs = embedded_sentence.T.repeat(h, 1, 1).permute(0, 2, 1) # (h, seq_len, d)

# 定义 W_q, W_k, W_v
d_q, d_k, d_v = 24, 24, 28 # d_q = d_k
W_q = torch.nn.Parameter(torch.randn(h, d_q, d))
W_k = torch.nn.Parameter(torch.randn(h, d_k, d))
W_v = torch.nn.Parameter(torch.randn(h, d_v, d))

# 句子嵌入乘以 W_q, W_k, W_v 得到 Q, K, V
querys = (W_q @ stacked_inputs.transpose(1, 2)).transpose(1, 2) # (h, seq_len, d_q)
keys = (W_k @ stacked_inputs.transpose(1, 2)).transpose(1, 2) # (h, seq_len, d_k)
values = (W_v @ stacked_inputs.transpose(1, 2)).transpose(1, 2) # (h, seq_len, d_v)

# Q, K 点乘得到 attention weights
omegas = torch.matmul(querys, keys.transpose(1, 2)) # (h, seq_len, seq_len)
attention_weights = F.softmax(omega / d_k**0.5, dim=0) # (h, seq_len, seq_len)

# attention weights 乘以 V 得到 context vector
context_vector = torch.matmul(attention_weights, values) # (h, seq_len, d_v)

# cross-attention