In [1]:
import keras

Using TensorFlow backend.


In [7]:
from collections.abc import Callable,Sequence

In [8]:
from keras.layers import merge
from keras.layers.core import *
from keras.layers.recurrent import LSTM
from keras.models import *

In [9]:
INPUT_DIM = 32
TIME_STEPS = 20
SINGLE_ATTENTION_VECTOR = False
APPLY_ATTENTION_BEFORE_LSTM = False


In [87]:
class SelfAttention2DLayer(Layer):
    """加性相似度,最经典的注意力相似度机制,如果是在self attention中\
则该层有一个dim为Key_time_step的向量和一个(Key_dim,Key_time_step)的矩阵作为用于训练的参数

    .. math::  Similarity(Key) = v \cdot tanh(W_k\cdot Key)


如果不是在self attention中,则该层有一个dim为Key_time_step的向量和两个(Key_dim,Key_time_step)\
的矩阵作为用于训练的参数

    .. math::  Similarity(Key) = v \cdot tanh(W_k\cdot Key+W_q\cdot Query)
    """

    def __init__(self, similarity="additive",*,
                 kernel_size = None,
                 kernel_initializer='glorot_uniform',
                 wk_kernel_initializer='glorot_uniform',
                 wq_kernel_initializer='glorot_uniform',
                 **kwargs):
        if isinstance(similarity,Callable):
            self.similarity = similarity
        elif isinstance(similarity,str) and similarity in ("multiplicative","dot_product","additive"):
                self.similarity = similarity
        else:
            raise ValueError(
                    'similarity now only support "multiplicative","dot_product","additive",'
                    'and you can input a function as the similarity function!'
                                )
        if (isinstance(
            kernel_size,
            Sequence) and len(kernel_size) == 2) or kernel_size is None:
            self.kernel_size = kernel_size
        else:
            raise ValueError(
                    'kernel_size must be a Sequence with 2 int element')
            
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.wk_kernel_initializer = initializers.get(
            wk_kernel_initializer)
        self.wq_kernel_initializer = initializers.get(
            wq_kernel_initializer)
        super().__init__(**kwargs)

    def build(self, input_shape):
        if len(input_shape) != 3:
            raise ValueError('A additive weight layer should be called '
                             'by a (batch,time_step,dim)3D inputs.'
                             'Got ' + str(input_shape) + ' inputs.')
        time = input_shape[-2]
        dim = input_shape[-1]
        if self.similarity == "additive":
            
            if self.kernel_size is None:
                self.kernel_size = (time,time)
            r,d_a = self.kernel_size
            self.kernel = self.add_weight(name='kernel',
                                      shape=(r,d_a),
                                      initializer=self.kernel_initializer,
                                      trainable=True)
                
            self.wk_kernel = self.add_weight(
                name='wk_kernel',
                shape=(d_a, dim),
                initializer=self.wk_kernel_initializer,
            trainable=True)
        elif self.similarity == "multiplicative":
            self.kernel = self.add_weight(name='kernel',
                                          shape=(
                                              dim,dim),
                                          initializer=self.kernel_initializer,
                                          trainable=True)
        else:
            pass

        # Be sure to call this somewhere!
        super().build(input_shape)
    def multiplicative(self, Source):
        Source_t = K.permute_dimensions(Source, (0,2,1))
        s = K.dot(Source,self.kernel)
        print(s)
        sim = K.batch_dot(s,Source_t)
        print(sim)
        return sim
        
    def dot_product(self, Source):
        Source_t = K.permute_dimensions(Source, (0,2,1))
        sim = K.batch_dot(Source,Source_t)
        print(sim)
        return sim
        
    def additive(self, Source):
        Source_t = K.permute_dimensions(Source, (0,2,1))
        f_att = K.dot(self.wk_kernel,Source_t)
        f_att = K.permute_dimensions(f_att, (1,0,2))
        sim = K.dot(self.kernel,K.tanh(f_att))
        sim = K.permute_dimensions(sim, (1,0,2))
        print(sim)
        return sim

    def call(self, inputs):
        Source = inputs
        if isinstance(self.similarity,Callable):
            sim = self.similarity(Source)
        else:
            sim = getattr(self, self.similarity)(Source)
        sm = activations.softmax(sim)
        result = K.batch_dot(sm,Source)
        print(result)
        return result

    def compute_output_shape(self, input_shape):
        return input_shape


In [88]:
inputs = Input(shape=(TIME_STEPS, INPUT_DIM,))
attention_mul =  SelfAttention2DLayer(similarity='multiplicative',kernel_size=(5,6))(inputs)#MyLayer((20,32))(inputs)#
lstm_units = 32
attention_mul = LSTM(lstm_units, return_sequences=False)(attention_mul)
output = Dense(1, activation='sigmoid')(attention_mul)
m = Model(input=[inputs], output=output)

m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
print(m.summary())

Tensor("self_attention2d_layer_42/Reshape_2:0", shape=(?, 20, 32), dtype=float32)
Tensor("self_attention2d_layer_42/MatMul_1:0", shape=(?, 20, 20), dtype=float32)
Tensor("self_attention2d_layer_42/MatMul_2:0", shape=(?, 20, 32), dtype=float32)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_43 (InputLayer)        (None, 20, 32)            0         
_________________________________________________________________
self_attention2d_layer_42 (S (None, 20, 32)            1024      
_________________________________________________________________
lstm_12 (LSTM)               (None, 32)                8320      
_________________________________________________________________
dense_12 (Dense)             (None, 1)                 33        
Total params: 9,377
Trainable params: 9,377
Non-trainable params: 0
_________________________________________________________________
None




In [44]:
import numpy as np

In [337]:
k = np.random.rand(20,12)

In [338]:
q = np.random.rand(18,12)

In [355]:
wk = np.random.rand(12,18)

In [356]:
(k@wk).shape

(20, 18)

In [364]:
wq = np.random.rand(20,12)

In [365]:
(wq@q.T).shape

(20, 18)

In [350]:
(k@wk)+(q@wq)

ValueError: operands could not be broadcast together with shapes (20,18) (18,18) 

In [270]:
v =  np.random.rand(20,20)

In [271]:
(v.T@k).shape

(20, 2)

In [276]:
k = np.random.rand(20,2)

In [279]:
vv = np.random.rand(2,4)

In [281]:
(k@vv).shape

(20, 4)

In [322]:
w = np.random.rand(2,2)

In [324]:
q = np.random.rand(18,2)

In [333]:
s = k@w

In [334]:
s@q.T

array([[ 0.1121432 ,  0.35664993,  0.24638674,  0.18976187,  0.45051852,
         0.28007788,  0.30429968,  0.10896963,  0.37784277,  0.15818082,
         0.11349591,  0.15118776,  0.30723985,  0.40662203,  0.33915779,
         0.04253393,  0.41135932,  0.47588102],
       [ 0.48384752,  1.53861945,  1.05035995,  0.8302624 ,  1.94133307,
         1.22395786,  1.31929431,  0.47444006,  1.63372104,  0.68215898,
         0.49560182,  0.661576  ,  1.30174022,  1.740443  ,  1.45916821,
         0.18600621,  1.77159176,  2.04782356],
       [ 0.51202271,  1.62843195,  1.12817858,  0.86348186,  2.05759833,
         1.2748229 ,  1.38774718,  0.49644283,  1.72426204,  0.72230243,
         0.51669352,  0.68793441,  1.40886357,  1.86010041,  1.54957895,
         0.1935674 ,  1.87900573,  2.1741455 ],
       [ 0.11008194,  0.3500118 ,  0.23548987,  0.19205839,  0.44100711,
         0.2827332 ,  0.30190753,  0.10911737,  0.37265406,  0.15511266,
         0.11438   ,  0.15306079,  0.28961649,  0.392

In [325]:
(k@w@q.T).shape

(20, 18)

In [320]:
k.shape

(20, 2)

In [316]:
(k.T@w).shape

(2, 20)

In [304]:
s = k@w

In [305]:
s.shape

(20, 20)

In [306]:
(s@k)

array([[ 1.56967935,  1.13957991],
       [ 6.80108574,  4.9554269 ],
       [ 7.15955082,  5.19325489],
       [ 1.55519432,  1.13803288],
       [ 5.64839779,  4.12002862],
       [ 5.19543016,  3.80455297],
       [ 4.5430327 ,  3.30276969],
       [ 1.47712727,  1.0628364 ],
       [ 3.07430949,  2.24569032],
       [ 3.32585045,  2.40973565],
       [ 4.60755453,  3.34114439],
       [ 4.79723194,  3.48051961],
       [ 3.39030022,  2.46085421],
       [ 4.5407952 ,  3.28482748],
       [ 6.30656563,  4.58412114],
       [ 1.50355706,  1.08677645],
       [ 5.01366034,  3.65012031],
       [ 1.90026988,  1.37405081],
       [ 5.97611325,  4.35109981],
       [ 6.49339227,  4.72102345]])

In [336]:
(k@q.T).shape

(20, 18)

In [372]:
3!=3!=2

False

In [381]:
import keras.backend as K
t = K.ones((5,12, 3))
t1 = t[:, :6,:] + 1
t2 = t[:, 1:,:] - 1

In [382]:
t1.shape

TensorShape([Dimension(5), Dimension(6), Dimension(3)])