<a href="https://colab.research.google.com/github/Taiga10969/Lecture-Transformer/blob/main/copy_code_Transformer_05_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 写経して理解するTransformer_05：Transformer

01〜04で作成してきた各機構を用いてTransformerを構築していく．

## 必要ライブラリのインポート

In [60]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

## 各クラスの定義

In [61]:
class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model) # torch.nn の nn.Embedding を利用
        self.d_model = d_model

    def forward(self, x):
        # x：torch.Size([1, 7]) << [batch_size, src_len]
        embedded = self.embedding(x)
        # embedded : torch.Size([1, 7, 128]) << [batch_size, src_len, d_model]
        return embedded

In [62]:
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super().__init__()

        self.d_model = d_model

        pe = torch.zeros(max_seq_len, d_model)

        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2*i) / d_model)))
                pe[pos, i+1] = math.cos(pos / (10000 ** ((2*i) / d_model)))

        self.pe = nn.Parameter(pe, requires_grad=False)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        ret = math.sqrt(self.d_model) * x + self.pe[:seq_len, :].unsqueeze(0)
        return ret

    def get_pe(self, x):
        batch_size, seq_len, _ = x.size()
        return self.pe[:seq_len, :].unsqueeze(0)

In [63]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query = self.W_q(query)
        key = self.W_k(key)
        value = self.W_v(value)

        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)

        attention_score = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.depth, dtype=torch.float32))

        if mask is not None:
            attention_score += mask

        attention_weights = F.softmax(attention_score, dim=-1)
        attention = torch.matmul(attention_weights, value)
        attention = attention.permute(0, 2, 1, 3).contiguous()
        attention = attention.view(batch_size, -1, self.d_model)

        return self.W_o(attention)

In [64]:
class FeedForwardNetwork(nn.Module):

  def __init__(self, dim, hidden_dim, dropout=0.1):
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.linear_1 = nn.Linear(dim, hidden_dim)
    self.relu = nn.ReLU()
    self.linear_2 = nn.Linear(hidden_dim, dim)

  def forward(self, x):
    x = self.linear_1(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.linear_2(x)
    return x

## Transformer Encoder の構築
Encoderブロックを複数積層してEncoderを構築するため，<br>
EncoderBlockをクラスで定義し，そのクラスをインスタンスとして複数呼び出すEncoderクラスを作成する．

In [93]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, hidden_dim, dropout=0.1):
        super().__init__()

        self.MHA = MultiHeadAttention(d_model, num_heads)
        self.layer_norm1 = nn.LayerNorm([d_model])
        self.layer_norm2 = nn.LayerNorm([d_model])
        self.FFN = FeedForwardNetwork(d_model, hidden_dim)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
      # self-Attention >> Q，K，Vに同じデータを流す
      Q = K = V = _x = x

      x = self.MHA(Q, K, V, mask)
      x = self.dropout_1(x)
      x = x + _x # Add
      x = self.layer_norm1(x) # Norm

      _x = x

      x = self.FFN(x)
      x = self.dropout_2(x)
      x = x + _x #Add
      x = self.layer_norm2(x) # Norm

      return x

In [125]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, max_seq_len, hidden_dim, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embed = Embedder(vocab_size, d_model)
        self.PE = PositionalEncoder(d_model, max_seq_len)
        self.dropout = nn.Dropout(dropout)
        self.EncoderBlocks = nn.ModuleList([EncoderBlock(d_model, num_heads, hidden_dim) for _ in range(6)])

    def forward(self, x, mask=None):
        x = self.embed(x)
        x = x*(self.d_model**0.5)
        x = self.PE(x)
        x = self.dropout(x)

        if mask is not None:
            mask = torch.where(mask == 1, torch.tensor(0), torch.tensor(float('-inf')))

        for i in range(6):
            x = self.EncoderBlocks[i](x, mask)
        return x

## Transformer Decoder の構築
Decoderブロックを複数積層してDecoderを構築するため，<br>
DecoderBlockをクラスで定義し，そのクラスをインスタンスとして複数呼び出すDecoderクラスを作成する．

In [110]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, hidden_dim, dropout=0.1):
        super().__init__()

        self.MHA_1 = MultiHeadAttention(d_model, num_heads)
        self.MHA_2 = MultiHeadAttention(d_model, num_heads)

        self.layer_norm_1 = nn.LayerNorm([d_model])
        self.layer_norm_2 = nn.LayerNorm([d_model])
        self.layer_norm_3 = nn.LayerNorm([d_model])

        self.FFN = FeedForwardNetwork(d_model, hidden_dim)

        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        self.dropout_3 = nn.Dropout(dropout)

    def forward(self, tgt, encoder_output, decoder_mask=None, encoder_mask=None):
        # Decoder の一つ目のMHAは，self AttentionのMasked Multi-Head Attentionとなっており，入力QKVは，全て同じデータ
        Q = K = V = _x = tgt

        # 情報の先読みを防止するマスクの作成
        seq_len = tgt.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len) * -float('inf'), diagonal=1)
        mask = mask.unsqueeze(0).unsqueeze(1) # バッチとヘッドの次元を追加

        if decoder_mask is not None:
            decoder_mask = torch.where(decoder_mask == 1, torch.tensor(0), torch.tensor(float('-inf')))
            mask = mask + decoder_mask

        x = self.MHA_1(Q, K, V, mask)
        x = self.dropout_1(x)
        x = x + _x
        x = self.layer_norm_1(x)

        # ここまでアーキテクチャ図のDecoderの下半分
        # ここからアーキテクチャ図のDecoderの上半分

        Q = x # queryには下半分からの出力を
        _x = x
        K = V = encoder_output # key,valueにはencoderからの出力を

        x = self.MHA_2(Q, K, V, encoder_mask)
        x = self.dropout_2(x)
        x = x + _x
        x = self.layer_norm_2(x)

        _x = x

        x = self.FFN(x)
        x = self.dropout_3(x)
        x = x + _x
        x = self.layer_norm_3(x)

        return x

In [113]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, max_seq_len, hidden_dim, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embed = Embedder(vocab_size, d_model)

        self.PE = PositionalEncoder(d_model, max_seq_len)

        self.DecoderBlocks = nn.ModuleList([DecoderBlock(d_model, num_heads, hidden_dim) for _ in range(6)])
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, encoder_output, mask=None, encoder_mask=None):
        x = self.embed(tgt)
        x = x*(self.d_model**0.5)
        x = self.PE(x)
        x = self.dropout(x)

        for i in range(6):
            x = self.DecoderBlocks[i](x, encoder_output, mask, encoder_mask)

        x = self.linear(x)
        return x

## Transformer実装

In [114]:
class Transformer(nn.Module):
    def __init__(self, vocab_size_jp, vocab_size_en, d_model, num_heads, max_seq_len, hidden_dim):
        super().__init__()
        self.encoder = Encoder(vocab_size_jp, d_model, num_heads, max_seq_len, hidden_dim)
        self.decoder = Decoder(vocab_size_en, d_model, num_heads, max_seq_len, hidden_dim)

    def forward(self, src, tgt, src_mask, tgt_mask):
        encoder_output = self.encoder(src, src_mask)
        output = self.decoder(tgt, encoder_output, tgt_mask, src_mask)
        return output

## Transformerの動作確認

In [117]:
from transformers import GPT2Tokenizer, AutoTokenizer
# 文字列をトークンのID列に変換するトークナイザーはGPT-2のものを使用します
# 日本語用と英語用でそれぞれ定義
tokenizer_jp = AutoTokenizer.from_pretrained("colorfulscoop/gpt2-small-ja")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
                                          #add_bos_token = True
                                          )
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# 入力するtextを定義
text_jp = '人工知能 (AI) は、推論・判断などの知的な機能を備えたコンピュータ・システムのことです。'

# 出力を期待するtext (翻訳結果) を定義
text = "Artificial intelligence (AI) is a computer system equipped with intelligent functions such as reasoning and judgment."

# 入力するtextをトークナイザーを用いて数値(id)化する
inputs_jp = tokenizer_jp(text_jp,
                         padding=True,
                         truncation=True,
                         return_tensors='pt',)
outputs_en = tokenizer(text,
                       padding=True,
                       truncation=True,
                       return_tensors='pt',)

# モデルに入力するID列を表示・形状を確認

print("inputs_jp['input_ids']　: ", inputs_jp['input_ids'])
print("inputs_jp['input_ids'].shape　: ", inputs_jp['input_ids'].shape)
print("inputs_jp['input_ids'] token : ", tokenizer_jp.convert_ids_to_tokens(inputs_jp['input_ids'][0]))

print("outputs_en['input_ids']　: ", outputs_en['input_ids'])
print("outputs_en['input_ids'].shape　: ", outputs_en['input_ids'].shape)
print("outputs_en['input_ids'] token : ", tokenizer.convert_ids_to_tokens(outputs_en['input_ids'][0]))

inputs_jp['input_ids']　:  tensor([[27193,    71,  5257,    12,    19,     9,     6, 26805,    11,  3731,
           134,   808,   141,   693,  8238,  2606,    11, 10149,   641,   318,
             7]])
inputs_jp['input_ids'].shape　:  torch.Size([1, 21])
inputs_jp['input_ids'] token :  ['人工知能', '▁(', 'AI', ')', '▁', 'は', '、', '推論', '・', '判断', 'などの', '知', '的な', '機能', 'を備えた', 'コンピュータ', '・', 'システムの', 'ことで', 'す', '。']
outputs_en['input_ids']　:  tensor([[ 8001,  9542,  4430,   357, 20185,     8,   318,   257,  3644,  1080,
         10911,   351, 12661,  5499,   884,   355, 14607,   290,  8492,    13]])
outputs_en['input_ids'].shape　:  torch.Size([1, 20])
outputs_en['input_ids'] token :  ['Art', 'ificial', 'Ġintelligence', 'Ġ(', 'AI', ')', 'Ġis', 'Ġa', 'Ġcomputer', 'Ġsystem', 'Ġequipped', 'Ġwith', 'Ġintelligent', 'Ġfunctions', 'Ġsuch', 'Ġas', 'Ġreasoning', 'Ġand', 'Ġjudgment', '.']


In [135]:
vocab_size_jp = tokenizer_jp.vocab_size
vocab_size_en = tokenizer_jp.vocab_size
d_model = 128
num_heads = 4
max_seq_len = 256
hidden_dim = 256

model = Transformer(vocab_size_jp = vocab_size_jp,
                    vocab_size_en = vocab_size_en,
                    d_model = d_model,
                    num_heads = num_heads,
                    max_seq_len = max_seq_len,
                    hidden_dim = hidden_dim)

output = model(src = inputs_jp['input_ids'],
               tgt = outputs_en['input_ids'],
               src_mask = inputs_jp['attention_mask'],
               tgt_mask = outputs_en['attention_mask'])

print("output.shape : ", output.shape)

outputs_IDs = torch.argmax(output, dim=-1)
print("outputs_IDs : ", outputs_IDs)

# モデルが出力した数値(id)を各単語に逆変換
generated_text = tokenizer.decode(outputs_IDs[0])
print("generated_text : ", generated_text)
print("model_output token : ", tokenizer.convert_ids_to_tokens(outputs_IDs[0]))

output.shape :  torch.Size([1, 20, 32000])
outputs_IDs :  tensor([[10013, 10388,    88, 15909, 30198, 28016, 24371, 11900,  8524, 31477,
         10796, 22039,   303, 14237, 28863,  5878,  8341, 28826, 27332, 31768]])
generated_text :   visiting contraryy adjacent Films Whereas Console /*then squads coins"><ve summitornings 2001 listsseed �Init
model_output token :  ['Ġvisiting', 'Ġcontrary', 'y', 'Ġadjacent', 'ĠFilms', 'ĠWhereas', 'ĠConsole', 'Ġ/*', 'then', 'Ġsquads', 'Ġcoins', '"><', 've', 'Ġsummit', 'ornings', 'Ġ2001', 'Ġlists', 'seed', 'Ġï', 'Init']


現状は何も学習していない乱数によるパラメータなので，期待している次単語予測による文章が生成できていないが，<br>
Transformerの大まかなデータの流れはこのようになっている．<br>
<br>
このプログラムを参考に，色々データの中身や形状をPrint()して表示させてみて，理解につなげてみよう！<br>
<br>
以上<br>
<br>
中部大学 機械知覚&ロボティクスグループ（藤吉研究室）<br>
増田 大河
