## Self Attention with Relative Position Representations 논문 실습

- 본 논문은 Attention is all you need (NIPS 2017) 에서 제안한 Transformer Architecture를 기반으로 실습합니다.
- Attention is all you need 에서 제안한 아키텍처 상에서 Self-Attention 모듈만 개선함으로써 성능 개선을 실습합니다.

#### 데이터 전처리 (PreProcessing)
- 허깅페이스 API를 이용해서 대표적인 영어-독어 데이터셋인 **Multi30k** 를 불러옵니다.

In [3]:
from datasets import load_dataset

dataset = load_dataset("bentrevett/multi30k")

train_dataset, validation_dataset, test_dataset = dataset['train'], dataset['validation'], dataset['test']

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
print(train_dataset[0])

{'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}


- **Tokenizer** 및 **Vocab** 생성

In [6]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

In [7]:
unknown_token = "<unk>"

def initialize_tokenizer() -> Tokenizer:
    tokenizer = Tokenizer(WordLevel(unk_token=unknown_token))
    tokenizer.pre_tokenizer = Whitespace()
    return tokenizer

de_tokenizer, en_tokenizer = [initialize_tokenizer() for _ in range(2)]

In [8]:
# 학습용 trainer 생성
pad_token, sos_token, eos_token = "<pad>", "<sos>", "<eos>"
special_tokens = [unknown_token, pad_token, sos_token, eos_token]

trainer = WordLevelTrainer(special_tokens=special_tokens, min_frequency=2)

In [9]:
# tokenizer 학습
train_de, train_en = train_dataset['de'], train_dataset['en']

de_tokenizer.train_from_iterator(train_de, trainer=trainer)
en_tokenizer.train_from_iterator(train_en, trainer=trainer)

In [10]:
# tokenizer 학습 결과 확인

print("[DE] vocab size: {}".format(de_tokenizer.get_vocab_size()))
print("[EN] vocab size: {}".format(en_tokenizer.get_vocab_size()))

print("[DE] Sample DE vocab tokens: {}".format(list(de_tokenizer.get_vocab().keys())[:10]))
print("[EN] Sample EN vocab tokens: {}".format(list(en_tokenizer.get_vocab().keys())[:10]))

[DE] vocab size: 8060
[EN] vocab size: 6203
[DE] Sample DE vocab tokens: ['Spielkonsole', 'zuläuft', 'Barkeeperin', 'kontrolliert', 'rettet', 'Mädchengruppe', 'Esel', 'Hochzeit', 'verloren', 'oder']
[EN] Sample EN vocab tokens: ['juice', 'pajama', 'watched', 'Hopper', 'architectural', 'onto', 'engaging', 'tortillas', 'away', 'littered']


In [11]:
# 특수 토큰 체크
for special_token in special_tokens:
    print("[DE] special token: {}, index: {}".format(special_token, de_tokenizer.get_vocab()[special_token]))
    print("[EN] special token: {}, index: {}".format(special_token, en_tokenizer.get_vocab()[special_token]))

[DE] special token: <unk>, index: 0
[EN] special token: <unk>, index: 0
[DE] special token: <pad>, index: 1
[EN] special token: <pad>, index: 1
[DE] special token: <sos>, index: 2
[EN] special token: <sos>, index: 2
[DE] special token: <eos>, index: 3
[EN] special token: <eos>, index: 3


- 하이퍼 파라미터 정의

In [96]:
class ModelConfiguration:
    def __init__(self, 
                 max_len: int = 768, 
                 batch_size: int = 32, 
                 hidden_size: int = 512, 
                 ffn_size: int = 2048,
                 num_heads: int = 8, 
                 num_layers: int = 6, 
                 dropout_pb: float = 0.1, 
                 eps: float = 1e-12, 
                 max_relative_position: int = 128, 
                 src_vocab_size: int = 0, 
                 trg_vocab_size: int = 0
                ):
        self.max_len = max_len
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.ffn_size = ffn_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.dropout_pb = dropout_pb
        self.eps = eps
        self.src_vocab_size = src_vocab_size
        self.trg_vocab_size = trg_vocab_size

model_config = ModelConfiguration(src_vocab_size=de_tokenizer.get_vocab_size(), trg_vocab_size=en_tokenizer.get_vocab_size())

- 데이터 전처리
    - 데이터 패딩 등...

In [15]:
de_pad_id, en_pad_id = de_tokenizer.token_to_id(pad_token), en_tokenizer.token_to_id(pad_token)
de_sos_id, en_sos_id = de_tokenizer.token_to_id(sos_token), en_tokenizer.token_to_id(sos_token)
de_eos_id, en_eos_id = de_tokenizer.token_to_id(eos_token), en_tokenizer.token_to_id(eos_token)

In [16]:
# input: {"en" : "example_en", "de" : "example_de"}
# output: {"encoder_input_ids": [], "encoder_attention_mask": [], "decoder_input_ids": [], "decoder_attention_mask": [], "labels": []}
def preprocess(dataset: dict) -> dict:
    max_len = model_config.max_len

    # 토큰 id로 변환
    src_input_ids = de_tokenizer.encode(dataset['de']).ids
    trg_input_ids = en_tokenizer.encode(dataset['en']).ids

    # decoder input
    decoder_input = [en_sos_id] + trg_input_ids
    labels = trg_input_ids + [en_eos_id]

    # padding
    encoder_input = src_input_ids[:max_len] + [de_pad_id] * max(0, max_len - len(src_input_ids))
    decoder_input = decoder_input[:max_len] + [en_pad_id] * max(0, max_len - len(decoder_input))
    labels = labels[:max_len] + [en_pad_id] * max(0, max_len - len(labels))
    # Optional. loss 계산시 pad_id를 계산하지 않도록 ignore_index 적용
    labels = [token if token != en_pad_id else -100 for token in labels]

    # Attention mask
    encoder_attention_mask = [1 if token != de_pad_id else 0 for token in encoder_input]
    decoder_attention_mask = [1 if token != en_pad_id else 0 for token in decoder_input]

    return {
        "encoder_input_ids" : encoder_input,
        "encoder_attention_mask" : encoder_attention_mask,
        "decoder_input_ids" : decoder_input,
        "decoder_attention_mask" : decoder_attention_mask,
        "labels" : labels
    }

In [17]:
train_dataset = train_dataset.map(preprocess, remove_columns=['en', 'de'])
validation_dataset = validation_dataset.map(preprocess, remove_columns=['en', 'de'])
test_dataset = test_dataset.map(preprocess, remove_columns=['en', 'de'])

Map: 100%|███████████████████████| 29000/29000 [00:08<00:00, 3621.94 examples/s]
Map: 100%|█████████████████████████| 1014/1014 [00:00<00:00, 3628.54 examples/s]
Map: 100%|█████████████████████████| 1000/1000 [00:00<00:00, 3574.80 examples/s]


- DataLoader 설정

In [19]:
import torch

def collate_function(batch):
    return {
        key: torch.tensor([data[key] for data in batch], dtype=torch.long) for key in batch[0]
    }

In [20]:
from torch.utils.data import DataLoader

batch_size = model_config.batch_size

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_function)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_function)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_function)

In [21]:
# 배치 샘플 확인
batch = next(iter(train_loader))

for key, value in batch.items():
    print("{}: shape={}".format(key, value.shape))

encoder_input_ids: shape=torch.Size([32, 768])
encoder_attention_mask: shape=torch.Size([32, 768])
decoder_input_ids: shape=torch.Size([32, 768])
decoder_attention_mask: shape=torch.Size([32, 768])
labels: shape=torch.Size([32, 768])


#### 토큰 임베딩
- 해당 실습에서는 파동 함수가 아닌 학습 임베딩을 이용하여 실습합니다.

In [23]:
import torch

# 학습 device 정의
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

print(device)

mps


In [24]:
import torch.nn as nn

class Embeddings(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int, max_len: int, dropout_pb: float, eps: float):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_size)
        self.positional_embedding = nn.Embedding(max_len, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size, eps=eps)
        self.dropout = nn.Dropout(dropout_pb)

    # input: (batch_size, max_len)
    # output: (batch_size, max_len, hidden_size)
    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # positional sequence 생성
        sequence_len = input_ids.size(1) # max_len
        positional_ids = torch.arange(sequence_len, device=device).unsqueeze(0).expand_as(input_ids) # (batch_size, max_len)

        # Embedding
        token_embeddings = self.token_embedding(input_ids) # (batch_size, max_len, hidden_size)
        positional_embeddings = self.positional_embedding(positional_ids) # (batch_size, max_len, hidden_size)

        # Add/Norm -> Dropout
        embeddings = token_embeddings + positional_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

- 임베딩 결과 확인

In [26]:
# 임베딩 검증
embedding_layer = Embeddings(
    vocab_size=model_config.src_vocab_size,
    hidden_size=model_config.hidden_size,
    max_len=model_config.max_len,
    dropout_pb=model_config.dropout_pb,
    eps=model_config.eps
).to(device)

batch = next(iter(train_loader))
input_ids = batch['encoder_input_ids'].to(device)

embeddings = embedding_layer(input_ids)

# 결과 확인
print("Input Shape: {}".format(input_ids.shape))
print("Embedding Shape: {}".format(embeddings.shape))

Input Shape: torch.Size([32, 768])
Embedding Shape: torch.Size([32, 768, 512])


#### Multi-Head Attention 구현

- Transformer 아키텍처의 핵심인 멀티 헤드 어텐션을 구현합니다.
    - **scaled-dot-product attention** 구현
    - **Attention Head** 구현
    - Attention Head를 조합하여 **Multi-Head Attention** 구현
- **Scaled-dot product Attention with RPR (Relative Positional Representatives)** 을 이용하여 상대 위치 임베딩을 적용합니다.

In [92]:
# playground
seq_len = 4
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
# (seq_len, seq_len)
relative_positions = position_ids[None, :] - position_ids[:, None]

# 양수화 -> 2 * max_len - 1 개의 원소 경우의 수를 가짐
relative_positions.clamp()

print(relative_positions)

tensor([[3, 4, 5, 6],
        [2, 3, 4, 5],
        [1, 2, 3, 4],
        [0, 1, 2, 3]], device='mps:0')


- **Attention Head 구현**

In [84]:
import torch.nn.functional as F
import torch

class AttentionHead(nn.Module):
    def __init__(self, hidden_dim: int, head_dim: int):
        super().__init__()
        self.query_projection = nn.Linear(hidden_dim, head_dim)
        self.key_projection = nn.Linear(hidden_dim, head_dim)
        self.value_projection = nn.Linear(hidden_dim, head_dim)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor=None) -> torch.Tensor:
        Q = self.query_projection(query)
        K = self.key_projection(key)
        V = self.value_projection(value)

        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        return attention_output

    # query: (batch_size, max_len, d_head)
    # key: (batch_size, max_len, d_head)
    # value: (batch_size, max_len, d_head)
    # output: Attention_weight(batch_size, max_len, d_head), Output(batch_size, max_len, d_head)
    def scaled_dot_product_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor=None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # hidden size
        dim_k = query.size(-1)
    
        # (batch_size, max_len, d_head) * (batch_size, d_head, max_len) = (batch_size, max_len, max_len)
        # scaling by sqrt(dim_k)
        scores = torch.bmm(query, key.transpose(1, 2)) / (dim_k ** 0.5)
    
        # mask가 존재하면 -1e9를 더한다
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # (batch_size, max_len, max_len) -> (batch_size, max_len, max_len) softmax
        attention_weights = F.softmax(scores, dim=-1)

        # (batch_size, max_len, max_len) * (batch_size, max_len, d_head) = (batch_size, max_len, d_head)
        output = torch.bmm(attention_weights, value)

        return output, attention_weights

    def scaled_dot_product_attention_with_rpr(
        self, 
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        relative_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        pass

- **Multi-Head Attention** 구현

In [87]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        head_dim = hidden_dim // num_heads

        self.head_list = nn.ModuleList([AttentionHead(hidden_dim, head_dim) for _ in range(num_heads)])
        self.output_linear = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor=None) -> torch.Tensor:
        concatenated_attention = torch.concat([head(query, key, value, mask) for head in self.head_list], dim=-1)
        output = self.output_linear(concatenated_attention)

        return output

- 멀티 헤드 어텐션 테스트

In [90]:
multi_head_attention = MultiHeadAttention(hidden_dim = model_config.hidden_size, num_heads = model_config.num_heads).to(device)

attn_output = multi_head_attention(embeddings, embeddings, embeddings)

print(attn_output.size())

torch.Size([32, 768, 512])
