In [1]:
import torch

In [2]:
def softmax_native(x):
    """
        softmax函式進行正規化
    """
    return torch.exp(x) / torch.exp(x).sum(dim= 0)

In [3]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],
     [0.55, 0.87, 0.66],
     [0.57, 0.85, 0.64],
     [0.22, 0.58, 0.33],
     [0.77, 0.25, 0.10],
     [0.05, 0.80, 0.55]]
)

query =  inputs[1] # 查詢目標
attention_score = torch.empty(inputs.shape[0]) # 建立空的注意力分數張量

In [None]:
for i, i_value in enumerate(inputs):
    attention_score[i] = torch.dot(i_value, query) # 計算查詢目標與每個鍵的點積
print("Attention Score: ", attention_score)

Attention Score:  tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [None]:
# Simple normalization
attention_weight = attention_score / attention_score.sum() # 計算注意力權重(總和為1)
print("Attention Weight: ", attention_weight)
print("Sum of Attention Weight: ", attention_weight.sum())


Attention Weight:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum of Attention Weight:  tensor(1.0000)


In [None]:
# softmax normalization
attention_weight_softmax_naive = softmax_native(attention_score)
print("Softmax Attention Weight: ", attention_weight_softmax_naive)
print("Sum of Softmax Attention Weight: ", attention_weight_softmax_naive.sum())

Softmax Attention Weight:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum of Softmax Attention Weight:  tensor(1.)


In [None]:
# PyTorch softmax normalization
attention_weight_softmax = torch.softmax(attention_score, dim= 0)
print("Softmax Attention Weight (PyTorch): ", attention_weight_softmax)
print("Sum of Softmax Attention Weight (PyTorch): ", attention_weight_softmax.sum())

Softmax Attention Weight (PyTorch):  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum of Softmax Attention Weight (PyTorch):  tensor(1.)


In [8]:
context_vec = torch.zeros(query.shape)
for i, i_value in enumerate(inputs):
    context_vec += attention_weight_softmax[i]*i_value # 計算加權後的上下文向量
print("Context Vector: ", context_vec)

Context Vector:  tensor([0.4419, 0.6515, 0.5683])
