In [1]:
import collections
import logging
import os
import pathlib
import re
import string
import sys
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [2]:
def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)


In [3]:
def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)


In [4]:
def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights


In [5]:
def print_out(q, k, v):
  temp_out, temp_attn = scaled_dot_product_attention(
      q, k, v, None)
  print ('Attention weights are:')
  print (temp_attn)
  print ('Output is:')
  print (temp_out)


In [6]:
np.set_printoptions(suppress=True)

temp_k = tf.constant([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]], dtype=tf.float32)  # (4, 3)

temp_v = tf.constant([[   1,0],
                      [  10,0],
                      [ 100,5],
                      [1000,6]], dtype=tf.float32)  # (4, 2)

# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)


Attention weights are:
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)


In [70]:
class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    print("Inside 'MultiHeadAttention' class...")
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)
    
    print()
    print("The shape of 'q' is " + str(q.shape))
    print("The shape of 'k' is " + str(k.shape))
    print("The shape of 'v' is " + str(v.shape))

 
    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    print()
    print("After splitting the heads....")
    print("The shape of 'q' is " + str(q.shape))
    print("The shape of 'k' is " + str(k.shape))
    print("The shape of 'v' is " + str(v.shape))

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    
    print()
    print("The shape of 'attention_weights' is " + str(attention_weights.shape))
    print("The shape of 'scaled_attention' is " + str(scaled_attention.shape))
    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
    
    print()
    print("After transposing....")
    print("The shape of 'scaled_attention' is " + str(scaled_attention.shape))
    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
    
    print()
    print("The shape of 'concat_attention' is " + str(concat_attention.shape))
    
    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
    print()
    print("The shape of 'output' is " + str(output.shape))

    return output, attention_weights


In [71]:
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 9, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)


Inside 'MultiHeadAttention' class...

The shape of 'q' is (1, 9, 512)
The shape of 'k' is (1, 9, 512)
The shape of 'v' is (1, 9, 512)

After splitting the heads....
The shape of 'q' is (1, 8, 9, 64)
The shape of 'k' is (1, 8, 9, 64)
The shape of 'v' is (1, 8, 9, 64)

The shape of 'attention_weights' is (1, 8, 9, 9)
The shape of 'scaled_attention' is (1, 8, 9, 64)

After transposing....
The shape of 'scaled_attention' is (1, 9, 8, 64)

The shape of 'concat_attention' is (1, 9, 512)

The shape of 'output' is (1, 9, 512)


In [54]:
out.shape, attn.shape

(TensorShape([1, 9, 512]), TensorShape([1, 8, 9, 9]))

In [55]:
sample_query = np.arange(1*9*512).reshape((1, 9, 512)) + 1

In [56]:
sample_query.shape

(1, 9, 512)

In [62]:
sample_query

<tf.Tensor: shape=(1, 8, 9, 64), dtype=int64, numpy=
array([[[[   1,    2,    3, ...,   62,   63,   64],
         [  65,   66,   67, ...,  126,  127,  128],
         [ 129,  130,  131, ...,  190,  191,  192],
         ...,
         [ 385,  386,  387, ...,  446,  447,  448],
         [ 449,  450,  451, ...,  510,  511,  512],
         [ 513,  514,  515, ...,  574,  575,  576]],

        [[ 577,  578,  579, ...,  638,  639,  640],
         [ 641,  642,  643, ...,  702,  703,  704],
         [ 705,  706,  707, ...,  766,  767,  768],
         ...,
         [ 961,  962,  963, ..., 1022, 1023, 1024],
         [1025, 1026, 1027, ..., 1086, 1087, 1088],
         [1089, 1090, 1091, ..., 1150, 1151, 1152]],

        [[1153, 1154, 1155, ..., 1214, 1215, 1216],
         [1217, 1218, 1219, ..., 1278, 1279, 1280],
         [1281, 1282, 1283, ..., 1342, 1343, 1344],
         ...,
         [1537, 1538, 1539, ..., 1598, 1599, 1600],
         [1601, 1602, 1603, ..., 1662, 1663, 1664],
         [1665, 1

In [63]:
sample_query = tf.convert_to_tensor(sample_query)

In [64]:
sample_query = tf.reshape(sample_query, (1, 8, 9, 64))

In [65]:
sample_query.shape

TensorShape([1, 8, 9, 64])

In [66]:
sample_query[0][0]

<tf.Tensor: shape=(9, 64), dtype=int64, numpy=
array([[  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
         27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
         40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
         53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64],
       [ 65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
        117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128],
       [129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154,
        155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168,

In [67]:
sample_query[0][1]

<tf.Tensor: shape=(9, 64), dtype=int64, numpy=
array([[ 577,  578,  579,  580,  581,  582,  583,  584,  585,  586,  587,
         588,  589,  590,  591,  592,  593,  594,  595,  596,  597,  598,
         599,  600,  601,  602,  603,  604,  605,  606,  607,  608,  609,
         610,  611,  612,  613,  614,  615,  616,  617,  618,  619,  620,
         621,  622,  623,  624,  625,  626,  627,  628,  629,  630,  631,
         632,  633,  634,  635,  636,  637,  638,  639,  640],
       [ 641,  642,  643,  644,  645,  646,  647,  648,  649,  650,  651,
         652,  653,  654,  655,  656,  657,  658,  659,  660,  661,  662,
         663,  664,  665,  666,  667,  668,  669,  670,  671,  672,  673,
         674,  675,  676,  677,  678,  679,  680,  681,  682,  683,  684,
         685,  686,  687,  688,  689,  690,  691,  692,  693,  694,  695,
         696,  697,  698,  699,  700,  701,  702,  703,  704],
       [ 705,  706,  707,  708,  709,  710,  711,  712,  713,  714,  715,
         716,