In [2]:
import collections
import logging
import os
import pathlib
import re
import string
import sys
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [3]:
def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)


In [4]:
def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)


In [5]:
def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights


In [6]:
class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    print("Inside 'MultiHeadAttention' class...")
    batch_size = tf.shape(q)[0]

    print()
    print("The shape of 'q' is " + str(q.shape))
    print("The shape of 'k' is " + str(k.shape))
    print("The shape of 'v' is " + str(v.shape))
    
    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)
    
    print()
    print("After passing 'q', 'k', 'v' through densely connected layers....")
    print("The shape of 'q' is " + str(q.shape))
    print("The shape of 'k' is " + str(k.shape))
    print("The shape of 'v' is " + str(v.shape))

 
    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    print()
    print("After splitting the heads....")
    print("The shape of 'q' is " + str(q.shape))
    print("The shape of 'k' is " + str(k.shape))
    print("The shape of 'v' is " + str(v.shape))

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    
    print()
    print("The shape of 'attention_weights' is " + str(attention_weights.shape))


    print("The shape of 'scaled_attention' is " + str(scaled_attention.shape))
    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
    
    print()
    print("After transposing....")
    print("The shape of 'scaled_attention' is " + str(scaled_attention.shape))
    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
    
    print()
    print("The shape of 'concat_attention' is " + str(concat_attention.shape))
    
    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
    print()
    print("The shape of 'output' is " + str(output.shape))

    return output, attention_weights


In [7]:
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
sample_sentence = tf.random.uniform((1, 9, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(v=sample_sentence, k=sample_sentence, q=sample_sentence, mask=None)


Inside 'MultiHeadAttention' class...

The shape of 'q' is (1, 9, 512)
The shape of 'k' is (1, 9, 512)
The shape of 'v' is (1, 9, 512)

After passing 'q', 'k', 'v' through densely connected layers....
The shape of 'q' is (1, 9, 512)
The shape of 'k' is (1, 9, 512)
The shape of 'v' is (1, 9, 512)

After splitting the heads....
The shape of 'q' is (1, 8, 9, 64)
The shape of 'k' is (1, 8, 9, 64)
The shape of 'v' is (1, 8, 9, 64)

The shape of 'attention_weights' is (1, 8, 9, 9)
The shape of 'scaled_attention' is (1, 8, 9, 64)

After transposing....
The shape of 'scaled_attention' is (1, 9, 8, 64)

The shape of 'concat_attention' is (1, 9, 512)

The shape of 'output' is (1, 9, 512)


In [10]:
out.shape, attn.shape

(TensorShape([1, 12, 512]), TensorShape([1, 8, 12, 9]))

In [18]:
sample_query = np.arange(1*9*512).reshape((1, 9, 512)) + 1
sample_query = tf.convert_to_tensor(sample_query)
print(sample_query)

tf.Tensor(
[[[   1    2    3 ...  510  511  512]
  [ 513  514  515 ... 1022 1023 1024]
  [1025 1026 1027 ... 1534 1535 1536]
  ...
  [3073 3074 3075 ... 3582 3583 3584]
  [3585 3586 3587 ... 4094 4095 4096]
  [4097 4098 4099 ... 4606 4607 4608]]], shape=(1, 9, 512), dtype=int64)


In [19]:
sample_query = tf.reshape(sample_query, (1, 9, 8, 64))
sample_query = tf.transpose(sample_query, perm=[0, 2, 1, 3])

In [20]:
print(sample_query)

tf.Tensor(
[[[[   1    2    3 ...   62   63   64]
   [ 513  514  515 ...  574  575  576]
   [1025 1026 1027 ... 1086 1087 1088]
   ...
   [3073 3074 3075 ... 3134 3135 3136]
   [3585 3586 3587 ... 3646 3647 3648]
   [4097 4098 4099 ... 4158 4159 4160]]

  [[  65   66   67 ...  126  127  128]
   [ 577  578  579 ...  638  639  640]
   [1089 1090 1091 ... 1150 1151 1152]
   ...
   [3137 3138 3139 ... 3198 3199 3200]
   [3649 3650 3651 ... 3710 3711 3712]
   [4161 4162 4163 ... 4222 4223 4224]]

  [[ 129  130  131 ...  190  191  192]
   [ 641  642  643 ...  702  703  704]
   [1153 1154 1155 ... 1214 1215 1216]
   ...
   [3201 3202 3203 ... 3262 3263 3264]
   [3713 3714 3715 ... 3774 3775 3776]
   [4225 4226 4227 ... 4286 4287 4288]]

  ...

  [[ 321  322  323 ...  382  383  384]
   [ 833  834  835 ...  894  895  896]
   [1345 1346 1347 ... 1406 1407 1408]
   ...
   [3393 3394 3395 ... 3454 3455 3456]
   [3905 3906 3907 ... 3966 3967 3968]
   [4417 4418 4419 ... 4478 4479 4480]]

  [[ 385  

In [21]:
print(sample_query[0][0])

tf.Tensor(
[[   1    2    3    4    5    6    7    8    9   10   11   12   13   14
    15   16   17   18   19   20   21   22   23   24   25   26   27   28
    29   30   31   32   33   34   35   36   37   38   39   40   41   42
    43   44   45   46   47   48   49   50   51   52   53   54   55   56
    57   58   59   60   61   62   63   64]
 [ 513  514  515  516  517  518  519  520  521  522  523  524  525  526
   527  528  529  530  531  532  533  534  535  536  537  538  539  540
   541  542  543  544  545  546  547  548  549  550  551  552  553  554
   555  556  557  558  559  560  561  562  563  564  565  566  567  568
   569  570  571  572  573  574  575  576]
 [1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038
  1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052
  1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066
  1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080
  1081 1082 1083 1084 1085 1086 1087 10

In [15]:
tf.matmul(sample_query, sample_query, transpose_b=True)

<tf.Tensor: shape=(1, 8, 9, 9), dtype=int64, numpy=
array([[[[     89440,    1154400,    2219360,    3284320,    4349280,
             5414240,    6479200,    7544160,    8609120],
         [   1154400,   18996576,   36838752,   54680928,   72523104,
            90365280,  108207456,  126049632,  143891808],
         [   2219360,   36838752,   71458144,  106077536,  140696928,
           175316320,  209935712,  244555104,  279174496],
         [   3284320,   54680928,  106077536,  157474144,  208870752,
           260267360,  311663968,  363060576,  414457184],
         [   4349280,   72523104,  140696928,  208870752,  277044576,
           345218400,  413392224,  481566048,  549739872],
         [   5414240,   90365280,  175316320,  260267360,  345218400,
           430169440,  515120480,  600071520,  685022560],
         [   6479200,  108207456,  209935712,  311663968,  413392224,
           515120480,  616848736,  718576992,  820305248],
         [   7544160,  126049632,  244555104,

In [17]:
tf.matmul(sample_query[0][0], tf.transpose(sample_query[0][0]))

<tf.Tensor: shape=(9, 9), dtype=int64, numpy=
array([[     89440,    1154400,    2219360,    3284320,    4349280,
           5414240,    6479200,    7544160,    8609120],
       [   1154400,   18996576,   36838752,   54680928,   72523104,
          90365280,  108207456,  126049632,  143891808],
       [   2219360,   36838752,   71458144,  106077536,  140696928,
         175316320,  209935712,  244555104,  279174496],
       [   3284320,   54680928,  106077536,  157474144,  208870752,
         260267360,  311663968,  363060576,  414457184],
       [   4349280,   72523104,  140696928,  208870752,  277044576,
         345218400,  413392224,  481566048,  549739872],
       [   5414240,   90365280,  175316320,  260267360,  345218400,
         430169440,  515120480,  600071520,  685022560],
       [   6479200,  108207456,  209935712,  311663968,  413392224,
         515120480,  616848736,  718576992,  820305248],
       [   7544160,  126049632,  244555104,  363060576,  481566048,
         60

In [8]:
# As I mentioned, "queries" can be in different language from "keys" or "values."
# * They are supposed to be different in translation tasks. 

# In this case you compare "quries" in the target language, with the "keys" in the original language. 
# And after that you reweight "values" in the original language. 

# Usually, the numbef or "queries" is different from that of "keys" or "values." because 
# translated sentences usually have different number of tokens. 

# Let's see an example where the number of input sentence is 9 and that of the translated sentence is 12. 
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
sample_sentence_source_lang = tf.random.uniform((1, 9, 512))  # (batch_size, encoder_sequence, d_model)
sample_sentence_target_lang = tf.random.uniform((1, 12, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(v=sample_sentence_source_lang, k=sample_sentence_source_lang, q=sample_sentence_target_lang, mask=None)

# In the results below, you can see that you reweight the "values" in the original sentence with a (12, 9) sized matrix
# in each head, and the the size of the resulting 'scaled_attention' is (12, 64) in each head. 


Inside 'MultiHeadAttention' class...

The shape of 'q' is (1, 12, 512)
The shape of 'k' is (1, 9, 512)
The shape of 'v' is (1, 9, 512)

After passing 'q', 'k', 'v' through densely connected layers....
The shape of 'q' is (1, 12, 512)
The shape of 'k' is (1, 9, 512)
The shape of 'v' is (1, 9, 512)

After splitting the heads....
The shape of 'q' is (1, 8, 12, 64)
The shape of 'k' is (1, 8, 9, 64)
The shape of 'v' is (1, 8, 9, 64)

The shape of 'attention_weights' is (1, 8, 12, 9)
The shape of 'scaled_attention' is (1, 8, 12, 64)

After transposing....
The shape of 'scaled_attention' is (1, 12, 8, 64)

The shape of 'concat_attention' is (1, 12, 512)

The shape of 'output' is (1, 12, 512)
