<a href="https://colab.research.google.com/github/Tyanakai/transformer_from_scratch/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transfomer訓練
本ノートブックでは、スクラッチ実装したTransformerを訓練します。<br>
訓練データとして、斎藤 康毅氏著「ゼロから作るDeep Learning ②」で使用されていた、[こちらの日付変換データ](https://github.com/oreilly-japan/deep-learning-from-scratch-2/blob/master/dataset/date.txt)をお借りしました。<br>
多様な形式で日付を表現する文字列から、YYYY-MM-DD形式の文字列を予測するタスクとなっています。



## 環境、事前準備
本ノートブックは、上記の日付データをGoogle Driveの`/portfolio/transformer_from_scratch`フォルダ下に保存した上で、Google Colaboratoryでの実行を想定しています。<br>

In [1]:
import os

import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [6]:
DRIVE = "/content/drive/MyDrive/portfolio/transformer_from_scratch"

## load & preprocess data

In [7]:
with open(os.path.join(DRIVE, "date.txt"), mode="r") as f:
    text = f.read()

In [8]:
# データの境は\nで区切られています。
text[:200]

'september 27, 1994           _1994-09-27\nAugust 19, 2003              _2003-08-19\n2/10/93                      _1993-02-10\n10/31/90                     _1990-10-31\nTUESDAY, SEPTEMBER 25, 1984  _1984-0'

In [9]:
encoder_text = []
decoder_text = []

for line in text.split("\n")[:-1]:
    encoder_text.append(line[:-11].lower().lstrip())
    decoder_text.append(line[-11:] + ".") # 最終位置に"."を付加します。

In [10]:
# 入力文字列
encoder_text[:5]

['september 27, 1994           ',
 'august 19, 2003              ',
 '2/10/93                      ',
 '10/31/90                     ',
 'tuesday, september 25, 1984  ']

In [11]:
# 正解文字列
decoder_text[:5]

['_1994-09-27.',
 '_2003-08-19.',
 '_1993-02-10.',
 '_1990-10-31.',
 '_1984-09-25.']

In [12]:
len(encoder_text)

50000

## tokenizer

In [13]:
class Tokenizer():

    def __init__(self, text_list):
        self.text_list = text_list
        # 文字列を文字に分解しリスト化
        self.create_char_list()
        
        # 出現文字にidを対応させる
        self.create_char_id_dict()


    def pad_text(self, text):
        # 文末の空白をpad文字で埋める
        last_char_idx = len(text.strip())
        text = text[:last_char_idx] + "＠" * (len(text) - last_char_idx)
        return text


    def create_char_list(self):
        # 入力文字列を文字に分解する
        self.char_list = []
        for text in self.text_list:
            text = self.pad_text(text) # 文末の空白をpad文字で埋める
            self.char_list.append(list(text)) # 文字に分解しリスト化

    
    def create_char_id_dict(self):
        # 出現文字にidを対応させる
        self.id_char_dict = dict()
        self.char_id_dict = dict()
        self.unique_char = np.unique(self.char_list)
        
        for id, c in enumerate(self.unique_char):
            self.id_char_dict[id] = c
            self.char_id_dict[c] = id


    def attention_mask(self):
        # attention_maskを作る
        attention_mask = []
        for line in self.char_list:
            chars = np.array(line) # 文字リストをnp.array化
            attention_mask.append((chars != "＠") * 1) # 文字が＠ではない場所が1となる
        return np.array(attention_mask)


    def tokenize(self, text_list):
        # 文字をidに変換する
        token_list = []
        for text in text_list:
            token_list.append([self.char_id_dict[c] for c in text])

        return np.array(token_list)


    def detokenize(self, token_list):
        # idを文字に変換する
        char_list = []
        for line in token_list:
            char_list.append([self.id_char_dict[t] for t in line])
        return char_list

In [14]:
encoder_tokenizer = Tokenizer(encoder_text)
encoder_token_list = encoder_tokenizer.tokenize(encoder_text)
encoder_attention_mask = encoder_tokenizer.attention_mask()
encoder_num_char = encoder_tokenizer.unique_char.shape[0]
encoder_max_len = encoder_token_list.shape[1]

decoder_tokenizer = Tokenizer(decoder_text)
decoder_token_list = decoder_tokenizer.tokenize(decoder_text)
decoder_attention_mask = decoder_tokenizer.attention_mask()
decoder_num_char = decoder_tokenizer.unique_char.shape[0]
decoder_max_len = decoder_token_list.shape[1] - 1


# 入力token 
# 5000個分のデータをvalidationデータとして取り分けておきます。

tr_token_dict = {
    "encoder_input_ids": encoder_token_list[:-5000],
    "encoder_attention_mask": encoder_attention_mask[:-5000],
    "decoder_input_ids": decoder_token_list[:-5000, :-1],
    "decoder_attention_mask":decoder_attention_mask[:-5000, :-1]
    }

val_token_dict = {
    "encoder_input_ids": encoder_token_list[-5000:],
    "encoder_attention_mask": encoder_attention_mask[-5000:],
    "decoder_input_ids": decoder_token_list[-5000:, :-1],
    "decoder_attention_mask":decoder_attention_mask[-5000:, :-1]
    }

# 正解token
tr_decoder_output_ids = decoder_token_list[:-5000, 1:]
val_decoder_output_ids = decoder_token_list[-5000:, 1:]

In [15]:
# decoder側の入力と正解文字列のtokenを一個分ずらしています。
decoder_tokenizer.detokenize([decoder_token_list[0, :-1], tr_decoder_output_ids[0]])

[['_', '1', '9', '9', '4', '-', '0', '9', '-', '2', '7'],
 ['1', '9', '9', '4', '-', '0', '9', '-', '2', '7', '.']]

In [16]:
for key, v in tr_token_dict.items():
    print(f"{key}, {v.shape}")

encoder_input_ids, (45000, 29)
encoder_attention_mask, (45000, 29)
decoder_input_ids, (45000, 11)
decoder_attention_mask, (45000, 11)


In [17]:
# encoder側の入力文字列の最大長と使用する文字の種類
encoder_max_len, encoder_num_char

(29, 36)

In [18]:
# decoder側の入力文字列の最大長と使用する文字の種類
decoder_max_len, decoder_num_char

(11, 13)

## tf.data.Dataset

In [19]:
def get_dataset(x, y=None, dataset="valid"):

    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if dataset == "train":
        ds = ds.shuffle(512)
    ds = ds.batch(64)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

    return ds

In [20]:
tr_ds = get_dataset(x=tr_token_dict, y=tr_decoder_output_ids, dataset="train")
val_ds = get_dataset(x=val_token_dict, y=val_decoder_output_ids, dataset="valid")

In [21]:
list(tr_ds.as_numpy_iterator())[0]

({'decoder_attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
      

## model

### layers

In [22]:
class EncoderSelfAttention(tf.keras.layers.Layer):
    """
    encoder側のself attention
    Attributes:
        weight_dim: 入力に積算する重みの次元 (int)
        num_heads: multi head attentionのhead数 (int)
    """
    
    def __init__(self, weight_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.weight_dim = weight_dim
        self.num_heads = num_heads


    def split_transpose(self, x):
        """
        xをheadの数に分割し、後の積のため転置する
        Args:  
            x: tensor (batch_size, max_length, weight_dim)
        Returns: 
            x: tensor (batch_size, num_heads, max_length, weight_dim/num_heads)
        """
        x_shape = tf.shape(x)
        x = tf.reshape(x, [x_shape[0], x_shape[1], self.num_heads, -1])
        x = tf.transpose(x, perm=[0,2,1,3])
        return x

   
    def create_mask_for_pad(self, attention_mask1, attention_mask2):
        """
        paddingの位置を無視する為のmaskを作る
        Args: 
            attention_mask1: tensor (batch_size, max_length1)  padの位置 = 0
            attention_mask2: tensor (batch_size, max_length2)  padの位置 = 0
        Returns:
            (batch_size, num_heads, max_length1, max_length2)  padの位置 = True
        Note:
            #1 　(batch_size, max_length1, max_length2)のmaskを作り
            #2 　headの数だけrepeatし
            #3 　0,1を反転させる
        """
        # mask1: (batch_size, max_length1, 1)
        # mask2: (batch_size, 1, max_length2) に変形する
        mask1 = tf.reshape(attention_mask1, [tf.shape(attention_mask1)[0], -1, 1])
        mask2 = tf.reshape(attention_mask2, [tf.shape(attention_mask2)[0], 1, -1])

        p_mask = mask1 * mask2  #1
        p_mask = tf.repeat(p_mask[:,None,:,:], self.num_heads, axis=1)  #2
        p_mask = 1 - p_mask  #3
        return tf.cast(p_mask, tf.bool)


    def build(self, input_shape):
        self.wq = self.add_weight(
            "wq", shape=[input_shape[-1], self.weight_dim])
        self.wk = self.add_weight(
            "wk", shape=[input_shape[-1], self.weight_dim])
        self.wv = self.add_weight(
            "wv", shape=[input_shape[-1], self.weight_dim])
        self.wo = self.add_weight(
            "wo", shape=[self.weight_dim, input_shape[-1]])
        super().build(input_shape)
        
        
    def call(self, input, attention_mask):
        """
        Args:
            input: tensor (batch_size, max_length, hidden_dim)
            attention_mask: tensor (batch_size, max_length) padの位置 = 0
        
        Returns:
            tensor (batch_size, max_length, hidden_dim)
        """
        q = tf.matmul(input, self.wq)
        k = tf.matmul(input, self.wk)
        v = tf.matmul(input, self.wv)

        q = self.split_transpose(q)
        k = self.split_transpose(k)
        v = self.split_transpose(v)

        p_mask = self.create_mask_for_pad(attention_mask, attention_mask)
        mask = tf.cast(p_mask, tf.float32)

        logit = tf.matmul(q, k, transpose_b=True)
        logit += logit.dtype.min * mask   # set pad position to "-inf"

        attention_weight = tf.nn.softmax(
            logit / tf.sqrt(tf.cast(self.weight_dim, tf.float32)))
        multi_context_vec = tf.matmul(attention_weight, v)
        
        input_shape = tf.shape(input)
        multi_context_vec = tf.transpose(multi_context_vec, perm=[0,2,1,3])
        concat_vec = tf.reshape(
            multi_context_vec, 
            shape=[input_shape[0], input_shape[1], self.weight_dim]
            )
        encoded_vec = tf.matmul(concat_vec, self.wo)
        return encoded_vec


class DecoderSelfAttention(EncoderSelfAttention):
    """
    decoder側のself attention
    Attributes:
        weight_dim: 入力に積算する重みの次元 (int)
        num_heads: multi head attentionのhead数 (int)
    """
    
    def __init__(self, weight_dim, num_heads, **kwargs):
        super().__init__(weight_dim, num_heads, **kwargs)


    def create_mask_for_future_input(self, input):
        """
        自身より未来のinputを参照しない為のmaskを作る
        Args:
            input: tensor (batch_size, num_heads, max_length, max_length)
        Returns:
            tensor (batch_size, num_heads, max_length, max_length) maskの位置 = True
        Notes:
            右上三角行列 - 対角行列　＝　未来時刻の値が1のマスク行列 (f-mask)
            [[0, 1, 1, 1]
            [0, 0, 1, 1]
            [0, 0, 0, 1] 
            [0, 0, 0, 0]]
        """
        ones = tf.ones(tf.shape(input))

        # 右上三角行列 - 対角行列
        f_mask = tf.linalg.band_part(ones, 0, -1) \
               - tf.linalg.band_part(ones, 0, 0)
        return tf.cast(f_mask, tf.bool)
        
        
    def call(self, input, attention_mask):
        """
        Args:
            input: tensor (batch_size, max_length, hidden_dim)
            attention_mask: tensor (batch_size, max_length) padの位置 = 0
        
        Returns:
            tensor (batch_size, max_length, hidden_dim)
        Notes:
            future maskを適用する点でEncoderSelfAttentionのcallと異なる
        """
        q = tf.matmul(input, self.wq)
        k = tf.matmul(input, self.wk)
        v = tf.matmul(input, self.wv)
        
        q = self.split_transpose(q)
        k = self.split_transpose(k)
        v = self.split_transpose(v)

        logit = tf.matmul(q, k, transpose_b=True)

        f_mask = self.create_mask_for_future_input(logit) # create future mask
        p_mask = self.create_mask_for_pad(attention_mask, attention_mask)
        mask = tf.cast(tf.logical_or(f_mask, p_mask), tf.float32)
        
        logit += logit.dtype.min * mask  # set future or pad position to "-inf"

        attention_weight = tf.nn.softmax(
            logit / tf.sqrt(tf.cast(self.weight_dim, tf.float32)))
        multi_context_vec = tf.matmul(attention_weight, v)

        input_shape = tf.shape(input)
        multi_context_vec = tf.transpose(multi_context_vec, perm=[0,2,1,3])
        concat_vec = tf.reshape(
            multi_context_vec, 
            shape=[input_shape[0], input_shape[1], self.weight_dim]
            )
        encoded_vec = tf.matmul(concat_vec, self.wo)
        return encoded_vec


class EncoderDecoderAttention(EncoderSelfAttention):
    """
    decoder側のlayer
    decoder側のself attentionの出力と共に、encoder側の出力も参照する
    Attributes:
        weight_dim: 入力に積算する重みの次元 (int)
        num_heads: multi head attentionのhead数 (int)
    """

    def __init__(self, weight_dim, num_heads, **kwargs):
        super().__init__(weight_dim, num_heads, **kwargs)
        

    def call(self, 
             decoder_input, 
             decoder_attention_mask, 
             encoder_output, 
             encoder_attention_mask):
        """
        Args:
            decoder_input: tensor (batch_size, decoder_max_length, hidden_dim)
            decoder_attention_mask: tensor (batch_size, decoder_max_length) padの位置 = 0
            encoder_output: tensor (batch_size, encoder_max_length, hidden_dim)
            encoder_attention_mask: tensor (batch_size, decoder_max_length) padの位置 = 0
        Returns:
            tensor (batch_size, decoder_max_length, hidden_dim)
        """
        
        q = tf.matmul(decoder_input, self.wq)
        k = tf.matmul(encoder_output, self.wk)
        v = tf.matmul(encoder_output, self.wv)

        q = self.split_transpose(q)
        k = self.split_transpose(k)
        v = self.split_transpose(v)

        p_mask = self.create_mask_for_pad(decoder_attention_mask, encoder_attention_mask)
        mask = tf.cast(p_mask, tf.float32)

        logit = tf.matmul(q, k, transpose_b=True)
        logit += logit.dtype.min * mask   # set pad position to "-inf"

        attention_weight = tf.nn.softmax(
            logit / tf.sqrt(tf.cast(self.weight_dim, tf.float32)))
        multi_context_vec = tf.matmul(attention_weight, v)
        
        decoder_input_shape = tf.shape(decoder_input)
        multi_context_vec = tf.transpose(multi_context_vec, perm=[0,2,1,3])
        concat_vec = tf.reshape(
            multi_context_vec, 
            shape=[decoder_input_shape[0], decoder_input_shape[1], self.weight_dim]
            )
        encoded_vec = tf.matmul(concat_vec, self.wo)
        return encoded_vec


class LayerNormalizer(tf.keras.layers.Layer):
    """
    文単位で正規化を行う
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)


    def build(self, input_shape):
        self.scale = self.add_weight(
            "scale", initializer=tf.keras.initializers.Constant(1.))
        self.bias = self.add_weight(
            "bias", initializer=tf.keras.initializers.Constant(0.))
        super().build(input_shape)


    def call(self, input):
        """
        Args:
            input: tensor (batch_size, max_length, hidden_dim)
        
        Returns:
            tensor (batch_size, max_length, hidden_dim)
        """
        mean = tf.math.reduce_mean(input, axis=[1,2])[:, tf.newaxis, tf.newaxis]
        std = tf.math.reduce_std(input, axis=[1,2])[:, tf.newaxis, tf.newaxis]
        normalized = (input - mean) / (std + K.epsilon())
        output = normalized * self.scale + self.bias
        return output


class FeedForwardNeuralBlock(tf.keras.Model):
    """
    encoder, decoder両方で使用する全結合layer
    Attributes:
        hidden_dim: 全結合層の重みの次元
        dropout_rate: dropout層のパラメータ
    """

    def __init__(self, hidden_dim, dropout_rate, **kwargs):
        super().__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate

        self.filter_layer = tf.keras.layers.Dense(
            hidden_dim*4, activation="relu", use_bias=True, name="filter_layer")
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.output_layer = tf.keras.layers.Dense(
            hidden_dim, use_bias=True, name="output_layer")
        
      
    def call(self, input):
        """
        Args:
            input: tensor (batch_size, max_length, hidden_dim)
        Returns
            tensor (batch_size, max_length, hidden_dim)
        """
        x = self.filter_layer(input)
        x = self.dropout(x)
        output = self.output_layer(x)
        return output


class PositionalEncoder(tf.keras.layers.Layer):
    """
    入力されたtokenベクトルに位置ベクトルを加算するlayer
    """
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    
    def positional_vec(self, pos, embd_dim):
        """
        位置ベクトルを計算する
        
        Args:
            pos: 文におけるtokenの位置
            embd_dim: tokenベクトルの次元
        Returns:
            pos_v: np.array (None, pos, embd_dim)
                   ブロードキャストの為batch_sizeの次元を先頭に追加する
        """
        pos_v = np.zeros(shape=[pos, embd_dim])
        for p in range(pos):
            for i in range(embd_dim):
                if i % 2 == 0:
                    pos_v[p,i] = np.sin(p / np.power(10000, (i / embd_dim)))
                else:
                    pos_v[p,i] = np.cos(p / np.power(10000, ((i - 1) / embd_dim)))
        return pos_v[None,...]


    def build(self, input_shape):
        pos_vec = self.positional_vec(input_shape[1], input_shape[-1])
        self.pos_vec = tf.constant(pos_vec, dtype=tf.float32)
        super().build(input_shape)
        

    def call(self, input):
        """
        Args:
            input: embeddingされた文章のtensor (batch_size, max_length, hidden_dim)
        Returns:
            inputに位置ベクトルを加算したtensor (batch_size, max_length, hidden_dim)
        """
        return tf.add(input, self.pos_vec)

### transformer

In [23]:
class Encoder(tf.keras.models.Model):
    """
    一層のEncoder
    Attributes:
        at_weight_dim: attention機構で使用する重みの次元 
        num_heads: multi head attentionのhead数
        ffn_weight_dim: 全結合層の重みの次元。embeddingの次元に一致させる
        dropout_rate: dropout層のパラメータ
    """ 

    def __init__(
        self, 
        at_weight_dim=512, 
        num_heads=8,
        ffn_weight_dim=256, 
        dropout_rate=0.2,
        **kwargs
        ):

        super().__init__(**kwargs) 
        self.at_weight_dim = at_weight_dim
        self.num_heads = num_heads
        self.ffn_weight_dim = ffn_weight_dim
        self.dropout_rate = dropout_rate

        self.self_attention = EncoderSelfAttention(
            self.at_weight_dim, self.num_heads)
        self.layer_norm1 = LayerNormalizer()
        self.layer_norm2 = LayerNormalizer()
        self.ffn = FeedForwardNeuralBlock(self.ffn_weight_dim, self.dropout_rate)

        
    def call(self, input, attention_mask):
        """
        Args:
            input: tensor (batch_size, max_length, hidden_dim)
            attention_mask: tensor (batch_size, max_length)
        Returns:
            tensor (batch_size, max_length, hidden_dim)
        """    
        out1 = self.self_attention(input, attention_mask)
        out1 = self.layer_norm1(input + out1)

        out2 = self.ffn(out1)
        out2 = self.layer_norm2(out1 + out2)
        return out2


class Decoder(tf.keras.models.Model):
    """
    一層のDecoder
    Attributes:
        at_weight_dim: attention機構で使用する重みの次元 
        num_heads: multi head attentionのhead数
        ffn_weight_dim: 全結合層の重みの次元。embeddingの次元に一致させる
        dropout_rate: dropout層のパラメータ
    """

    def __init__(
        self,  
        at_weight_dim=512, 
        num_heads=8,
        ffn_weight_dim=256,
        dropout_rate=0.2,
        **kwargs
        ):

        super().__init__(**kwargs)
        self.at_weight_dim = at_weight_dim
        self.num_heads = num_heads
        self.ffn_weight_dim = ffn_weight_dim
        self.dropout_rate = dropout_rate

        self.self_attention = DecoderSelfAttention(
            self.at_weight_dim, self.num_heads)
        self.ed_attention = EncoderDecoderAttention(
            self.at_weight_dim, self.num_heads)
        self.ffn = FeedForwardNeuralBlock(self.ffn_weight_dim, self.dropout_rate)
        self.layer_norm1 = LayerNormalizer()
        self.layer_norm2 = LayerNormalizer()
        self.layer_norm3 = LayerNormalizer()


    def call(self, 
             decoder_input, 
             decoder_attention_mask, 
             encoder_output,
             encoder_attention_mask
             ):
        """
        Args:
            decoder_input: decoder側の入力tensor (batch_size, decoder_max_length, hidden_dim)
            decoder_attention_mask: tensor (batch_size, decoder_max_length)
            encoder_output: encoder側の最終出力tensor (batch_size, encoder_max_length, hidden_dim)
            encoder_attention_mask: tensor (batch_size, encoder_max_length)
        
        Returns:
            tensor (batch_size, decoder_max_length, hidden_dim)
        """
        
        out1 = self.self_attention(decoder_input, decoder_attention_mask)
        out1 = self.layer_norm1(decoder_input + out1)
        
        out2 = self.ed_attention(
            out1, decoder_attention_mask, encoder_output, encoder_attention_mask)
        out2 = self.layer_norm2(out1 + out2)

        out3 = self.ffn(out2)
        out3 = self.layer_norm3(out2 + out3)
        return out3


class Transformer(tf.keras.models.Model):
    """
    Attributes:
        encoder_num_vocabs: encoder側の語彙数
        decoder_num_vocabs: decoder側の語彙数
        hidden_dim: embeddingベクトル及びEncoder, Decoder層の出力ベクトルの次元
        at_weight_dim: attention機構で用いる重みの次元
        num_heads: multi head attentionのhead数
        dropout_rate: dropout層のパラメータ
        num_encoders: Encoder層を積み上げる個数
        num_decoders: Decoder層を積み上げる個数
    """

    def __init__(self,
                 encoder_num_vocabs,
                 decoder_num_vocabs,
                 hidden_dim=256,
                 at_weight_dim=512, 
                 num_heads=8,
                 dropout_rate=0.2, 
                 num_encoders=8,
                 num_decoders=8,
                 **kwargs
                 ):
        
        super().__init__(**kwargs)
        self.encoder_num_vocabs = encoder_num_vocabs
        self.decoder_num_vocabs = decoder_num_vocabs
        self.hidden_dim = hidden_dim
        self.at_weight_dim = at_weight_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.num_encoders = num_encoders
        self.num_decoders = num_decoders

        self.encoder_embedding_layer = tf.keras.layers.Embedding(
            encoder_num_vocabs, hidden_dim)
        self.decoder_embedding_layer = tf.keras.layers.Embedding(
            decoder_num_vocabs, hidden_dim)
        self.encoder_pe_layer = PositionalEncoder()
        self.decoder_pe_layer = PositionalEncoder()

        
        self.encoders_list = []
        self.decoders_list = []

        for _ in range(self.num_encoders):
            self.encoders_list.append(
                Encoder(at_weight_dim=at_weight_dim,
                        num_heads=num_heads,
                        ffn_weight_dim=hidden_dim,
                        dropout_rate=dropout_rate)
                )
            
        for _ in range(self.num_decoders):
            self.decoders_list.append(
                Decoder(at_weight_dim=at_weight_dim,
                        num_heads=num_heads,
                        ffn_weight_dim=hidden_dim,
                        dropout_rate=dropout_rate)
                )
            
        self.vocab_prob_layer = tf.keras.layers.Dense(
            decoder_num_vocabs, name="vocab_prob_layer", activation="softmax")
        
            
    def call(self, 
             encoder_input_ids, 
             encoder_attention_mask,
             decoder_input_ids,
             decoder_attention_mask
             ):
        """
        Args:
            encoder_input_ids: encoder側の入力token id tensor (batch_size, encoder_max_length)
            encoder_attention_mask: tensor (batch_size, encoder_max_length)
            decoder_input: decoder側の入力token id tensor (batch_size, decoder_max_length)
            decoder_attention_mask: tensor (batch_size, decoder_max_length)
        
        Returns:
            tensor (batch_size, decoder_max_length, hidden_dim)       
        """
        encoder_vec = self.encoder_embedding_layer(encoder_input_ids)
        encoder_vec = self.encoder_pe_layer(encoder_vec)

        for encoder in self.encoders_list:
            encoder_vec = encoder(encoder_vec, encoder_attention_mask)

        decoder_vec = self.decoder_embedding_layer(decoder_input_ids)
        decoder_vec = self.decoder_pe_layer(decoder_vec)

        for decoder in self.decoders_list:
            decoder_vec = decoder(
                decoder_vec, decoder_attention_mask, encoder_vec, encoder_attention_mask)
        
        vocab_prob = self.vocab_prob_layer(decoder_vec)
        
        return {"vocab_prob": vocab_prob, "last_hidden_state": decoder_vec}

### build model

In [24]:
def build_model():

    encoder_input_ids = tf.keras.layers.Input(
        shape=(encoder_max_len,), dtype=tf.int32, name="encoder_input_ids"
        )
    encoder_attention_mask = tf.keras.layers.Input(
        shape=(encoder_max_len,), dtype=tf.int32, name="encoder_attention_mask"
        )
    decoder_input_ids = tf.keras.layers.Input(
        shape=(decoder_max_len,), dtype=tf.int32, name="decoder_input_ids"
        ) 
    decoder_attention_mask = tf.keras.layers.Input(
        shape=(decoder_max_len,), dtype=tf.int32, name="decoder_attention_mask"
        )
    
    transformer = Transformer(
                 encoder_num_vocabs=encoder_num_char,
                 decoder_num_vocabs=decoder_num_char,
                 hidden_dim=64,
                 at_weight_dim=128, 
                 num_heads=4,
                 dropout_rate=0.2, 
                 num_encoders=4,
                 num_decoders=4,
                 )
    
    output = transformer(
        encoder_input_ids,  
        encoder_attention_mask,
        decoder_input_ids,
        decoder_attention_mask
        )
    
    model = tf.keras.Model(inputs=[encoder_input_ids, 
                                   encoder_attention_mask,
                                   decoder_input_ids,
                                   decoder_attention_mask],
                           outputs=output["vocab_prob"])
    
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  metrics=["acc"])
    
    return model

In [26]:
model = build_model()
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 encoder_input_ids (InputLayer)  [(None, 29)]        0           []                               
                                                                                                  
 encoder_attention_mask (InputL  [(None, 29)]        0           []                               
 ayer)                                                                                            
                                                                                                  
 decoder_input_ids (InputLayer)  [(None, 11)]        0           []                               
                                                                                                  
 decoder_attention_mask (InputL  [(None, 11)]        0           []                           

In [27]:
for weight in model.trainable_weights:
    print(weight.name)

transformer/embedding/embeddings:0
transformer/embedding_1/embeddings:0
transformer/encoder/encoder_self_attention/wq:0
transformer/encoder/encoder_self_attention/wk:0
transformer/encoder/encoder_self_attention/wv:0
transformer/encoder/encoder_self_attention/wo:0
transformer/encoder/layer_normalizer/scale:0
transformer/encoder/layer_normalizer/bias:0
transformer/encoder/layer_normalizer_1/scale:0
transformer/encoder/layer_normalizer_1/bias:0
transformer/encoder/feed_forward_neural_block/filter_layer/kernel:0
transformer/encoder/feed_forward_neural_block/filter_layer/bias:0
transformer/encoder/feed_forward_neural_block/output_layer/kernel:0
transformer/encoder/feed_forward_neural_block/output_layer/bias:0
transformer/encoder_1/encoder_self_attention_1/wq:0
transformer/encoder_1/encoder_self_attention_1/wk:0
transformer/encoder_1/encoder_self_attention_1/wv:0
transformer/encoder_1/encoder_self_attention_1/wo:0
transformer/encoder_1/layer_normalizer_2/scale:0
transformer/encoder_1/layer_n

## fit

In [28]:
history = model.fit(tr_ds, validation_data=val_ds, epochs=2)

Epoch 1/2
Epoch 2/2
