In [38]:
import tensorflow as tf
import numpy as np

In [39]:
tf.__version__

'2.9.1'

In [40]:
input_embeddings = [[
    "Salut", "comment", "ca", "va", "?",
]]

output_embeddings = [[
    "<START>", "Hi", "how", "are", "you", "?",
]]
print(input_embeddings)
print(output_embeddings)

[['Salut', 'comment', 'ca', 'va', '?']]
[['<START>', 'Hi', 'how', 'are', 'you', '?']]


In [41]:
def get_vocabulary(sequences):

    token_to_info = {}

    for sequence in sequences:
        for word in sequence:
            if word not in token_to_info:
                token_to_info[word] = len(token_to_info)
    return token_to_info

input_voc = get_vocabulary(input_embeddings)
output_voc = get_vocabulary(output_embeddings)

input_voc["<START>"] = len(input_voc)
input_voc["<END>"] = len(input_voc)
input_voc["<PAD>"] = len(input_voc)

output_voc["<END>"] = len(output_voc)
output_voc["<PAD>"] = len(output_voc)

print(input_voc)
print(output_voc)

{'Salut': 0, 'comment': 1, 'ca': 2, 'va': 3, '?': 4, '<START>': 5, '<END>': 6, '<PAD>': 7}
{'<START>': 0, 'Hi': 1, 'how': 2, 'are': 3, 'you': 4, '?': 5, '<END>': 6, '<PAD>': 7}


In [42]:
def sequences_to_int(sequences, voc):
    for sequence in sequences:
        for s, word in enumerate(sequence):
            sequence[s] = voc[word]
    return(np.array(sequences))

input_seq = sequences_to_int(input_embeddings, input_voc)
output_seq = sequences_to_int(output_embeddings, output_voc)

print(input_seq)
print(output_seq)

[[0 1 2 3 4]]
[[0 1 2 3 4 5]]


In [43]:
class EmbeddingLayer(tf.keras.layers.Layer):

    def __init__(self, nb_token, **kwargs):
        self.nb_token = nb_token
        super(**kwargs).__init__()

    def build(self, input_shape):
        self.word_embedding = tf.keras.layers.Embedding(
            self.nb_token, 256,
        )
        super().build(input_shape)

    def call(self, x):
        embed = self.word_embedding(x)
        return embed


class ScaledDotProductAttention(tf.keras.layers.Layer):

    def __init__(self, **kwargs):
        super(**kwargs).__init__()

    def build(self, input_shape):
        self.query_layer = tf.keras.layers.Dense(256)
        self.value_layer = tf.keras.layers.Dense(256)
        self.key_layer = tf.keras.layers.Dense(256)
        super().build(input_shape)

    def call(self, x):
        Q = self.query_layer(x)
        K = self.key_layer(x)
        V = self.value_layer(x)
        QK = tf.matmul(Q, K, transpose_b=True)
        QK = QK / tf.math.sqrt(256.)
        softmax_QK = tf.nn.softmax(QK, axis=-1)
        attention = tf.matmul(softmax_QK, V)
        # print("Shape Q", Q.shape)
        # print("Shape K", K.shape)
        # print("Shape V", V.shape)
        # print("Shape QK", QK.shape)
        # print("Shape softmax", softmax_QK.shape)
        # print("Shape attention", attention.shape)
        return attention

def test():
    layer_input = tf.keras.Input(shape=(5))
    embedding = EmbeddingLayer(nb_token=5)(layer_input)
    attention = ScaledDotProductAttention()(embedding)
    model = tf.keras.Model(layer_input, attention)
    #model.summary()
    return model

m_test = test()
out = m_test(input_seq)
print(out.shape)

(1, 5, 256)


In [81]:
class MultiHeadAttention(tf.keras.layers.Layer):

    def __init__(self, dim=256, nb_head=8, **kwargs):
        self.dim = 256
        self.head_dim = 256 // 8
        self.nb_head = nb_head
        super(**kwargs).__init__()

    def build(self, input_shape):
        self.query_layer = tf.keras.layers.Dense(256)
        self.value_layer = tf.keras.layers.Dense(256)
        self.key_layer = tf.keras.layers.Dense(256)
        self.out_proj = tf.keras.layers.Dense(256)
        super().build(input_shape)

    def mask_softmax(self, x, mask):
        x_expe = tf.math.exp(x)
        x_expe_masked = x_expe * mask
        x_expe_sum = tf.reduce_sum(x_expe_masked, axis = -1)
        x_expe_sum = tf.expand_dims(x_expe_sum, axis=-1)
        softmax = x_expe_masked / x_expe_sum
        return softmax

    def call(self, x, mask = None):

        in_query, in_key, in_value = x

        Q = self.query_layer(in_query)
        K = self.key_layer(in_key)
        V = self.value_layer(in_value)

        batch_size = tf.shape(Q)[0]
        Q_seq_len = tf.shape(Q)[1]
        K_seq_len = tf.shape(K)[1]
        V_seq_len = tf.shape(V)[1]

        Q = tf.reshape(Q, [batch_size, Q_seq_len, self.nb_head, self.head_dim])
        K = tf.reshape(K, [batch_size, K_seq_len, self.nb_head, self.head_dim])
        V = tf.reshape(V, [batch_size, V_seq_len, self.nb_head, self.head_dim])

        Q = tf.transpose(Q, [0, 2, 1, 3])
        K = tf.transpose(K, [0, 2, 1, 3])
        V = tf.transpose(V, [0, 2, 1, 3])

        Q = tf.reshape(Q, [batch_size * self.nb_head, Q_seq_len, self.head_dim])
        K = tf.reshape(K, [batch_size * self.nb_head, K_seq_len, self.head_dim])
        V = tf.reshape(V, [batch_size * self.nb_head, V_seq_len, self.head_dim])

        # Scaled dot product attention
        QK = tf.matmul(Q, K, transpose_b=True)
        QK = QK / tf.math.sqrt(256.)

        if mask is not None:
            QK = QK * mask
            softmax_QK = self.mask_softmax(QK, mask)
        else:
            softmax_QK = tf.nn.softmax(QK, axis=-1)

        attention = tf.matmul(softmax_QK, V)
        attention = tf.reshape(attention, [batch_size, self.nb_head, Q_seq_len, self.head_dim])
        attention = tf.transpose(attention, [0, 2, 1, 3])

        # Concat
        attention = tf.reshape(attention, [batch_size, Q_seq_len, self.nb_head * self.head_dim])

        out_attention = self.out_proj(attention)

        return out_attention

def test():
    layer_input = tf.keras.Input(shape=(6))
    embedding = EmbeddingLayer(nb_token=6)(layer_input)

    mask = tf.sequence_mask(tf.range(6) + 1, 6)
    mask = tf.cast(mask, tf.float32)
    mask = tf.expand_dims(mask, axis=0)
    multi_attention = MultiHeadAttention()((embedding,embedding, embedding), mask =mask)

    model = tf.keras.Model(layer_input, multi_attention)
    model.summary()
    return model

m_test = test()
out = m_test(output_seq)
print(out.shape)

Model: "model_51"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_94 (InputLayer)          [(None, 6)]          0           []                               
                                                                                                  
 embedding_layer_93 (EmbeddingL  (None, 6, 256)      1536        ['input_94[0][0]']               
 ayer)                                                                                            
                                                                                                  
 multi_head_attention_20 (Multi  (None, 6, 256)      263168      ['embedding_layer_93[0][0]',     
 HeadAttention)                                                   'embedding_layer_93[0][0]',     
                                                                  'embedding_layer_93[0][0]

In [82]:
class EncoderLayer(tf.keras.layers.Layer):

    def __init__(self, **kwargs):
        super(**kwargs).__init__()

    def build(self, input_shape):
        self.multi_head_attention = MultiHeadAttention()
        self.norm = tf.keras.layers.LayerNormalization()
        self.dense_out = tf.keras.layers.Dense(256)
        super().build(input_shape)

    def call(self, x):
        attention = self.multi_head_attention((x, x, x))
        post_attention = self.norm(x + attention)
        x = self.dense_out(post_attention)
        enc_output = self.norm(x + post_attention)
        return enc_output

class Encoder(tf.keras.layers.Layer):

    def __init__(self, nb_encoder, **kwargs):
        self.nb_encoder = nb_encoder
        super(**kwargs).__init__()

    def build(self, input_shape):
        self.encoder_layers = []
        for nb in range(self.nb_encoder):
            self.encoder_layers.append(
                EncoderLayer()
            )
        super().build(input_shape)

    def call(self, x):

        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
        return x

def test():
    layer_input = tf.keras.Input(shape=(5))
    embedding = EmbeddingLayer(nb_token=5)(layer_input)
    enc_output = Encoder(nb_encoder=6)(embedding)
    model = tf.keras.Model(layer_input, enc_output)
    model.summary()
    return model

m_test = test()
out = m_test(input_seq)
print(out.shape)

Model: "model_52"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_95 (InputLayer)       [(None, 5)]               0         
                                                                 
 embedding_layer_94 (Embeddi  (None, 5, 256)           1280      
 ngLayer)                                                        
                                                                 
 encoder_30 (Encoder)        (None, 5, 256)            1976832   
                                                                 
Total params: 1,978,112
Trainable params: 1,978,112
Non-trainable params: 0
_________________________________________________________________
(1, 5, 256)


In [83]:
class DecoderLayer(tf.keras.layers.Layer):

    def __init__(self, **kwargs):
        super(**kwargs).__init__()

    def build(self, input_shape):
        self.multi_head_self_attention = MultiHeadAttention()
        self.multi_head_enc_attention = MultiHeadAttention()
        self.norm = tf.keras.layers.LayerNormalization()
        self.proj_output = tf.keras.layers.Dense(256)
        super().build(input_shape)

    def call(self, x):
        enc_output, output_embedding, mask = x
        self_attention = self.multi_head_self_attention((output_embedding, output_embedding, output_embedding), mask)
        post_self_att = self.norm(output_embedding + self_attention)
        enc_attention = self.multi_head_enc_attention((post_self_att, enc_output, enc_output)) # Pas sur de l'ordre
        post_enc_attention = self.norm(enc_attention + post_self_att)
        proj_out = self.proj_output(post_enc_attention)
        dec_output = self.norm(proj_out + post_enc_attention)
        return dec_output

class Decoder(tf.keras.layers.Layer):

    def __init__(self, nb_decoder, **kwargs):
        self.nb_decoder = nb_decoder
        super(**kwargs).__init__()

    def build(self, input_shape):
        self.decoder_layers = []
        for nb in range(self.nb_decoder):
            self.decoder_layers.append(
                DecoderLayer()
            )
        super().build(input_shape)

    def call(self, x):

        enc_out, output_embedding, mask = x
        dec_output = output_embedding
        for decoder_layer in self.decoder_layers:
            dec_output = decoder_layer((enc_out, dec_output, mask))
        return dec_output

def test():
    input_token = tf.keras.Input(shape=(5))
    output_token = tf.keras.Input(shape=(6))

    # Retrieve embedding
    input_embedding = EmbeddingLayer(nb_token=5)(input_token)
    output_embedding = EmbeddingLayer(nb_token=6)(output_token)

    # Encoder
    enc_output = Encoder(nb_encoder=6)(input_embedding)

    # mask
    mask = tf.sequence_mask(tf.range(6) + 1, 6)
    mask = tf.cast(mask, tf.float32)
    mask = tf.expand_dims(mask, axis=0)

    # Decoder
    dec_output = Decoder(nb_decoder=6)((enc_output, output_embedding, mask))

    model = tf.keras.Model([input_token, output_token], dec_output)
    model.summary()
    return model

m_test = test()
out = m_test((input_seq, output_seq))
print(out.shape)

Model: "model_53"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_96 (InputLayer)          [(None, 5)]          0           []                               
                                                                                                  
 embedding_layer_95 (EmbeddingL  (None, 5, 256)      1280        ['input_96[0][0]']               
 ayer)                                                                                            
                                                                                                  
 input_97 (InputLayer)          [(None, 6)]          0           []                               
                                                                                                  
 encoder_31 (Encoder)           (None, 5, 256)       1976832     ['embedding_layer_95[0][0]