In [1]:
import collections
import math
import string

from keras import constraints
from keras import initializers
from keras import regularizers
from keras.engine.base_layer import Layer
from keras.layers import einsum_dense
from keras.utils import tf_utils
import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export

In [2]:
def _build_proj_equation(free_dims, bound_dims, output_dims):
    """Builds an einsum equation for projections inside multi-head attention."""
    input_str = ""
    kernel_str = ""
    output_str = ""
    bias_axes = ""
    letter_offset = 0
    for i in range(free_dims):
        char = _CHR_IDX[i + letter_offset]
        input_str += char
        output_str += char

    letter_offset += free_dims
    for i in range(bound_dims):
        char = _CHR_IDX[i + letter_offset]
        input_str += char
        kernel_str += char

    letter_offset += bound_dims
    for i in range(output_dims):
        char = _CHR_IDX[i + letter_offset]
        kernel_str += char
        output_str += char
        bias_axes += char
    equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
    print("equation: ", equation)
    print("bias_axes: ", bias_axes)
    print("len(output_str) :", len(output_str))
    return equation, bias_axes, len(output_str)

In [15]:
query = tf.random.uniform(shape=[5, 100, 256])
_query_shape = tf.TensorShape(query.shape)

In [16]:
free_dims =_query_shape.rank - 1
print(free_dims)

2


In [17]:
_CHR_IDX = string.ascii_lowercase

In [18]:
einsum_equation, bias_axes, output_rank = _build_proj_equation(
    free_dims, bound_dims=1, output_dims=2)

equation:  abc,cde->abde
bias_axes:  de
len(output_str) : 4


In [20]:
def _build_attention_equation(rank, attn_axes):
    """Builds einsum equations for the attention computation.
    Query, key, value inputs after projection are expected to have the shape as:
    `(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
    `bs` and `<non-attention dims>` are treated as `<batch dims>`.
    The attention operations can be generalized:
    (1) Query-key dot product:
    `(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
    <key attention dims>, num_heads, channels) -> (<batch dims>,
    num_heads, <query attention dims>, <key attention dims>)`
    (2) Combination:
    `(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
    (<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
    <query attention dims>, num_heads, channels)`
    Args:
    rank: Rank of query, key, value tensors.
    attn_axes: List/tuple of axes, `[-1, rank)`,
      that attention will be applied to.
    Returns:
    Einsum equations.
    """
    target_notation = _CHR_IDX[:rank]
    # `batch_dims` includes the head dim.
    batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
    letter_offset = rank
    source_notation = ""
    for i in range(rank):
        if i in batch_dims or i == rank - 1:
          source_notation += target_notation[i]
        else:
          source_notation += _CHR_IDX[letter_offset]
          letter_offset += 1

    product_notation = "".join([target_notation[i] for i in batch_dims] +
                             [target_notation[i] for i in attn_axes] +
                             [source_notation[i] for i in attn_axes])
    dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
                                        product_notation)
    attn_scores_rank = len(product_notation)
    combine_equation = "%s,%s->%s" % (product_notation, source_notation,
                                    target_notation)
    
    print("dot_product_equation: ", dot_product_equation)
    print("combine_equation: ", combine_equation)
    print("len(attn_scores_rank) :", len(attn_scores_rank))
    
    return dot_product_equation, combine_equation, attn_scores_rank

In [24]:
_attention_axes = tuple(range(1, 4 - 2))
print(_attention_axes)

(1,)


In [None]:
# (Batch,seq,dim), (dim,num_head,head_size) -> (Batch,seq,n_head,head_size)
# (Batch,seq,dim), (dim,num_head,head_size) -> (Batch,seq,n_head,head_size)
# (Batch,seq,dim), (dim,num_head,head_size_v) -> (Batch,seq,n_head,head_size_v)

In [27]:
_build_attention_equation(4, _attention_axes) 
# self._dot_product_equation, self._combine_equation, attn_scores_rank 

('aecd,abcd->acbe', 'acbe,aecd->abcd', 4)

In [28]:
norm_axes = tuple(range(4 - len(_attention_axes), 4))
norm_axes # softmax on axis seq_key

(3,)

In [None]:
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
attention_scores = tf.einsum(self._dot_product_equation, key, query) #:
# 'aecd,abcd->acbe' (Batch,Seq_key,num_head,head_size), (Batch,Seq_query,num_head,head_size)
# -> (Batch,num_head,seq_query,seq_key) Weight matrix

attention_scores = self._masked_softmax(attention_scores, attention_mask) 

attention_output = tf.einsum(self._combine_equation, attention_scores, value) #:
# 'acbe,aecd->abcd' (Batch,num_head,seq_query,seq_key), (Batch,Seq_value,num_head,head_size_v)
# -> (Batch,Seq_query,num_head,head_size_v)

In [None]:
# (Batch,Seq_query,num_head,head_size_v), (Batch,num_head,head_size_v,dim) -> (Batch,seq,dim)