In [2]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

In [3]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
  [66, 88, 98, 47],
  [77, 65, 51, 77, 19, 15, 35, 19, 23]
]

In [6]:
def padding(data):
    max_len = len(max(data, key=len))
    print(f"Maximum sequence length: {max_len}")

    for i, seq in enumerate(tqdm(data)):
        if len(seq) < max_len:
            data[i] = seq + [pad_id] * (max_len - len(seq))
    
    return data, max_len

In [7]:
data, max_len = padding(data)

100%|██████████| 5/5 [00:00<00:00, 17077.79it/s]

Maximum sequence length: 10





In [8]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]

In [9]:
d_model = 8  # model의 hidden size
num_heads = 2  # head의 개수
inf = 1e12

In [10]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [11]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[ 0.8786,  0.3163, -0.8188, -1.7411,  0.6010,  1.3551, -0.0304,
          -0.8868],
         [-0.5582, -0.0140,  0.3634,  0.3151,  0.8392,  0.5616, -0.3119,
           0.6164],
         [-1.0502,  1.3600, -0.0739, -1.5370,  0.0528,  0.7068,  0.9794,
           2.1416],
         [-0.0666,  0.7159,  0.7612, -1.5253, -0.8612, -0.0231,  0.2097,
          -0.0408],
         [ 0.3174, -1.7833,  0.8307, -0.3914, -1.4998, -0.2120, -1.2893,
           1.3583],
         [ 0.3277, -0.5655, -0.9004,  1.3814,  1.2269,  0.9255,  0.3268,
           0.3711],
         [-0.8452,  1.6297,  0.4007,  1.4033,  0.1686, -0.2018,  0.2711,
           1.7532],
         [-0.5582, -0.0140,  0.3634,  0.3151,  0.8392,  0.5616, -0.3119,
           0.6164],
         [-0.1733,  1.7511,  0.2618, -0.2619, -0.8064,  0.3460,  0.1049,
          -0.2198],
         [-0.1733,  1.7511,  0.2618, -0.2619, -0.8064,  0.3460,  0.1049,
          -0.2198]],

        [[-0.5780,  0.0299,  1.6646,  1.5522,  1.0711,  0.2238, -1.2

In [14]:
padding_mask = (batch != pad_id).unsqueeze(1)  # (B, 1, L)

print(padding_mask)
print(padding_mask.shape)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False]]])
torch.Size([5, 1, 10])


In [26]:
nopeak_mask = torch.ones([1, max_len, max_len], dtype=torch.bool)  # (1, L, L)
nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L)

print(nopeak_mask)
print(nopeak_mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]])
torch.Size([1, 10, 10])


In [27]:
mask = padding_mask & nopeak_mask #(B, L, L)

print(mask)
print(mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  T

In [29]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

w_0 = nn.Linear(d_model, d_model)

In [30]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


In [34]:
attn_scores = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(d_k)  # (B, num_heads, L, L)

In [35]:
masks = mask.unsqueeze(1)  #(B, 1, L, L) broadcast 적용
masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf)

print(masked_attn_scores)
print(masked_attn_scores.shape)

tensor([[[[-1.8532e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-1.5706e-01, -1.6180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-2.9744e-01, -2.7197e-01, -1.7719e-01, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-1.5711e-01, -1.8225e-01,  1.6399e-01,  3.8102e-02, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 1.2096e-01,  2.5698e-01, -1.4434e+00, -9.8080e-01, -3.2418e-01,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 7.1007e-02, -8.7613e-04,  5.9786e-01,  5.3197e-01,  1.4229e-01,
           -2.2755e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-6.8584e-02, -4.9841e-02, -2.2177e-01,  1.2075e-01, -2.1001e-01,
      

In [36]:
attn_dists = F.softmax(masked_attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.5012, 0.4988, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.3171, 0.3253, 0.3576, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2188, 0.2134, 0.3017, 0.2660, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.3005, 0.3443, 0.0629, 0.0998, 0.1925, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1423, 0.1325, 0.2411, 0.2257, 0.1529, 0.1056, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1514, 0.1543, 0.1299, 0.1830, 0.1314, 0.1186, 0.1315, 0.0000,
           0.0000, 0.0000],
          [0.1173, 0.1167, 0.1497, 0.1689, 0.1171, 0.0866, 0.1270, 0.1167,
           0.0000, 0.0000],
          [0.1077, 0.1025, 0.1958, 0.1726, 0.1104, 0.0752, 0.1334, 0.1025,
           0.0000, 0.0000],
          [0.1077, 0.1025, 0.1958, 0.1726, 0.1104, 0.0752, 0.1334, 0.1025

In [37]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 10, 4])


In [39]:
class MultiheadAttention(nn.Module):
    def __init__(self):
        super(MultiheadAttention, self).__init__()

        # Q, K, V learnable matrices
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # Linear transformation for concatenated outputs
        self.w_0 = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size = q.shape[0]

        q = self.w_q(q)  # (B, L, d_model)
        k = self.w_k(k)  # (B, L, d_model)
        v = self.w_v(v)  # (B, L, d_model)

        q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
        k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
        v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

        q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
        k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
        v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

        attn_values = self.self_attention(q, k, v, mask=mask)  # (B, num_heads, L, d_k)
        attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

        return self.w_0(attn_values)

    def self_attention(self, q, k, v, mask=None):
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)

        if mask is not None:
            mask = mask.unsqueeze(1)  # (B, 1, L, L) or  (B, 1, 1, L)
            attn_scores = attn_scores.masked_fill_(mask == False, -1*inf)

        attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

        attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

        return attn_values

In [40]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask=mask)  # (B, L, d_model)

In [41]:
print(outputs)
print(outputs.shape)

tensor([[[ 0.6314,  0.4195, -0.1171,  0.0016,  0.0380, -0.2114,  0.1884,
          -0.1033],
         [ 0.4099,  0.4891, -0.0286,  0.0265, -0.1412, -0.1441,  0.1327,
          -0.0787],
         [ 0.3775,  0.6340, -0.1236, -0.0063, -0.2284,  0.0338,  0.3218,
          -0.1715],
         [ 0.4674,  0.4579, -0.2584,  0.0084, -0.0502,  0.1042,  0.3732,
          -0.2337],
         [ 0.2759,  0.2971, -0.0438,  0.0971, -0.1022, -0.0687,  0.0759,
           0.0370],
         [ 0.2449,  0.4944,  0.0694,  0.0766, -0.2990, -0.1418,  0.0251,
           0.0619],
         [ 0.3605,  0.3897, -0.0747, -0.0022, -0.2653,  0.0202,  0.1802,
          -0.1252],
         [ 0.2659,  0.5072, -0.0055,  0.0210, -0.2753, -0.0351,  0.1209,
          -0.0538],
         [ 0.3506,  0.4273, -0.1284,  0.0107, -0.2250,  0.0663,  0.2211,
          -0.1566],
         [ 0.3506,  0.4273, -0.1284,  0.0107, -0.2250,  0.0663,  0.2211,
          -0.1566]],

        [[ 0.3631,  0.2638,  0.0502, -0.0193, -0.1381, -0.1555,  0.0