##**7. Multi-head Attention**
1. Multi-head attention 및 self-attention 구현.
2. 각 과정에서 일어나는 연산과 input/output 형태 이해.

### **필요 패키지 import**

In [15]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

In [16]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [17]:
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))

  return data, max_len

In [18]:
data, max_len = padding(data)

100%|██████████| 10/10 [00:00<00:00, 62883.12it/s]

Maximum sequence length: 20





In [19]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

### **Hyperparameter 세팅 및 embedding**

In [20]:
d_model = 512  # model의 hidden size
num_heads = 8  # head의 개수, d_model%num_heads == 0이어야 한다

In [21]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [22]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[ 1.7307, -1.4483, -0.0433,  ..., -1.8369, -0.7755, -0.6538],
         [ 1.2350, -0.2202,  0.1944,  ..., -0.8821, -1.9208, -0.9424],
         [ 1.1804, -0.4475,  1.6954,  ...,  1.6695,  0.0476,  2.2247],
         ...,
         [-1.1962,  2.8726, -1.2271,  ..., -1.1268, -1.8536,  0.2114],
         [-1.1962,  2.8726, -1.2271,  ..., -1.1268, -1.8536,  0.2114],
         [-1.1962,  2.8726, -1.2271,  ..., -1.1268, -1.8536,  0.2114]],

        [[-1.0603, -0.0627, -0.1550,  ..., -0.4284,  0.7544, -1.2059],
         [ 0.9366,  1.0761, -0.4529,  ...,  1.7042,  0.3171, -0.2775],
         [-0.1113,  1.2346,  1.0658,  ...,  0.8423,  0.6119,  2.4125],
         ...,
         [-1.1962,  2.8726, -1.2271,  ..., -1.1268, -1.8536,  0.2114],
         [-1.1962,  2.8726, -1.2271,  ..., -1.1268, -1.8536,  0.2114],
         [-1.1962,  2.8726, -1.2271,  ..., -1.1268, -1.8536,  0.2114]],

        [[-0.6782, -0.4049,  0.1398,  ..., -0.2391, -1.0342,  0.5864],
         [-0.2481, -0.0813, -0.9470,  ..., -1

### **Linear transformation & 여러 head로 나누기**

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

In [23]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [24]:
w_0 = nn.Linear(d_model, d_model)

In [36]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


Q, k, v를 `num_head`개의 차원 분할된 여러 vector로 만듭니다.

In [26]:
batch_size = q.shape[0]
d_k = d_model // num_heads # 실제 논문에서 구현 방식, embedding된 단어 하나의 정보(features)를 여러 개로 쪼개서 multi-head를 구함

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [27]:
q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])


### **Scaled dot-product self-attention 구현**

각 head에서 실행되는 self-attetion 과정입니다.

In [28]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[0.0846, 0.0487, 0.0480,  ..., 0.0371, 0.0371, 0.0371],
          [0.0474, 0.0475, 0.0419,  ..., 0.0449, 0.0449, 0.0449],
          [0.0416, 0.0432, 0.0746,  ..., 0.0525, 0.0525, 0.0525],
          ...,
          [0.0357, 0.0621, 0.0570,  ..., 0.0577, 0.0577, 0.0577],
          [0.0357, 0.0621, 0.0570,  ..., 0.0577, 0.0577, 0.0577],
          [0.0357, 0.0621, 0.0570,  ..., 0.0577, 0.0577, 0.0577]],

         [[0.0644, 0.0443, 0.0494,  ..., 0.0351, 0.0351, 0.0351],
          [0.0439, 0.0373, 0.0386,  ..., 0.0536, 0.0536, 0.0536],
          [0.0704, 0.0452, 0.0316,  ..., 0.0697, 0.0697, 0.0697],
          ...,
          [0.0511, 0.0473, 0.0432,  ..., 0.0665, 0.0665, 0.0665],
          [0.0511, 0.0473, 0.0432,  ..., 0.0665, 0.0665, 0.0665],
          [0.0511, 0.0473, 0.0432,  ..., 0.0665, 0.0665, 0.0665]],

         [[0.0525, 0.0675, 0.0236,  ..., 0.0370, 0.0370, 0.0370],
          [0.0630, 0.0313, 0.0394,  ..., 0.0355, 0.0355, 0.0355],
          [0.0435, 0.0525, 0.0636,  ..., 0

In [29]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([10, 8, 20, 64])


### **각 head의 결과물 병합**

각 head의 결과물을 concat하고 동일 차원으로 linear transformation합니다.

In [30]:
attn_values = attn_values.transpose(1, 2)  # (B, L, num_heads, d_k)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  # (B, L, d_model) 
# congiguous() = 메모리에 연속적으로 배열된 tensor를 새로 반환함

print(attn_values.shape)

torch.Size([10, 20, 512])


In [31]:
outputs = w_0(attn_values)

print(outputs)
print(outputs.shape)

tensor([[[-5.6372e-02,  2.0260e-02, -2.0673e-01,  ..., -1.0473e-01,
           3.9225e-02, -1.7363e-02],
         [-4.0734e-02,  3.6285e-02, -1.6315e-01,  ..., -9.6915e-02,
           9.2038e-02,  2.4939e-02],
         [-5.6946e-02,  5.4392e-02, -1.4964e-01,  ..., -7.5285e-02,
           4.9044e-02, -1.8751e-02],
         ...,
         [-5.0095e-02,  7.5420e-02, -1.6180e-01,  ..., -7.3211e-02,
           5.2999e-02, -5.5118e-03],
         [-5.0095e-02,  7.5420e-02, -1.6180e-01,  ..., -7.3211e-02,
           5.2999e-02, -5.5118e-03],
         [-5.0095e-02,  7.5420e-02, -1.6180e-01,  ..., -7.3211e-02,
           5.2999e-02, -5.5118e-03]],

        [[ 5.4670e-02,  9.4746e-02,  6.1076e-02,  ..., -1.5567e-01,
          -1.8943e-02, -9.3740e-02],
         [ 3.8374e-02, -1.0920e-04,  1.8537e-02,  ..., -1.5706e-01,
          -2.7648e-02, -4.2522e-02],
         [ 6.9620e-02,  8.9878e-02,  3.2770e-02,  ..., -1.3834e-01,
           3.0849e-02, -7.5695e-02],
         ...,
         [ 8.5678e-02,  6

### **전체 코드**

위의 과정을 모두 합쳐 하나의 Multi-head attention 모듈을 구현하겠습니다.

In [32]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v):
    batch_size = q.shape[0]

    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  def self_attention(self, q, k, v):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [33]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb)  # (B, L, d_model)

In [34]:
print(outputs)
print(outputs.shape)

tensor([[[-0.0404,  0.0224,  0.2301,  ...,  0.0857,  0.1325,  0.1532],
         [-0.0600,  0.0390,  0.1469,  ...,  0.1620,  0.1024,  0.2262],
         [-0.0850,  0.0875,  0.1645,  ...,  0.0564,  0.1324,  0.1759],
         ...,
         [-0.0516,  0.0557,  0.1825,  ...,  0.1076,  0.1603,  0.1408],
         [-0.0516,  0.0557,  0.1825,  ...,  0.1076,  0.1603,  0.1408],
         [-0.0516,  0.0557,  0.1825,  ...,  0.1076,  0.1603,  0.1408]],

        [[-0.2245, -0.0291,  0.2946,  ..., -0.0619,  0.1280,  0.2630],
         [-0.2048, -0.0114,  0.3183,  ..., -0.0680,  0.1399,  0.2983],
         [-0.1814, -0.0591,  0.3207,  ..., -0.0459,  0.1685,  0.2697],
         ...,
         [-0.1671, -0.0630,  0.2798,  ..., -0.0554,  0.1066,  0.2606],
         [-0.1671, -0.0630,  0.2798,  ..., -0.0554,  0.1066,  0.2606],
         [-0.1671, -0.0630,  0.2798,  ..., -0.0554,  0.1066,  0.2606]],

        [[-0.0944, -0.0834,  0.1834,  ..., -0.0875,  0.0301,  0.2550],
         [-0.1133, -0.0635,  0.2359,  ..., -0