In [1]:
from keras.layers import Layer
import keras.backend as K
from keras import initializers
import tensorflow as tf

In [3]:
class PositionEmbedding(Layer):
    def __init__(
            self,
            input_dim,
            output_dim,
            merge_mode='add',
            hierarchical=None,
            embeddings_initializer='zeros',
            custom_position_ids=False,
            **kwargs
    ):
        super(PositionEmbedding, self).__init__(**kwargs)
        self.input_dim = input_dim  # input dimensionality max_position
        self.output_dim = output_dim  # ouput dimensionality embedding_size
        self.merge_mode = merge_mode  
        self.hierarchical = hierarchical
        self.embeddings_initializer = initializers.get(embeddings_initializer)
        self.custom_position_ids = custom_position_ids


In [4]:
def build(self, input_shape):
        super(PositionEmbedding, self).build(input_shape)
        self.embeddings = self.add_weight(
            name='embeddings',
            shape=(self.input_dim, self.output_dim),
            initializer=self.embeddings_initializer
        )

In [13]:
def call(self, inputs):
        input_shape = K.shape(inputs)
        batch_size, seq_len = input_shape[0], input_shape[1]
        if self.custom_position_ids:  
            inputs, position_ids = inputs
            if K.dtype(position_ids) != 'int32':
                position_ids = K.cast(position_ids, 'int32')
        else:
             position_ids = K.arange(0, seq_len, dtype='int32')[None]  
        if self.hierarchical:
            alpha = 0.4 if self.hierarchical is True else self.hierarchical
            embeddings = self.embeddings - alpha * self.embeddings[:1]
            embeddings = embeddings / (1 - alpha)
            embeddings_x = K.gather(embeddings, position_ids // self.input_dim)
            embeddings_y = K.gather(embeddings, position_ids % self.input_dim)
            pos_embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y
        else:
            if self.custom_position_ids:  
                pos_embeddings = K.gather(self.embeddings, position_ids)
            else:
                 position_ids = K.arange(0, seq_len, dtype='int32')[None] 
            if self.merge_mode == 'add':  
                return inputs + pos_embeddings
            elif self.merge_mode == 'mul':  
                return inputs * pos_embeddings
            else:
                if not self.custom_position_ids:
                    pos_embeddings = K.tile(pos_embeddings, [batch_size, 1, 1])
                return K.concatenate([inputs, pos_embeddings])  

In [14]:
def compute_output_shape(self, input_shape):
        if self.custom_position_ids:
            input_shape = input_shape[0]

        if self.merge_mode in ['add', 'mul']:
            return input_shape
        else:
            return input_shape[:2] + (input_shape[2] + self.output_dim,)

In [15]:
def get_config(self):
        config = {
            'input_dim': self.input_dim,
            'output_dim': self.output_dim,
            'merge_mode': self.merge_mode,
            'hierarchical': self.hierarchical,
            'embeddings_initializer':
                initializers.serialize(self.embeddings_initializer),
            'custom_position_ids': self.custom_position_ids,
        }
        base_config = super(PositionEmbedding, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))