Permalink
Switch branches/tags
Nothing to show
Find file Copy path
c989e11 Sep 21, 2017
2 contributors

Users who have contributed to this file

@mahnerak @MartinXPN
52 lines (36 sloc) 1.44 KB
# from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from keras import backend as K
from keras.layers import Layer
from keras.layers.wrappers import TimeDistributed
from helpers import compute_mask, softmax
class QuestionPooling(Layer):
def __init__(self, **kwargs):
super(QuestionPooling, self).__init__(**kwargs)
self.supports_masking = True
def compute_output_shape(self, input_shape):
assert(isinstance(input_shape, list) and len(input_shape) == 5)
input_shape = input_shape[0]
B, Q, H = input_shape
return (B, H)
def build(self, input_shape):
assert(isinstance(input_shape, list) and len(input_shape) == 5)
input_shape = input_shape[0]
B, Q, H_ = input_shape
H = H_ // 2
def call(self, inputs, mask=None):
assert(isinstance(inputs, list) and len(inputs) == 5)
uQ, WQ_u, WQ_v, v, VQ_r = inputs
uQ_mask = mask[0] if mask is not None else None
ones = K.ones_like(K.sum(uQ, axis=1, keepdims=True)) # (B, 1, 2H)
s_hat = K.dot(uQ, WQ_u)
s_hat += K.dot(ones, K.dot(WQ_v, VQ_r))
s_hat = K.tanh(s_hat)
s = K.dot(s_hat, v)
s = K.batch_flatten(s)
a = softmax(s, mask=uQ_mask, axis=1)
rQ = K.batch_dot(uQ, a, axes=[1, 1])
return rQ
def compute_mask(self, input, mask=None):
return None