##**8. Masked Multi-head Attention**
1. Masked Multi-head Attention 구현.
2. Encoder-Decoder Attention 구현.

### **필요 패키지 import**

In [1]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

데이터의 값과 형태를 좀 더 명확하게 보기 위해 sample을 줄이겠습니다.

In [2]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
  [66, 88, 98, 47],
  [77, 65, 51, 77, 19, 15, 35, 19, 23]
]

In [3]:
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))

  return data, max_len

In [4]:
data, max_len = padding(data)

100%|██████████| 5/5 [00:00<00:00, 22168.63it/s]

Maximum sequence length: 10





In [5]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]

### **Hyperparameter 세팅 및 embedding**

In [6]:
d_model = 8  # model의 hidden size
num_heads = 2  # head의 개수
inf = 1e12

In [7]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [8]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[-0.3839, -1.6025,  0.1949, -1.0202,  0.4348, -0.3048,  0.9260,
          -0.4802],
         [ 0.4507,  0.8300,  0.4032, -1.4784, -0.6163,  0.0249, -1.0281,
           0.3248],
         [-0.9727,  0.2589, -0.9030,  0.2146, -0.4724,  0.4236, -0.2628,
           0.3315],
         [-0.0299,  1.5641,  1.1531, -1.0335,  0.1508,  0.4652,  0.4110,
          -0.1267],
         [ 1.3567, -0.4157, -0.4153, -0.0233, -1.1758, -2.2988, -0.2780,
          -1.0320],
         [-0.8699,  0.5465,  0.5797,  1.4772, -0.0591,  0.4099,  0.1740,
          -0.1735],
         [-1.4360, -0.2477, -1.0659,  0.7206,  0.0434,  0.7370, -1.3688,
          -0.2318],
         [ 0.4507,  0.8300,  0.4032, -1.4784, -0.6163,  0.0249, -1.0281,
           0.3248],
         [ 0.9736, -0.4646, -0.0426,  0.6201,  0.1612,  0.7134,  0.5097,
           1.9258],
         [ 0.9736, -0.4646, -0.0426,  0.6201,  0.1612,  0.7134,  0.5097,
           1.9258]],

        [[ 0.4918,  0.2270,  0.4832, -1.0459, -1.3621,  1.1617,  0.7

### **Mask 구축**

`True`는 attention이 적용될 부분, `False`는 masking될 자리입니다.

In [9]:
padding_mask = (batch != pad_id).unsqueeze(1)  # (B, 1, L) # pad token 마스킹, batch의 각 요소가 pad_id가 아니면 True, False mapping

print(padding_mask)
print(padding_mask.shape)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False]]])
torch.Size([5, 1, 10])


In [10]:
nopeak_mask = torch.ones([1, max_len, max_len], dtype=torch.bool)  # (1, L, L)
nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L) # Triangle low = 삼각형으로 masking 해줌

print(nopeak_mask)
print(nopeak_mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]])
torch.Size([1, 10, 10])


In [11]:
mask = padding_mask & nopeak_mask  # (B, L, L) # 둘이 함께 true 인 경우만 리턴

print(mask)
print(mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  T

### **Linear transformation & 여러 head로 나누기**

In [12]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

w_0 = nn.Linear(d_model, d_model)

In [13]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


### **Masking이 적용된 self-attention 구현**

In [14]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)

In [15]:
masks = mask.unsqueeze(1)  # (B, 1, L, L)
masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf)  # (B, num_heads, L, L)

print(masked_attn_scores)#0이 아닌 0에 아주 가까운 값이 들어감 (float 이라서)
print(masked_attn_scores.shape) 

tensor([[[[ 2.3512e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-1.2979e-01, -2.4226e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-1.5244e-01, -9.0459e-02, -2.2510e-01, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-4.6227e-01, -7.8110e-02, -7.8825e-02, -8.0203e-02, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 2.5430e-01, -2.5678e-01,  2.2410e-01, -5.1298e-02,  3.3492e-01,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-3.2856e-01, -2.4977e-01, -9.3572e-02, -6.4003e-01, -2.6189e-01,
           -2.4245e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 3.2029e-02, -8.4978e-02, -6.7304e-02, -2.8347e-01, -2.3434e-01,
      

`-1* inf`로 masking된 부분은 softmax 후 0이 됩니다.

In [16]:
attn_dists = F.softmax(masked_attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.5281, 0.4719, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.3340, 0.3554, 0.3106, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1852, 0.2719, 0.2717, 0.2713, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2278, 0.1366, 0.2210, 0.1678, 0.2469, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1603, 0.1734, 0.2028, 0.1174, 0.1714, 0.1747, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1740, 0.1548, 0.1576, 0.1269, 0.1333, 0.1213, 0.1321, 0.0000,
           0.0000, 0.0000],
          [0.1014, 0.0906, 0.1308, 0.0957, 0.1491, 0.1666, 0.1752, 0.0906,
           0.0000, 0.0000],
          [0.1831, 0.1329, 0.1155, 0.1395, 0.1199, 0.1000, 0.0760, 0.1329,
           0.0000, 0.0000],
          [0.1831, 0.1329, 0.1155, 0.1395, 0.1199, 0.1000, 0.0760, 0.1329

In [17]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 10, 4])


### **전체 코드**

In [18]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v, mask=None):
    batch_size = q.shape[0]

    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v, mask=mask)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  def self_attention(self, q, k, v, mask=None):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)

    if mask is not None:
      mask = mask.unsqueeze(1)  # (B, 1, L, L) or  (B, 1, 1, L)
      attn_scores = attn_scores.masked_fill_(mask == False, -1*inf)

    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [19]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask=mask)  # (B, L, d_model)

In [20]:
print(outputs)
print(outputs.shape)

tensor([[[-8.4967e-01, -5.8826e-01, -4.2850e-01,  9.9093e-02,  1.4771e-01,
           6.2223e-01,  2.8084e-01, -5.0115e-01],
         [-6.7497e-01, -3.9401e-01,  4.2540e-02,  4.1373e-02,  3.1136e-01,
           4.1219e-01, -1.4659e-01, -2.1613e-01],
         [-4.2850e-01, -3.7295e-01,  6.8726e-02,  3.4801e-02,  2.4097e-01,
           3.9402e-01, -1.9667e-01, -1.3758e-01],
         [-4.0753e-01, -4.6717e-01,  3.8919e-02, -9.0501e-02,  2.1593e-01,
           3.3076e-01, -2.6306e-01, -5.6494e-02],
         [-4.2232e-01, -3.3443e-01,  1.7336e-01,  2.1401e-02,  2.7166e-01,
           2.7729e-01, -3.2880e-01,  1.2248e-03],
         [-3.9664e-01, -3.1598e-01,  1.6006e-01,  6.9023e-02,  2.3970e-01,
           2.4342e-01, -3.3928e-01,  4.9142e-02],
         [-3.5027e-01, -2.7480e-01,  1.6992e-01,  8.1613e-02,  2.1119e-01,
           2.6034e-01, -2.8892e-01, -9.8891e-03],
         [-2.5498e-01, -2.4454e-01,  1.7983e-01,  5.4166e-02,  2.9174e-01,
           2.0812e-01, -3.9076e-01,  4.7339e-02],


### **Encoder-Decoder attention**

Query, key, value만 달라질 뿐 구현은 동일합니다.  
Decoder에 들어갈 batch만 별도 구현하겠습니다.

In [21]:
trg_data = [
  [33, 11, 49, 10],
  [88, 34, 5, 29, 99, 45, 11, 25],
  [67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88 ,19],
  [16, 58, 91, 47, 12, 5, 8],
  [71, 63, 62, 7, 9, 11, 55, 91, 32, 48]
]

trg_data, trg_max_len = padding(trg_data)

100%|██████████| 5/5 [00:00<00:00, 25922.77it/s]

Maximum sequence length: 12





In [22]:
# S_L: source maximum sequence length, T_L: target maximum sequence length
src_batch = batch  # (B, S_L)
trg_batch = torch.LongTensor(trg_data)  # (B, T_L)

print(src_batch.shape)
print(trg_batch.shape)

torch.Size([5, 10])
torch.Size([5, 12])


In [23]:
src_emb = embedding(src_batch)  # (B, S_L, d_w)
trg_emb = embedding(trg_batch)  # (B, T_L, d_w)

print(src_emb.shape)
print(trg_emb.shape)

torch.Size([5, 10, 8])
torch.Size([5, 12, 8])


`src_emb`를 encoder에서 나온 결과, 그리고 `trg_emb`를 masked multi-head attention 후 결과로 가정합니다.

In [24]:
q = w_q(trg_emb)  # (B, T_L, d_model) # encoder decoder 과정에서 q는 다른게 들어간다.
k = w_k(src_emb)  # (B, S_L, d_model)
v = w_v(src_emb)  # (B, S_L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, T_L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)

q = q.transpose(1, 2)  # (B, num_heads, T_L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, S_L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, S_L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 12, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


In [25]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, T_L, S_L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, T_L, S_L)

print(attn_dists.shape)

torch.Size([5, 2, 12, 10])


In [26]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, T_L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 12, 4])


Masked multi-head attention 후 나온 결과와 동일한 shape를 가지며 이후 layer에서 전체 연산도 동일하게 진행됩니다.