<a href="https://colab.research.google.com/github/antonemking/at-challenges/blob/challenge-0/at_sr_challenge0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Challenge 0: Understanding Attention
# This notebook is designed to help you understand the key concepts of the paper
# "Attention Is All You Need". Your task is to implement the core components of
# the attention mechanism and experiment with it.

## 1. Introduction
"""
The "Attention Is All You Need" paper introduced the Transformer model, which
has since revolutionized natural language processing (NLP) and other machine
learning fields. At the heart of this model is the attention mechanism, which
allows the model to focus on different parts of the input sequence dynamically.

In this challenge, you'll implement a simplified version of the attention
mechanism and apply it to some test data. Follow the TODO sections to fill
in the code.
"""

## 2. Step 1: Setup & Imports
# First, let's import the necessary libraries for our task.

import numpy as np
import torch
import torch.nn.functional as F

# TODO: Explore any additional libraries that might be useful for visualizing the results.

## 3. Step 2: Defining the Attention Mechanism
"""
In the attention mechanism, we compute the attention scores between
a query and a set of key-value pairs. The attention scores are
used to compute a weighted sum of the values based on the relevance of the keys to the query.

The formula for dot-product attention is:
    Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
"""

# Here, Q = query, K = keys, and V = values.
# d_k is the dimension of the key/query vectors.

def attention(query, key, value, d_k):
    """
    Computes the dot-product attention.

    Args:
        query: A tensor of shape (batch_size, seq_len, d_k)
        key: A tensor of shape (batch_size, seq_len, d_k)
        value: A tensor of shape (batch_size, seq_len, d_v)
        d_k: Dimension of the keys/queries.

    Returns:
        output: The result of the attention mechanism.
    """

    # TODO: Implement the attention mechanism.
    # Step 1: Compute the dot product of the query and key, and scale by sqrt(d_k)
    attention_scores = torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(d_k)

    # TODO: Step 2: Apply softmax to get the attention weights
    attention_weights = F.softmax(attention_scores, dim=-1)

    # TODO: Step 3: Use the attention weights to compute a weighted sum of the values
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

# TODO: Test the function with random inputs to ensure it works as expected.

## 4. Step 3: Applying Attention
"""
Now that you have implemented the attention mechanism, let's apply it to some test data.
In a real-world scenario, attention mechanisms are used to enhance sequence-to-sequence
models, such as those used in machine translation.
"""

# Let's create some random data to simulate queries, keys, and values.
# TODO: Generate random tensors for query, key, and value
query = torch.rand(5, 10, 64)  # Example: 5 sequences, 10 tokens per sequence, 64-dimensional embeddings
key = torch.rand(5, 10, 64)
value = torch.rand(5, 10, 64)

# Use the attention mechanism to compute the output
d_k = query.size(-1)
output, attention_weights = attention(query, key, value, d_k)

# TODO: Print the output and attention weights for inspection.

## 5. Step 4: Testing & Visualization
"""
To fully understand how attention works, it's useful to visualize the attention
weights. The attention weights tell us which parts of the input sequence the
model is focusing on at each step.
"""

# TODO: Visualize the attention weights using matplotlib or any other library of your choice.
# You can plot the attention weights as a heatmap to see how the model focuses on different tokens.

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights):
    """
    Visualizes the attention weights as a heatmap.
    """
    attention_weights_np = attention_weights.detach().numpy()
    sns.heatmap(attention_weights_np[0], annot=True, cmap="Blues")
    plt.show()

# TODO: Test the visualization function with the attention weights computed above.

visualize_attention(attention_weights)

## 6. Conclusion
"""
Congratulations! You've just implemented the core of the attention mechanism,
a key component in modern neural networks like the Transformer model.

Feel free to experiment further by modifying the input data or adjusting the
dimension sizes. Understanding this mechanism is an important step toward
grasping more advanced models, such as BERT and GPT.
"""
