Skip to content

Commit

Permalink
Come back to Numpy due to hard time consuming TensorFlow slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
rolczynski committed May 19, 2020
1 parent 42be7e8 commit 401d3f6
Showing 1 changed file with 11 additions and 27 deletions.
38 changes: 11 additions & 27 deletions aspect_based_sentiment_analysis/alignment.py
Expand Up @@ -87,31 +87,13 @@ def merge_input_attentions(
) -> tf.Tensor:
""" Merge input sub-token attentions into token attentions. """

@tf.function
def map_fn(*args, **kwargs):
return tf.map_fn(*args, **kwargs)

def aggregate(x, fun):
new = tf.stack([fun([x[..., i] for i in a], axis=0)
if len(a) > 1 else x[..., a[0]]
for a in alignment], axis=-1)
def aggregate(a, fun):
n = len(alignment)
new = np.zeros(n)
for i in range(n):
new[i] = fun(a[alignment[i]])
return new

def apply_along_axis(fun, x, axis):
other, = {2, 3} - {axis}
perm = [other, 0, 1, axis]
# Unfortunately, the map_fn iterates over 0 dim rather than
# apply along the axis, so we have to transpose the matrix
# `x` back and forth.
x = tf.transpose(x, perm)
x = map_fn(fun, x, parallel_iterations=len(x))
# Put the 0 dim in the last or next to last place.
# Others dimensions are unchanged: [1, 2, 3].
perm = [1, 2, 3]
perm.insert(other, 0)
x = tf.transpose(x, perm)
return x

attentions = tf.reduce_sum(attentions, axis=[0, 1], keepdims=True) \
if reduce else attentions
# For attention _to_ a split-up word, we sum up the attention weights
Expand All @@ -120,8 +102,10 @@ def apply_along_axis(fun, x, axis):
# mean over rows, and sum over columns of split tokens according to the
# alignment. Note that if we go along the axis, the aggregation
# impacts to orthogonal dimension.
fun_to = partial(aggregate, fun=tf.reduce_mean)
attentions = apply_along_axis(fun_to, attentions, axis=2)
fun_from = partial(aggregate, fun=tf.reduce_sum)
attentions = apply_along_axis(fun_from, attentions, axis=3)
attentions = attentions.numpy()
attention_to = partial(aggregate, fun=np.mean)
attentions = np.apply_along_axis(attention_to, 2, attentions)
attention_from = partial(aggregate, fun=np.sum)
attentions = np.apply_along_axis(attention_from, 3, attentions)
attentions = tf.convert_to_tensor(attentions)
return attentions

0 comments on commit 401d3f6

Please sign in to comment.