In [1]:
from keras.layers import *
from keras.models import Model 
from keras.optimizers import Nadam
import keras.backend as K

Using TensorFlow backend.


In [2]:
# soft-attention

def unchanged_shape(input_shape):
    # function for Lambda layer 
    return input_shape

def soft_attention_alignment(input_1, input_2):
    """
    两输入为三维张量(bs, sl1, size) (bs, sl2, size)   (batch_size, seq_len, size)
    
    return (bs, sl1, size), (bs, sl2, size)
    """
    attention = Dot(axes=-1)([input_1, input_2])  # (bs, sl1, size)·(bs, sl2, size) ==> (bs, sl1, sl2)
    
    w_att_1 = Lambda(lambda x: K.softmax(x, axis=1), output_shape=unchanged_shape)(attention)  # (bs, sl1, sl2)
    w_att_2 = Permute((2, 1))(Lambda(lambda x: K.softmax(x, axis=2), 
                                     output_shape=unchanged_shape)(attention))  # (bs, sl2, sl1)
    
    in1_aligned = Dot(axes=1)([w_att_1, input_1])  # (bs, sl1, sl2)·(bs, sl1, size)  ==> (bs, sl2, size)
    in2_aligned = Dot(axes=1)([w_att_2, input_2])  # (bs, sl2, sl1)·(bs, sl2, size)  ==> (bs, sl1, size)

    return in1_aligned, in2_aligned   # (bs, sl2, size)  (bs, sl1, size)  与输入shape相反

In [3]:
def pool_corr(q1,q2,pool_way):
    if pool_way == 'max':
        pool = GlobalMaxPooling1D()
    elif pool_way == 'ave':
        pool = GlobalAveragePooling1D()
    else:
        raise RuntimeError("don't have this pool way")

    q1 = pool(q1)
    q2 = pool(q2)

    def norm_layer(x, axis=1):
        return (x - K.mean(x, axis=axis, keepdims=True)) / K.std(x, axis=axis, keepdims=True)
    q1 = Lambda(norm_layer)(q1)
    q2 = Lambda(norm_layer)(q2)
    
    def jaccard(x):
        return  x[0]*x[1]/(K.sum(x[0]**2,axis=1,keepdims=True)+
                           K.sum(x[1]**2,axis=1,keepdims=True)-
                           K.sum(K.abs(x[0]*x[1]),axis=1,keepdims=True))
    merged = Lambda(jaccard)([q1,q2])
    return merged

In [4]:
# 文本匹配ESIM

def esim():
    text_len = 20
    max_features = 20000
    
    q1 = Input(name='q1', shape=(text_len,))
    q2 = Input(name='q2', shape=(text_len,))

    embedding = Embedding(max_features, 100, input_length=text_len)

    bn = BatchNormalization()
    q1_embed = bn(embedding(q1))
    q1_embed = SpatialDropout1D(0.2)(q1_embed)
    q2_embed = bn(embedding(q2))
    q2_embed = SpatialDropout1D(0.2)(q2_embed)

    encode = Bidirectional(CuDNNLSTM(128, return_sequences=True), merge_mode='sum')
    q1_encoded = encode(q1_embed)
    q2_encoded = encode(q2_embed)

    q1_aligned, q2_aligned = soft_attention_alignment(q1_encoded, q2_encoded)

    q1_combined = Concatenate()([q1_encoded, q2_aligned, multiply([q1_encoded, q2_aligned])])
    q2_combined = Concatenate()([q2_encoded, q1_aligned, multiply([q2_encoded, q1_aligned])])

    compose = Bidirectional(CuDNNLSTM(128, return_sequences=True), merge_mode='sum')
    q1_compare = compose(q1_combined)
    q2_compare = compose(q2_combined)


    merged_ave = pool_corr(q1_compare,q2_compare,'ave')
    merged_max = pool_corr(q1_compare,q2_compare,'max')

    merged = Concatenate()([merged_max, merged_ave])

    dense = Dense(256, activation='relu')(merged)
    dense = Dense(64, activation='relu')(dense)
    out = Dense(1, activation='sigmoid')(dense)
    lr = 0.0008

    model = Model(inputs=[q1, q2], outputs=out)
    model.compile(optimizer=Nadam(lr=lr), loss='binary_crossentropy', metrics=['binary_crossentropy', 'accuracy'])
    return model

In [5]:
model = esim()
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
q1 (InputLayer)                 (None, 20)           0                                            
__________________________________________________________________________________________________
q2 (InputLayer)                 (None, 20)           0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 20, 100)      2000000     q1[0][0]                         
                                                                 q2[0][0]                         
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 20, 100)      400         embedding_1[0][0]                
          