In [1]:
from keras.layers import *
import keras.backend as K
from keras.engine.topology import Layer

Using TensorFlow backend.


In [2]:
# soft-attention

def unchanged_shape(input_shape):
    # function for Lambda layer 
    return input_shape

def soft_attention_alignment(input_1, input_2):
    """
    两输入为三维张量(bs, ml1, size) (bs, ml2, size)
    
    return (bs, ml2, size), (bs, ml1, size)
    """
    attention = Dot(axes=-1)([input_1, input_2])  # (bs, ml1, size)·(bs, ml2, size) ==> (bs, ml1, ml2)
    
    w_att_1 = Lambda(lambda x: K.softmax(x, axis=1), output_shape=unchanged_shape)(attention)  # (bs, ml1, ml2)
    w_att_2 = Permute((2, 1))(Lambda(lambda x: K.softmax(x, axis=2), 
                                     output_shape=unchanged_shape)(attention))  # (bs, ml2, ml1)
    
    in1_aligned = Dot(axes=1)([w_att_1, input_1])  # (bs, ml1, ml2)·(bs, ml1, size)  ==> (bs, ml2, size)
    in2_aligned = Dot(axes=1)([w_att_2, input_2])  # (bs, ml2, ml1)·(bs, ml2, size)  ==> (bs, ml1, size)

    return in1_aligned, in2_aligned   # (bs, ml2, size)  (bs, ml1, size)  与输入shape相反

# # 测试
# a = K.ones((3, 5, 7))
# b = K.ones((3, 20, 7))
# res1, res2 = soft_attention_alignment(a, b)
# print(K.int_shape(res1), K.int_shape(res2))
# # >>>(3, 20, 7) (3, 5, 7)

In [3]:
# co-attention

def co_attention(input_1, input_2):
    """
    两输入为三维张量(bs, ml, size) (bs, ml, size)  (要求步长相同)
    
    return 
    """
    dense_w = TimeDistributed(Dense(1))
    atten = Lambda(lambda x: K.batch_dot(x[0], x[1]))([input_1, Permute((2, 1))(input_2)]) 
    # (bs, ml, size), (bs, size, ml)  ==>  (bs, ml, ml)

    atten_1 = dense_w(atten)   # (bs, ml, 1)
    atten_1 = Flatten()(atten_1)   # (bs, ml)
    atten_1 = Activation('softmax')(atten_1)
    atten_1 = Reshape((1, -1))(atten_1)   # (bs, 1, ml)
    
    atten_2 = dense_w(Permute((2, 1))(atten))   # (bs, ml, 1)
    atten_2 = Flatten()(atten_2)
    atten_2 = Activation('softmax')(atten_2)
    atten_2 = Reshape((1, -1))(atten_2)   # (bs, 1, ml)
    
    out1 = Lambda(lambda x: K.batch_dot(x[0], x[1]))([atten_1, input_1])  # (bs, 1, ml)·(bs, ml, size)  ==> (bs, 1, size)
    out1 = Flatten()(out1)   # (bs, size)
    out2 = Lambda(lambda x: K.batch_dot(x[0], x[1]))([atten_2, input_2])  # (bs, 1, ml)·(bs, ml, size)  ==> (bs, 1, size)
    out2 = Flatten()(out2)   # (bs, size)
    
    return out1, out2  # (bs, size), (bs, size)

# # 测试
# a = K.ones((3, 5, 7))
# b = K.ones((3, 5, 7))
# res1, res2 = co_attention(a, b)
# print(K.int_shape(res1), K.int_shape(res2))
# # >>>(3, 7) (3, 7)

In [4]:
# 层级attention

def dot_product(x, kernel):
    """
    Wrapper for dot product operation, in order to be compatible with both
    Theano and Tensorflow
    Args:
        x (): input
        kernel (): weights
    Returns:
    """
    if K.backend() == 'tensorflow':   # 默认添加最后一个维度  return => (samples, steps, feaures)
        return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)   # (samples, steps, features), (feaures, feaures, 1)
    else:
        return K.dot(x, kernel)


class AttentionWithContext(Layer):
    """
    Attention operation, with a context/query vector, for temporal data.
    Supports Masking.  # 支持masking
    Follows the work of Yang et al. [https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf]
    "Hierarchical Attention Networks for Document Classification"
    by using a context vector to assist the attention  # 用上下文向量支持attention
    
    # Input shape
        3D tensor with shape: `(samples, steps, features)`.  
    # Output shape
        2D tensor with shape: `(samples, features)`.  
    
    How to use:
    Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
    The dimensions are inferred based on the output shape of the RNN.  # dimension based on GRU shape
    Example:
        model.add(LSTM(64, return_sequences=True))
        model.add(AttentionWithContext())  # [None， features]
        # next add a Dense layer (for classification/regression) or whatever...
    """

    def __init__(self,
                 W_regularizer=None, u_regularizer=None, b_regularizer=None,
                 W_constraint=None, u_constraint=None, b_constraint=None,
                 bias=True, **kwargs):

        self.init = initializers.get('glorot_uniform')

        self.W_regularizer = regularizers.get(W_regularizer)
        self.u_regularizer = regularizers.get(u_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.u_constraint = constraints.get(u_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias
        super(AttentionWithContext, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3  # 输入长度必须为3 

        self.W = self.add_weight((input_shape[-1], input_shape[-1],),  # 相同 features  (size, size)
                                 initializer=self.init,  # initializer.get('glorot_uniform')
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,  # 正则化
                                 constraint=self.W_constraint)   # 约束
        if self.bias:
            self.b = self.add_weight((input_shape[-1],),  # (size,)
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)

        self.u = self.add_weight((input_shape[-1],),   # (size, )
                                 initializer=self.init,
                                 name='{}_u'.format(self.name),
                                 regularizer=self.u_regularizer,
                                 constraint=self.u_constraint)

        super(AttentionWithContext, self).build(input_shape)  # .build(input_shape) 

    def call(self, x):
        uit = dot_product(x, self.W)  # (bs,ml,size)(size,size)  ==> (bs, ml, size)

        if self.bias:
            uit += self.b   # (bs, ml, size)

        uit = K.tanh(uit)  # (bs, ml, size)  得到uit
        ait = dot_product(uit, self.u)   # (bs, ml, size), (size, 1)  ==> (bs, ml, 1)  => # (bs, ml)

        a = K.softmax(ait)  # (bs, ml)
        a = K.expand_dims(a)     # (bs, ml)
        weighted_input = x * a   #  (bs, ml, size) * ((bs, ml, 1) => (bs, ml, size)

        return K.sum(weighted_input, axis=1)   # (bs, size)

    def compute_output_shape(self, input_shape):
        return input_shape[0], input_shape[-1]   # 不用括号

In [5]:
# multi-head self-attention  来源：https://spaces.ac.cn/archives/4765

class Position_Embedding(Layer):  # 位置embedding
    
    def __init__(self, size=None, mode='sum', **kwargs):
        self.size = size #必须为偶数   可以自定义位置embedding的维度
        self.mode = mode
        super(Position_Embedding, self).__init__(**kwargs)
        
    def call(self, x):   # (bs, ml, size)
        if (self.size == None) or (self.mode == 'sum'):
            self.size = int(x.shape[-1])   # size
        batch_size,seq_len = K.shape(x)[0],K.shape(x)[1]   # bs, ml
        position_j = 1. / K.pow(10000., 2*K.arange(self.size/2, dtype='float32') / self.size)   # (size/2,)
        position_j = K.expand_dims(position_j, 0)  # (1, size/2)
        position_i = K.cumsum(K.ones_like(x[:,:,0]), 1) - 1   #K.arange不支持变长，只好用这种方法生成  (bs, ml)
        position_i = K.expand_dims(position_i, 2)   # (bs, ml, 1)
        position_ij = K.dot(position_i, position_j)  # (bs, ml, 1) · (1, size/2)  ==>  (bs, ml, size/2)
        position_ij = K.concatenate([K.cos(position_ij), K.sin(position_ij)], 2)   # (bs, ml, size)
        if self.mode == 'sum':
            return position_ij + x   # (bs, ml, size)
        elif self.mode == 'concat':
            return K.concatenate([position_ij, x], 2)   # (bs, ml, size*2)
        
    def compute_output_shape(self, input_shape):
        if self.mode == 'sum':
            return input_shape
        elif self.mode == 'concat':
            return (input_shape[0], input_shape[1], input_shape[2]+self.size)


# 多头自注意力
class Attention(Layer):

    def __init__(self, nb_head, size_per_head, mask_right=False, **kwargs):
        self.nb_head = nb_head   # 注意力头数
        self.size_per_head = size_per_head   # 每个注意力头的大小
        self.output_dim = nb_head*size_per_head   # 输出维度
        self.mask_right = mask_right   # 是否mask
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.WQ = self.add_weight(name='WQ', 
                                  shape=(input_shape[0][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)   # (size, att_dim)
        self.WK = self.add_weight(name='WK', 
                                  shape=(input_shape[1][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)   # (size, att_dim)
        self.WV = self.add_weight(name='WV', 
                                  shape=(input_shape[2][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)   # (size, att_dim)
        super(Attention, self).build(input_shape)
        
    def Mask(self, inputs, seq_len, mode='mul'):
        if seq_len == None:
            return inputs
        else:
            mask = K.one_hot(seq_len[:,0], K.shape(inputs)[1])   # (bs, ml)
            mask = 1 - K.cumsum(mask, 1)  # (bs, ml)
            for _ in range(len(inputs.shape)-2):
                mask = K.expand_dims(mask, 2)  # (bs, ml, 1)
            if mode == 'mul':
                return inputs * mask  
            if mode == 'add': 
                return inputs - (1 - mask) * 1e12
                
    def call(self, x):
        #如果只传入Q_seq,K_seq,V_seq，那么就不做Mask
        #如果同时传入Q_seq,K_seq,V_seq,Q_len,V_len，那么对多余部分做Mask
        if len(x) == 3:
            Q_seq,K_seq,V_seq = x
            Q_len,V_len = None,None
        elif len(x) == 5:
            Q_seq,K_seq,V_seq,Q_len,V_len = x
        #对Q、K、V做线性变换
        Q_seq = K.dot(Q_seq, self.WQ)  # (bs, ml, size) (size, att_dim)  ==>  (bs, ml, att_dim)
        Q_seq = K.reshape(Q_seq, (-1, K.shape(Q_seq)[1], self.nb_head, self.size_per_head))  # (bs, ml, nb_head, size_ph)
        Q_seq = K.permute_dimensions(Q_seq, (0,2,1,3))   # (bs, nb_head, ml, size_ph)
        K_seq = K.dot(K_seq, self.WK)  
        K_seq = K.reshape(K_seq, (-1, K.shape(K_seq)[1], self.nb_head, self.size_per_head))
        K_seq = K.permute_dimensions(K_seq, (0,2,1,3))
        V_seq = K.dot(V_seq, self.WV)
        V_seq = K.reshape(V_seq, (-1, K.shape(V_seq)[1], self.nb_head, self.size_per_head))
        V_seq = K.permute_dimensions(V_seq, (0,2,1,3))
        #计算内积，然后mask，然后softmax
        A = K.batch_dot(Q_seq, K_seq, axes=[3,3]) / self.size_per_head**0.5   # (bs, nb_head, ml, ml)
        A = K.permute_dimensions(A, (0,3,2,1))   # (bs, ml, ml, nb_head)
        A = self.Mask(A, V_len, 'add')
        A = K.permute_dimensions(A, (0,3,2,1))   # (bs, nb_head, ml, ml)
        if self.mask_right:
            ones = K.ones_like(A[:1, :1])
            mask = (ones - K.tf.matrix_band_part(ones, -1, 0)) * 1e12
            A = A - mask
        A = K.softmax(A)   # (bs, nb_head, ml, ml)
        #输出并mask
        O_seq = K.batch_dot(A, V_seq, axes=[3,2])  # (bs,nb_head,ml,ml) (bs,nb_head,ml,size_ph) => (bs,nb_head,ml,size_ph)
        O_seq = K.permute_dimensions(O_seq, (0,2,1,3))   # (bs,ml,nb_head,size_ph)
        O_seq = K.reshape(O_seq, (-1, K.shape(O_seq)[1], self.output_dim))  # (bs,ml,att_dim)
        O_seq = self.Mask(O_seq, Q_len, 'mul')
        return O_seq
        
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][1], self.output_dim)