In [1]:
import tensorflow as tf
from tensorflow import keras

In [31]:
def qbert_model(vocab_size, num_layers, dff, d_model, num_heads, dropout, name = 'qbert'):

    input = tf.keras.layers.Input(shape = (None, ), name = 'inputs')

    padding_mask = tf.keras.Input(shape = (1, 1, None), name = 'padding_mask')

    embeddings = tf.keras.layers.Embedding(vocab_size, d_model)(input)
    embeddings *= tf.math.sqrt(tf.cast(d_model, tf.float32))
    print(embeddings.shape)
    pos_layer = tf.keras.layers.Po
    embeddings += tf.kears.layers.add(embeddings, pos_layer)
    outputs = tf.keras.layers.Dropout(rate = dropout)(embeddings)


#     for i in range(num_layers) :

#         outputs = encoder_layer(dff = dff
#                                 , d_model=d_model
#                                 , num_heads = num_heads
#                                 , dropout = dropout
#                                 , name = 'encoding_layer_{}'.format(i))([outputs, padding_mask])

    return tf.keras.Model(inputs = [input, padding_mask], outputs = outputs, name = name)


In [44]:
class PositionEmbedding(tf.keras.layers.Layer):
    """Creates a positional embedding.
    Example:
    ```python
    position_embedding = PositionEmbedding(max_length=100)
    inputs = tf.keras.Input((100, 32), dtype=tf.float32)
    outputs = position_embedding(inputs)
    ```
    Args:
    max_length: The maximum size of the dynamic sequence.
    initializer: The initializer to use for the embedding weights. Defaults to
      "glorot_uniform".
    seq_axis: The axis of the input tensor where we add the embeddings.
    Reference: This layer creates a positional embedding as described in
    [BERT: Pre-training of Deep Bidirectional Transformers for Language
    Understanding](https://arxiv.org/abs/1810.04805).
    """

    def __init__(self, max_length, initializer="glorot_uniform", seq_axis=1,  **kwargs):

        super(PositionEmbedding, self).__init__(**kwargs)
        
        if max_length is None:
            raise ValueError("`max_length` must be an Integer, not `None`.")
        
        self._max_length = max_length
        self._initializer = tf.keras.initializers.get(initializer)
        self._seq_axis = seq_axis

        
    def get_config(self):
        config = {
            "max_length": self._max_length,
            "initializer": tf.keras.initializers.serialize(self._initializer),
            "seq_axis": self._seq_axis,
        }
        base_config = super(PositionEmbedding, self).get_config()
    
        return dict(list(base_config.items()) + list(config.items()))

    
    def build(self, input_shape):
#         print(input_shape)
        dimension_list = input_shape.as_list()

        seq_length = dimension_list[self._seq_axis]
        width = dimension_list[-1]

        if self._max_length is not None:
            weight_sequence_length = self._max_length
        else:
            weight_sequence_length = seq_length

        self._position_embeddings = self.add_weight("embeddings", shape=[weight_sequence_length, width], initializer=self._initializer)

        super(PositionEmbedding, self).build(input_shape)

        
    def call(self, inputs):
#         print("CALL")
        print(inputs)
        input_shape = tf.shape(inputs)
        print(input_shape)
        print(input_shape[1])
        actual_seq_len = input_shape[self._seq_axis]
        position_embeddings = self._position_embeddings[:actual_seq_len, :]
        
        new_shape = [1 for _ in inputs.get_shape().as_list()]
        new_shape[self._seq_axis] = actual_seq_len
        new_shape[-1] = position_embeddings.get_shape().as_list()[-1]
        
        print(new_shape)
        
        position_embeddings = tf.reshape(position_embeddings, new_shape)
        
        return tf.broadcast_to(position_embeddings, input_shape)

In [45]:
PositionEmbedding(255)(keras.layers.Input((255, 50)))

Tensor("input_16:0", shape=(None, 255, 50), dtype=float32)
Tensor("position_embedding_11/Shape:0", shape=(3,), dtype=int32)
Tensor("position_embedding_11/strided_slice:0", shape=(), dtype=int32)
[1, <tf.Tensor 'position_embedding_11/strided_slice_1:0' shape=() dtype=int32>, 50]


<tf.Tensor 'position_embedding_11/Identity:0' shape=(None, 255, 50) dtype=float32>