<a href="https://colab.research.google.com/github/Taaniya/Transformers-architecture/blob/main/Explore_outputs_in_causal_attention_mechanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Below function applies attention mechanism for encoder / decoder block in a transformer architecture, given query, key and value vectors and an appropriate masks (attention / padding)

For ease of understanding we'll explore the output without passing mask.

Note that the attention mask is only used for an input to a decoder layer, and not in the encoder layer. This is also called as look_ahead_mask as it masks the tokens ahead of the current token.

This function is used in Multi-head attention attention mechanism in both Encoder and Decoder blocks

In [3]:
import tensorflow as tf

In [72]:
def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
  q: query shape = (..., seq_len_q, depth)
  k: key shape = (..., seq_len_k, depth)
  v: value shape = (..., seq_len_v, depth_v)
  mask: float tensor with shape broadcastable to (..., seq_len_q, seq_len_k). Defaults to None

  Returns:
  output, attention_weights
  """
  matmul_qk = tf.matmul(q,k, transpose_b=True)         # (..., seq_len_q, seq_len_k)
  # print(f"matmul_qk: \n {matmul_qk}")
  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
  # print(f"scaled matmul_qk: \n {scaled_attention_logits}")

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9) 

  print(f"scaled_attention_logits: \n {scaled_attention_logits}")
  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)     # (..., seq_len_q, seq_len_k)
  output = tf.matmul(attention_weights, v)

  return output, attention_weights
     

In [73]:
def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)   # Set diagonal and all subdiagonals to zero, rest to 1 
  return mask

In [74]:
# Example, assume three token prefix received by decoder

prefix =  tf.constant([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [2.0, 3.0, 4.0, 5.0]])

In [75]:
prefix.shape

TensorShape([3, 4])

In [76]:
mask = create_look_ahead_mask(prefix.shape[0])
mask

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
       [0., 0., 1.],
       [0., 0., 0.]], dtype=float32)>

In [77]:
causal_attn_output, attn_weights = scaled_dot_product_attention(prefix, prefix, prefix, mask)

scaled_attention_logits: 
 [[ 1.5000000e+01 -9.9999994e+08 -1.0000000e+09]
 [ 3.5000000e+01  8.7000000e+01 -9.9999994e+08]
 [ 2.0000000e+01  4.8000000e+01  2.7000000e+01]]


In [78]:
causal_attn_output

<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[1., 2., 3., 4.],
       [5., 6., 7., 8.],
       [5., 6., 7., 8.]], dtype=float32)>

In [79]:
attn_weights

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [2.6102792e-23, 1.0000000e+00, 0.0000000e+00],
       [6.9143996e-13, 1.0000000e+00, 7.5825607e-10]], dtype=float32)>

Each row in the above matrix signifies a token. For each token, how much attention to pay for itself and its past tokens. 
Values in column indicates the amount of attention / weight to apply for other tokens in the given sequence.
Column 1, token 1, column 2 -> token 2 , column 3 -> token 3.

In [80]:
tf.matmul(attn_weights, prefix)

<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[1., 2., 3., 4.],
       [5., 6., 7., 8.],
       [5., 6., 7., 8.]], dtype=float32)>

See that the 1st and 2nd tokens only look at its previous tokens, and lose on information in the next tokens since there'no attention paid to them. 

This techniques works well during training to avoid the model attending to future tokens, but poses a limitation during inference, when a sequence is passed to the Decoder , which only creates representations of each token in given sequence based only on tokens only preceding it.

#### References 
* [Unified Language Model Pre-training for Natural Language Understanding and Generation, by Dong et al, 2019](https://arxiv.org/pdf/1905.03197.pdf)
* [Exploring the limits of Transfer Learning with a Unified
Text-to-Text Transformer, 2020](https://arxiv.org/pdf/1910.10683.pdf)