In [3]:
import numpy as np

In [84]:
a = np.array([[1,2, 4, 5], 
                [1, 3, 2, 4], 
                 [3, 2, 1, 4]])

b = np.array([0,1,0])

In [180]:
from keras.layers import Embedding, Input, Dense, Layer, TimeDistributed, Flatten
from keras.models import Model
from keras import backend as K

In [214]:
class Attention(Layer):
    
    def __init__(self, nb_head, size_per_head, **kwargs):
        self.nb_head = nb_head
        self.size_per_head = size_per_head
        self.output_dim = nb_head*size_per_head
        super(Attention, self).__init__(**kwargs)
        
    def build(self, input_shape):
        self.WQ = self.add_weight(name='WQ', 
                                  shape=(input_shape[0][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.WK = self.add_weight(name='WK', 
                                  shape=(input_shape[1][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.WV = self.add_weight(name='WV', 
                                  shape=(input_shape[2][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        super(Attention, self).build(input_shape)
        
        
    def call(self, x):
        Q_seq, K_seq, V_seq = x
        
        Q_seq = K.dot(Q_seq, self.WQ)
        Q_seq = K.reshape(Q_seq, (-1, K.shape(Q_seq)[1], self.nb_head, self.size_per_head))
        Q_seq = K.permute_dimensions(Q_seq, (0, 2, 1, 3))
        
        
        K_seq = K.dot(K_seq, self.WK)
        K_seq = K.reshape(K_seq, (-1, K.shape(K_seq)[1], self.nb_head, self.size_per_head))
        K_seq = K.permute_dimensions(K_seq, (0, 2, 1, 3))
        
        
        V_seq = K.dot(V_seq, self.WV)
        V_seq = K.reshape(V_seq, (-1, K.shape(V_seq)[1], self.nb_head, self.size_per_head))
        V_seq = K.permute_dimensions(V_seq, (0, 2, 1, 3))
        
        
        A = K.batch_dot(Q_seq, K_seq, axes=[3,3])/self.size_per_head**0.5
        

        A = K.softmax(A)



        O_seq = K.batch_dot(A, V_seq, axes=[3,2])
        O_seq = K.permute_dimensions(O_seq, (0,2,1,3))
        O_seq = K.reshape(O_seq, (-1, K.shape(O_seq)[1], self.output_dim))
        return O_seq
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][1], self.output_dim)      

In [215]:
b.shape

(3,)

In [216]:
inp = Input(shape=(4,))
emb = Embedding(input_dim = 6, output_dim = 512)(inp)

self_attn1 = Attention(nb_head = 8, size_per_head=64)([emb, emb, emb])


x = Flatten()(self_attn1)

y = Dense(units = 1, activation="softmax")(x)

In [217]:
model = Model(inputs = [inp], outputs = [y])

In [218]:
model.compile(loss = "binary_crossentropy", optimizer = "rmsprop")

In [219]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_35 (InputLayer)           (None, 4)            0                                            
__________________________________________________________________________________________________
embedding_35 (Embedding)        (None, 4, 512)       3072        input_35[0][0]                   
__________________________________________________________________________________________________
attention_33 (Attention)        (None, 4, 512)       786432      embedding_35[0][0]               
                                                                 embedding_35[0][0]               
                                                                 embedding_35[0][0]               
__________________________________________________________________________________________________
flatten_17

In [220]:
model.fit(x = [a], y = [b.reshape(-1,1)])

Epoch 1/1


<keras.callbacks.History at 0x13223aa58>