**Q:** I want to learn the Transformer architecture in deep learning from the scratch. However, the mathematical part in the original paper is difficult for me to understand well. Could you give me some advice?

**A (by ChatGPT):** Understanding the Transformer architecture from scratch can be a challenging task, especially when dealing with the mathematical details presented in the original paper. Here are a few suggestions to help you grasp the concepts and overcome the difficulties:

1. Start with the basics: Familiarize yourself with the fundamental concepts and prerequisites of deep learning, such as neural networks, attention mechanisms, and sequence modeling. Ensure you have a solid understanding of these topics before diving into the Transformer architecture.

2. Break it down: Rather than trying to comprehend the entire architecture in one go, break it down into smaller components and focus on understanding each part individually. Begin with the self-attention mechanism, as it is a core component of the Transformer. Study how it calculates attention weights, computes weighted sums, and handles multiple attention heads.

3. Utilize visualizations: Visualizations can be helpful in gaining an intuitive understanding of complex concepts. Look for visual explanations or diagrams that illustrate the Transformer architecture and its components. These visuals can aid in visualizing the flow of information and the interactions between different parts of the model.

4. Explore online resources: Seek out alternative resources beyond the original paper. Numerous blog posts, tutorials, and videos provide explanations of the Transformer architecture in a more accessible and intuitive manner. Some popular resources include the "Illustrated Transformer" blog post by Jay Alammar and various YouTube tutorials on the subject.

5. Study code implementations: Reading and analyzing code implementations of the Transformer architecture can provide valuable insights. Look for open-source implementations available on platforms like GitHub, and try to follow the code to understand how the components and mathematical formulas are implemented in practice.

6. Seek explanations from others: Engage in discussion forums, online communities, or local study groups focused on deep learning. Pose questions, seek clarifications, and discuss the concepts with fellow learners or experts in the field. Getting different perspectives and explanations can enhance your understanding.

7. Practice with hands-on exercises: Reinforce your understanding by implementing the Transformer architecture yourself. Start with simpler tasks and datasets, such as language modeling or machine translation, and gradually work your way up to more complex applications. Hands-on experience will solidify your understanding of the architecture and its mathematical formulations.

Remember that understanding the Transformer architecture takes time and persistence. Be patient with yourself and focus on gradually building your knowledge and intuition. It's also important to note that mathematical rigor is not always necessary for applying and utilizing deep learning models effectively. As long as you have a good conceptual understanding and practical experience, you can leverage the Transformer architecture in various applications.

In [1]:
import tensorflow as tf
from tensorflow.keras import layers

# Define the Self-Attention layer
class SelfAttention(layers.Layer):
    def __init__(self, num_heads, key_dim):
        super(SelfAttention, self).__init__()
        self.num_heads = num_heads
        self.key_dim = key_dim

    def build(self, input_shape):
        self.d_model = input_shape[-1]
        assert self.d_model % self.num_heads == 0

        self.depth = self.d_model // self.num_heads

        self.query_dense = layers.Dense(units=self.d_model)
        self.key_dense = layers.Dense(units=self.d_model)
        self.value_dense = layers.Dense(units=self.d_model)

        self.combine_heads = layers.Dense(units=self.d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]

        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)

        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)

        attention_logits = tf.matmul(query, key, transpose_b=True)
        attention_logits = attention_logits / tf.math.sqrt(tf.cast(self.depth, tf.float32))
        attention_weights = tf.nn.softmax(attention_logits, axis=-1)

        attention_output = tf.matmul(attention_weights, value)
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        attention_output = tf.reshape(attention_output, (batch_size, -1, self.d_model))

        attention_output = self.combine_heads(attention_output)

        return attention_output

# Load the MNIST dataset
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()

# Preprocess the data
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Reshape the images
x_train = tf.reshape(x_train, shape=(x_train.shape[0], -1))
x_test = tf.reshape(x_test, shape=(x_test.shape[0], -1))

# Build the model with Self-Attention
inputs = layers.Input(shape=(784,))
x = SelfAttention(num_heads=8, key_dim=64)(inputs)
outputs = layers.Dense(units=784, activation="sigmoid")(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

# Compile the model
model.compile(optimizer="adam", loss="binary_crossentropy")

# Train the model
model.fit(
    x=x_train,
    y=x_train,
    batch_size=64,
    epochs=10,
    validation_data=(x_test, x_test)
)

# Generate reconstructed images
reconstructed_images = model.predict(x_test[:10])

# Visualize the original and reconstructed images
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=2, ncols=10, figsize=(10, 2))

for i in range(10):
    axes[0, i].imshow(x_test[i].reshape((28, 28)), cmap="gray")
    axes[0, i].axis("off")
    axes[1, i].imshow(reconstructed_images[i].reshape((28, 28)), cmap="gray")
    axes[1, i].axis("off")

plt.show()


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Epoch 1/10


ValueError: in user code:

    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/engine/training.py", line 1338, in train_function  *
        return step_function(self, iterator)
    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/engine/training.py", line 1322, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/engine/training.py", line 1303, in run_step  **
        outputs = model.train_step(data)
    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/engine/training.py", line 1081, in train_step
        loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/engine/training.py", line 1139, in compute_loss
        return self.compiled_loss(
    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/engine/compile_utils.py", line 265, in __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/losses.py", line 142, in __call__
        losses = call_fn(y_true, y_pred)
    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/losses.py", line 268, in call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/losses.py", line 2432, in binary_crossentropy
        backend.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
    File "/opt/miniconda3/envs/ds/lib/python3.8/site-packages/keras/src/backend.py", line 5809, in binary_crossentropy
        return tf.nn.sigmoid_cross_entropy_with_logits(

    ValueError: `logits` and `labels` must have the same shape, received ((None, None, 784) vs (None, 784)).
