##**7. Multi-head Attention**
1. Multi-head attention 및 self-attention 구현.
2. 각 과정에서 일어나는 연산과 input/output 형태 이해.

### **필요 패키지 import**

In [1]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

In [2]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [3]:
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))

  return data, max_len

In [4]:
data, max_len = padding(data)

Maximum sequence length: 20


100%|█████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 190650.18it/s]


In [5]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

### **Hyperparameter 세팅 및 embedding**

In [6]:
d_model = 512  # model의 hidden size
num_heads = 8  # head의 개수

In [7]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [9]:
print(batch_emb)

tensor([[[-1.3829,  0.2615, -0.2509,  ...,  0.7878,  0.7341, -1.5561],
         [-1.5745,  1.2790,  0.6530,  ...,  3.2941,  0.0916, -0.4197],
         [ 0.3946, -0.9075,  0.1039,  ..., -1.3712,  2.2372,  0.6184],
         ...,
         [-0.9100, -0.1834,  0.1774,  ..., -2.0337,  1.1422,  0.5385],
         [-0.9100, -0.1834,  0.1774,  ..., -2.0337,  1.1422,  0.5385],
         [-0.9100, -0.1834,  0.1774,  ..., -2.0337,  1.1422,  0.5385]],

        [[ 0.2132, -0.2783,  0.8494,  ...,  1.3866, -0.3096,  0.1386],
         [ 0.8789, -0.1996,  0.8290,  ..., -0.3165,  1.2627,  0.0771],
         [-1.1631,  0.5746, -0.9279,  ...,  0.5000,  0.1728,  0.4771],
         ...,
         [-0.9100, -0.1834,  0.1774,  ..., -2.0337,  1.1422,  0.5385],
         [-0.9100, -0.1834,  0.1774,  ..., -2.0337,  1.1422,  0.5385],
         [-0.9100, -0.1834,  0.1774,  ..., -2.0337,  1.1422,  0.5385]],

        [[ 2.5604,  1.4206,  0.3539,  ..., -0.1646, -0.0159,  1.1848],
         [ 0.3831,  0.7947, -1.1260,  ...,  0

In [10]:
print(batch_emb.shape)

torch.Size([10, 20, 512])


### **Linear transformation & 여러 head로 나누기**

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

- query, key, value를 서로 **다른** linear transformation matrix를 사용하여 linear projection 합니다.
- 동일한 representations (`batch_emb`)으로부터 서로 다른 query, key, value를 생성할 수 있습니다.

In [11]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [12]:
w_0 = nn.Linear(d_model, d_model)

In [None]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

In [14]:
print(q)

tensor([[[ 0.3762, -0.3919,  0.2359,  ...,  0.2422, -0.1041, -0.2670],
         [ 0.4900,  0.5566, -0.6025,  ...,  0.1934,  0.1268, -1.5310],
         [-0.9232, -0.2311,  0.6882,  ..., -0.0455, -0.0301, -0.3920],
         ...,
         [-1.3817, -0.2090, -0.6947,  ..., -0.0486,  0.0123,  0.1691],
         [-1.3817, -0.2090, -0.6947,  ..., -0.0486,  0.0123,  0.1691],
         [-1.3817, -0.2090, -0.6947,  ..., -0.0486,  0.0123,  0.1691]],

        [[-0.1338, -0.5917, -0.3523,  ...,  0.1090, -1.1387,  0.1942],
         [ 1.1372, -0.5879,  0.5035,  ..., -0.9368, -0.9443, -0.1616],
         [ 0.0893, -0.4137,  0.5938,  ...,  0.1984,  0.2541,  0.0110],
         ...,
         [-1.3817, -0.2090, -0.6947,  ..., -0.0486,  0.0123,  0.1691],
         [-1.3817, -0.2090, -0.6947,  ..., -0.0486,  0.0123,  0.1691],
         [-1.3817, -0.2090, -0.6947,  ..., -0.0486,  0.0123,  0.1691]],

        [[ 0.0539, -0.8496,  0.1570,  ..., -0.4510, -0.0048,  1.3414],
         [ 0.1464, -0.3278, -0.0076,  ..., -0

In [13]:
print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


Q, k, v를 `num_head`개의 차원 분할된 여러 vector로 만듭니다.

In [15]:
batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [16]:
q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])


### **Scaled dot-product self-attention 구현**

각 head에서 실행되는 self-attetion 과정입니다.

In [17]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[0.0408, 0.0443, 0.0295,  ..., 0.0756, 0.0756, 0.0756],
          [0.0521, 0.0429, 0.0406,  ..., 0.0640, 0.0640, 0.0640],
          [0.0542, 0.0406, 0.0442,  ..., 0.0406, 0.0406, 0.0406],
          ...,
          [0.0392, 0.0435, 0.0454,  ..., 0.0542, 0.0542, 0.0542],
          [0.0392, 0.0435, 0.0454,  ..., 0.0542, 0.0542, 0.0542],
          [0.0392, 0.0435, 0.0454,  ..., 0.0542, 0.0542, 0.0542]],

         [[0.0486, 0.0364, 0.0543,  ..., 0.0457, 0.0457, 0.0457],
          [0.0638, 0.0641, 0.0541,  ..., 0.0252, 0.0252, 0.0252],
          [0.0757, 0.0344, 0.0448,  ..., 0.0413, 0.0413, 0.0413],
          ...,
          [0.0705, 0.0539, 0.0397,  ..., 0.0650, 0.0650, 0.0650],
          [0.0705, 0.0539, 0.0397,  ..., 0.0650, 0.0650, 0.0650],
          [0.0705, 0.0539, 0.0397,  ..., 0.0650, 0.0650, 0.0650]],

         [[0.0601, 0.0447, 0.0686,  ..., 0.0362, 0.0362, 0.0362],
          [0.0345, 0.0353, 0.0340,  ..., 0.0851, 0.0851, 0.0851],
          [0.0418, 0.0613, 0.0547,  ..., 0

In [18]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([10, 8, 20, 64])


### **각 head의 결과물 병합**

각 head의 결과물을 concat하고 동일 차원(`d_model`)으로 linear transformation합니다.

In [20]:
attn_values = attn_values.transpose(1, 2)  # (B, L, num_heads, d_k)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  # (B, L, d_model)

print(attn_values.shape)

torch.Size([10, 20, 512])


In [22]:
outputs = w_0(attn_values)

print(outputs)

tensor([[[ 3.7536e-02, -1.1127e-01, -1.0192e-01,  ...,  3.4096e-02,
           1.9917e-01, -4.7854e-02],
         [ 9.2709e-02, -1.1835e-01, -3.0714e-02,  ...,  6.1924e-02,
           1.9109e-01, -4.9547e-02],
         [ 2.2903e-02, -3.2129e-02, -4.0533e-02,  ...,  8.2119e-03,
           5.1191e-02, -8.5249e-02],
         ...,
         [-3.7873e-02,  1.1777e-01, -8.4572e-02,  ...,  3.0946e-02,
          -8.4204e-02,  8.9798e-02],
         [-4.7355e-02,  8.8719e-02, -1.2927e-01,  ...,  1.0065e-01,
           1.2332e-01, -2.4758e-01],
         [ 5.6735e-03,  5.5543e-03, -1.1913e-01,  ...,  7.0541e-02,
           8.4119e-02, -1.9495e-01]],

        [[-2.0735e-01,  5.9715e-02, -5.6107e-01,  ...,  2.6348e-01,
           2.9746e-01, -1.3711e-01],
         [-2.1326e-01,  7.9526e-02, -5.6328e-01,  ...,  2.5622e-01,
           2.9996e-01, -1.3119e-01],
         [-8.9059e-02,  1.8120e-01, -2.4666e-02,  ...,  4.1081e-01,
           1.2801e-01, -1.1793e-01],
         ...,
         [-3.5409e-02, -6

In [23]:
print(outputs.shape)

torch.Size([10, 20, 512])


### **전체 코드**

위의 과정을 모두 합쳐 하나의 Multi-head attention 모듈을 구현하겠습니다.

In [24]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v):
    batch_size = q.shape[0]

    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  def self_attention(self, q, k, v):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [25]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb)  # (B, L, d_model)

In [27]:
print(outputs)

tensor([[[-0.0870,  0.0750,  0.0305,  ...,  0.1137,  0.0775, -0.0140],
         [-0.0297,  0.0265,  0.0692,  ...,  0.1198,  0.1127,  0.0661],
         [-0.0761,  0.0685,  0.0723,  ...,  0.0506,  0.1020, -0.0112],
         ...,
         [-0.1376,  0.0968, -0.0120,  ...,  0.0367,  0.0514,  0.0476],
         [-0.1376,  0.0968, -0.0120,  ...,  0.0367,  0.0514,  0.0476],
         [-0.1376,  0.0968, -0.0120,  ...,  0.0367,  0.0514,  0.0476]],

        [[-0.0709, -0.0145,  0.0304,  ...,  0.0891,  0.1155,  0.1585],
         [-0.0871, -0.0206,  0.0171,  ...,  0.0145,  0.1359,  0.1032],
         [-0.0992,  0.0107,  0.0189,  ..., -0.0168,  0.0872,  0.1363],
         ...,
         [-0.0776, -0.0085,  0.0244,  ...,  0.0144,  0.0907,  0.0911],
         [-0.0776, -0.0085,  0.0244,  ...,  0.0144,  0.0907,  0.0911],
         [-0.0776, -0.0085,  0.0244,  ...,  0.0144,  0.0907,  0.0911]],

        [[-0.0836, -0.0048,  0.0421,  ...,  0.0171,  0.0731,  0.1269],
         [-0.0532,  0.0138,  0.0299,  ...,  0

In [26]:
print(outputs.shape)

torch.Size([10, 20, 512])


###**콘텐츠 라이선스**

<font color='red'><b>**WARNING**</b></font> : **본 교육 콘텐츠의 지식재산권은 재단법인 네이버커넥트에 귀속됩니다. 본 콘텐츠를 어떠한 경로로든 외부로 유출 및 수정하는 행위를 엄격히 금합니다.** 다만, 비영리적 교육 및 연구활동에 한정되어 사용할 수 있으나 재단의 허락을 받아야 합니다. 이를 위반하는 경우, 관련 법률에 따라 책임을 질 수 있습니다.

