<a href="https://colab.research.google.com/github/JSJeong-me/AI-Innovation-2024/blob/main/Transformer/5-1-Multi-Head-Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
# Implementing the Scaled-Dot Product Attention
from tensorflow import math, matmul, reshape, shape, transpose, cast, float32
from tensorflow.keras.layers import Dense, Layer
from tensorflow.keras.activations import softmax # Import softmax from tf.keras.activations

# Implementing the Scaled-Dot Product Attention
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        # Scoring the queries against the keys after transposing the latter, and scaling
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))

        # Apply mask to the attention scores
        if mask is not None:
            scores += -1e9 * mask

        # Computing the weights by a softmax operation
        weights = softmax(scores) # Use softmax from tf.keras.activations

        # Computing the attention by a weighted sum of the value vectors
        return matmul(weights, values)

# ... (rest of the code remains the same)

# Implementing the Multi-Head Attention
class MultiHeadAttention(Layer):
    def __init__(self, h, d_k, d_v, d_model, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.attention = DotProductAttention()  # Scaled dot product attention
        self.heads = h  # Number of attention heads to use
        self.d_k = d_k  # Dimensionality of the linearly projected queries and keys
        self.d_v = d_v  # Dimensionality of the linearly projected values
        self.d_model = d_model  # Dimensionality of the model
        self.W_q = Dense(d_k)  # Learned projection matrix for the queries
        self.W_k = Dense(d_k)  # Learned projection matrix for the keys
        self.W_v = Dense(d_v)  # Learned projection matrix for the values
        self.W_o = Dense(d_model)  # Learned projection matrix for the multi-head output

    def reshape_tensor(self, x, heads, flag):
        if flag:
            # Tensor shape after reshaping and transposing: (batch_size, heads, seq_length, -1)
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], heads, -1))
            x = transpose(x, perm=(0, 2, 1, 3))
        else:
            # Reverting the reshaping and transposing operations: (batch_size, seq_length, d_k)
            x = transpose(x, perm=(0, 2, 1, 3))
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], self.d_k))
        return x

    def call(self, queries, keys, values, mask=None):
        # Rearrange the queries to be able to compute all heads in parallel
        q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, True)

        # Compute the multi-head attention output using the reshaped queries, keys and values

        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

        # Rearrange the keys to be able to compute all heads in parallel
        k_reshaped = self.reshape_tensor(self.W_k(keys), self.heads, True)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

        # Rearrange the values to be able to compute all heads in parallel
        v_reshaped = self.reshape_tensor(self.W_v(values), self.heads, True)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

        # Compute the multi-head attention output using the reshaped queries, keys and values
        o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, mask=mask, d_k=self.d_k) # Pass d_k as keyword argument
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

        # Rearrange back the output into concatenated form
        output = self.reshape_tensor(o_reshaped, self.heads, False)
        # Resulting tensor shape: (batch_size, input_seq_length, d_v)

        # Apply one final linear projection to the output to generate the multi-head attention
        # Resulting tensor shape: (batch_size, input_seq_length, d_model)
        return self.W_o(output)

In [7]:
from numpy import random

input_seq_length = 5  # Maximum length of the input sequence
h = 8  # Number of self-attention heads
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
d_model = 512  # Dimensionality of the model sub-layers' outputs
batch_size = 64  # Batch size from the training process

queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))

multihead_attention = MultiHeadAttention(h, d_k, d_v, d_model)
print(multihead_attention(queries, keys, values))

tf.Tensor(
[[[-6.31234720e-02  1.18290134e-01  1.13987789e-01 ... -6.44244924e-02
   -1.37954816e-01 -2.35938966e-01]
  [-6.16301782e-02  1.18591234e-01  1.15305856e-01 ... -6.20990470e-02
   -1.39237195e-01 -2.35218495e-01]
  [-6.20680116e-02  1.18751526e-01  1.13772482e-01 ... -6.22441210e-02
   -1.39342666e-01 -2.35152274e-01]
  [-6.18725643e-02  1.17973648e-01  1.13956407e-01 ... -6.29143566e-02
   -1.38948768e-01 -2.34578684e-01]
  [-6.08854219e-02  1.17912218e-01  1.13320045e-01 ... -5.95226660e-02
   -1.40971050e-01 -2.34650612e-01]]

 [[-1.97930187e-02  1.00369148e-01 -7.84399062e-02 ... -2.75027975e-02
   -7.85642043e-02 -2.54923165e-01]
  [-1.90120991e-02  9.56457704e-02 -7.70042837e-02 ... -2.85565779e-02
   -7.76824057e-02 -2.51348257e-01]
  [-1.88432019e-02  9.84462798e-02 -8.09082910e-02 ... -3.05076148e-02
   -7.95387030e-02 -2.55302340e-01]
  [-1.81312840e-02  9.61948708e-02 -7.79727623e-02 ... -2.90170442e-02
   -7.90269375e-02 -2.54960388e-01]
  [-1.70641411e-02  1.00