##**8. Masked Multi-head Attention**
1. Masked Multi-head Attention 구현.
2. Encoder-Decoder Attention 구현.

### **필요 패키지 import**

In [1]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

데이터의 값과 형태를 좀 더 명확하게 보기 위해 sample을 줄이겠습니다.

In [2]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
  [66, 88, 98, 47],
  [77, 65, 51, 77, 19, 15, 35, 19, 23]
]

In [3]:
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))

  return data, max_len

In [4]:
data, max_len = padding(data)

Maximum sequence length: 10


100%|███████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 106997.55it/s]


In [5]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]

### **Hyperparameter 세팅 및 embedding**

In [6]:
d_model = 8  # model의 hidden size
num_heads = 2  # head의 개수
inf = 1e12

In [7]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [8]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[-8.5668e-01, -5.6692e-01,  1.4254e+00,  1.0340e+00, -5.6951e-01,
          -3.2086e-01, -1.6549e+00,  2.1722e-01],
         [ 1.2143e+00,  6.8545e-01, -7.5511e-02, -1.8285e+00, -5.8766e-01,
          -3.0575e-01, -5.1290e-01, -8.3287e-01],
         [ 4.3816e-01, -1.1702e+00,  5.9156e-01, -1.1647e+00,  4.4995e-01,
           1.1627e+00, -4.9441e-02, -2.8493e-01],
         [-4.9367e-01,  2.1836e-01,  2.3368e+00, -3.6854e-01,  6.8153e-01,
           1.1698e-01,  8.5581e-02, -1.4454e+00],
         [ 5.8628e-01,  1.2795e+00, -1.1619e+00,  1.0901e+00, -9.1943e-01,
           3.2441e-01,  8.9015e-01,  1.1154e-01],
         [ 1.0066e+00, -1.8870e+00, -2.2291e+00,  1.7119e+00, -7.0112e-01,
          -6.8351e-01,  1.7704e+00, -1.7746e+00],
         [-1.3992e+00,  1.0362e+00, -7.9058e-01, -2.3070e-01,  4.4054e-01,
           8.3687e-01,  1.0816e+00, -9.4442e-01],
         [ 1.2143e+00,  6.8545e-01, -7.5511e-02, -1.8285e+00, -5.8766e-01,
          -3.0575e-01, -5.1290e-01, -8.3287e-01],


### **Mask 구축**

`True`는 attention이 적용될 부분, `False`는 masking될 자리입니다.

- Masked Multi-head Attention의 경우, decoding timestep에서 현재 타임스텝 이후의 token들을 참조하지 못하도록 하는 역할을 합니다. 이에 따른 mask는 `nopeak_mask`에서 결정됩니다.
- `padding_mask`: batch를 만들기 위해 채워 넣은 `pad_id`는 특별한 의미를 가지는 토큰이 아니므로 참조할 대상이 아님. 따라서 padding 부분을 masking 하는 `padding_mask`를 생성함. 미래의 토큰을 볼 수 없도록 하는 `nopeak_mask`와 결합하여 최종 `mask`를 생성하게 됨.

In [9]:
padding_mask = (batch != pad_id).unsqueeze(1)  # (B, 1, L)

print(padding_mask)
print(padding_mask.shape)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False]]])
torch.Size([5, 1, 10])


In [10]:
nopeak_mask = torch.ones([1, max_len, max_len], dtype=torch.bool)  # (1, L, L)
nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L)

print(nopeak_mask)
print(nopeak_mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]])
torch.Size([1, 10, 10])


In [15]:
mask = padding_mask & nopeak_mask  # (B, L, L)

print(mask)
print(mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  T

### **Linear transformation & 여러 head로 나누기**

In [11]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

w_0 = nn.Linear(d_model, d_model)

In [12]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


### **Masking이 적용된 self-attention 구현**

In [13]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)

In [17]:
masks = mask.unsqueeze(1)  # (B, 1, L, L)
masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf)  # (B, num_heads, L, L)

In [18]:
print(masked_attn_scores)

tensor([[[[ 4.5537e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-2.0454e-01, -2.4596e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-2.7268e-01,  4.1330e-01, -3.1739e-01, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 2.5532e-01,  2.5383e-01, -4.5472e-01,  1.7049e-01, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-3.9722e-01, -6.2343e-01,  1.0968e-01, -7.6718e-01,  3.8968e-01,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 9.4070e-02, -4.0943e-01,  4.8519e-01,  1.4621e-01,  1.7652e-01,
           -2.3282e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-6.5960e-01, -1.0805e-01, -3.9850e-01, -8.9610e-01,  2.9317e-01,
      

In [19]:
print(masked_attn_scores.shape)

torch.Size([5, 2, 10, 10])


`-1* inf`로 masking된 부분은 softmax 후 0이 됩니다.

In [20]:
attn_dists = F.softmax(masked_attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.5104, 0.4896, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2537, 0.5037, 0.2426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2934, 0.2929, 0.1442, 0.2695, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1576, 0.1257, 0.2616, 0.1089, 0.3462, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1682, 0.1017, 0.2488, 0.1773, 0.1827, 0.1213, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0569, 0.0989, 0.0739, 0.0450, 0.1477, 0.5080, 0.0697, 0.0000,
           0.0000, 0.0000],
          [0.0999, 0.0958, 0.0913, 0.0846, 0.1980, 0.2246, 0.1099, 0.0958,
           0.0000, 0.0000],
          [0.1650, 0.1080, 0.1024, 0.1260, 0.1512, 0.1193, 0.1201, 0.1080,
           0.0000, 0.0000],
          [0.1650, 0.1080, 0.1024, 0.1260, 0.1512, 0.1193, 0.1201, 0.1080

In [21]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 10, 4])


### **전체 코드**

In [22]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v, mask=None):
    batch_size = q.shape[0]

    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v, mask=mask)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  def self_attention(self, q, k, v, mask=None):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)

    if mask is not None:
      mask = mask.unsqueeze(1)  # (B, 1, L, L) or  (B, 1, 1, L)
      attn_scores = attn_scores.masked_fill_(mask == False, -1*inf)

    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [23]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask=mask)  # (B, L, d_model)

In [25]:
print(outputs)

tensor([[[-0.0950, -0.6099, -0.1194, -0.4724,  0.0516, -0.2494,  0.0768,
           0.7056],
         [-0.2845, -0.2152, -0.3627, -0.6293,  0.3523,  0.0401, -0.1301,
           0.3934],
         [-0.2273,  0.0778, -0.3762, -0.5926,  0.3266,  0.2731, -0.2716,
           0.1992],
         [-0.2018,  0.0960, -0.2749, -0.5768,  0.3603,  0.1116, -0.0936,
           0.1510],
         [-0.1883,  0.1396, -0.3003, -0.4102,  0.2189, -0.1391,  0.0320,
           0.1338],
         [-0.0967,  0.2005, -0.2014, -0.3439,  0.1561,  0.0376,  0.0032,
           0.1217],
         [-0.0721,  0.2410, -0.1946, -0.3351,  0.1565, -0.0019,  0.0417,
           0.0974],
         [-0.1803,  0.2622, -0.2623, -0.3797,  0.2532,  0.0154, -0.0396,
           0.0744],
         [-0.1231,  0.2505, -0.2116, -0.3664,  0.2563,  0.1203, -0.0404,
           0.0844],
         [-0.1231,  0.2505, -0.2116, -0.3664,  0.2563,  0.1203, -0.0404,
           0.0844]],

        [[-0.3115, -0.2966, -0.1596, -0.7238,  0.2179, -0.3929,  0.1

In [26]:
print(outputs.shape)

torch.Size([5, 10, 8])


### **Encoder-Decoder attention**

Query, key, value만 달라질 뿐 구현은 동일합니다.  
Decoder에 들어갈 batch만 별도 구현하겠습니다.

In [27]:
trg_data = [
  [33, 11, 49, 10],
  [88, 34, 5, 29, 99, 45, 11, 25],
  [67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88 ,19],
  [16, 58, 91, 47, 12, 5, 8],
  [71, 63, 62, 7, 9, 11, 55, 91, 32, 48]
]

trg_data, trg_max_len = padding(trg_data)

Maximum sequence length: 12


100%|████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 98922.26it/s]


In [28]:
# S_L: source maximum sequence length, T_L: target maximum sequence length
src_batch = batch  # (B, S_L)
trg_batch = torch.LongTensor(trg_data)  # (B, T_L)

print(src_batch.shape)
print(trg_batch.shape)

torch.Size([5, 10])
torch.Size([5, 12])


In [29]:
src_emb = embedding(src_batch)  # (B, S_L, d_w)
trg_emb = embedding(trg_batch)  # (B, T_L, d_w)

print(src_emb.shape)
print(trg_emb.shape)

torch.Size([5, 10, 8])
torch.Size([5, 12, 8])


`src_emb`를 encoder에서 나온 결과, 그리고 `trg_emb`를 masked multi-head attention 후 결과로 가정합니다.

In [30]:
q = w_q(trg_emb)  # (B, T_L, d_model)
k = w_k(src_emb)  # (B, S_L, d_model)
v = w_v(src_emb)  # (B, S_L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, T_L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)

q = q.transpose(1, 2)  # (B, num_heads, T_L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, S_L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, S_L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 12, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


In [31]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, T_L, S_L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, T_L, S_L)

print(attn_dists.shape)

torch.Size([5, 2, 12, 10])


In [32]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, T_L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 12, 4])


Masked multi-head attention 후 나온 결과와 동일한 shape를 가지며 이후 layer에서 전체 연산도 동일하게 진행됩니다.

###**콘텐츠 라이선스**

<font color='red'><b>**WARNING**</b></font> : **본 교육 콘텐츠의 지식재산권은 재단법인 네이버커넥트에 귀속됩니다. 본 콘텐츠를 어떠한 경로로든 외부로 유출 및 수정하는 행위를 엄격히 금합니다.** 다만, 비영리적 교육 및 연구활동에 한정되어 사용할 수 있으나 재단의 허락을 받아야 합니다. 이를 위반하는 경우, 관련 법률에 따라 책임을 질 수 있습니다.

