In [1]:
import tensorflow as tf
import numpy as np
import jax.numpy as jnp

In [6]:
def _compute_causal_mask(query, value=None):
    q_seq_length = tf.shape(query)[1]
    v_seq_length = q_seq_length if value is None else tf.shape(value)[1]
    return tf.linalg.band_part(  # creates a lower triangular matrix
        tf.ones((1, q_seq_length, v_seq_length), tf.bool), -1, 0
    )


x = np.random.uniform(size=(8, 3, 5))

In [7]:
_compute_causal_mask(x).numpy()

array([[[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]]])

In [9]:
def get_decoder_mask(self_attn_inputs):
    """Returns causal mask to apply for self-attention layer.

    Args:
      self_attn_inputs: Inputs to self attention layer to determine mask shape
    """
    len_s = tf.shape(self_attn_inputs)[1]
    bs = tf.shape(self_attn_inputs)[:1]
    mask = tf.cumsum(tf.eye(len_s, batch_shape=bs), 1)
    return mask


get_decoder_mask(x)[0]

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [1., 1., 0.],
       [1., 1., 1.]], dtype=float32)>

In [13]:
def causal_attention_mask(seq_len):
    """Creates a causal attention mask.

    Args:
      seq_len: The length of the sequence.

    Returns:
      A tensor of shape `[seq_len, seq_len]`, where each entry is 1 if the
      corresponding positions are causally related and 0 otherwise.
    """
    len_s = tf.shape(seq_len)[1]
    attention_mask = jnp.tril(jnp.ones((len_s, len_s)))
    return attention_mask


causal_attention_mask(x).shape

(3, 3)