<a href="https://colab.research.google.com/github/allanbatista/classificacao-de-produtos-no-e-commerce/blob/master/codigo/notebooks/Transformer_Codificacao_Posicional.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Codificação Posicional (Posicional Encoding)

In [32]:
import tensorflow as tf

def posicional_encoding(x):
    """
    https://github.com/tensorflow/tensor2tensor/pull/177

    PE(positions, 2i) = sin(posisions * 1 / ( 10000 ^ ( 2i / embedding_dimension) ))
    PE(positions, 2i+1) = cos(posisions * 1 / ( 10000 ^ ( 2i / embedding_dimension) ))
    """
    input_shape = tf.shape(x)
    batch_size, seq_length, emb_dim = tf.cast(input_shape, dtype=tf.float32)

    sequence_positions = tf.cumsum(tf.ones(shape=[batch_size, seq_length], dtype=tf.float32), 1) - 1 # shape=(batch_size, seq_length)
    sequence_positions = tf.expand_dims(sequence_positions, 2) # shape=(batch_size, seq_length, 1)

    const = 1. / tf.pow(10000., 2. * tf.range(emb_dim // 2, dtype=tf.float32) / emb_dim) # shape=(emb_dim / 2, 0)
    const = tf.cast(tf.expand_dims(const, 0), dtype=tf.float32) # shape=(1, emb_dim / 2)

    angles = tf.matmul(sequence_positions, const) #shape=(batch_size, seq_length, emb_dim / 2)
    posisions = tf.concat([tf.math.cos(angles), tf.math.sin(angles)], 2) #shape=(batch_size, seq_length, emb_dim)

    return x + posisions

In [33]:
sample_sequence = tf.random.uniform(shape=[2, 5, 2])
sample_sequence

<tf.Tensor: shape=(2, 5, 2), dtype=float32, numpy=
array([[[0.26484156, 0.8003231 ],
        [0.8221879 , 0.5932431 ],
        [0.5634736 , 0.8112178 ],
        [0.90830576, 0.74487853],
        [0.4941579 , 0.58787477]],

       [[0.99294484, 0.21522999],
        [0.11550105, 0.40101182],
        [0.40586925, 0.91038096],
        [0.18660271, 0.83135617],
        [0.90797865, 0.7912141 ]]], dtype=float32)>

In [34]:
posicional_encoding(sample_sequence)

<tf.Tensor: shape=(2, 5, 2), dtype=float32, numpy=
array([[[ 1.2648416 ,  0.8003231 ],
        [ 1.3624902 ,  1.4347141 ],
        [ 0.14732677,  1.7205153 ],
        [-0.08168674,  0.88599855],
        [-0.1594857 , -0.16892773]],

       [[ 1.9929448 ,  0.21522999],
        [ 0.6558033 ,  1.2424829 ],
        [-0.01027757,  1.8196783 ],
        [-0.8033898 ,  0.9724762 ],
        [ 0.25433505,  0.03441161]]], dtype=float32)>