<a href="https://colab.research.google.com/github/antonemking/at-challenges/blob/challenge-0/at_sr_challenge0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np

# Define the dimensions
d_model = 64  # Embedding size
seq_len = 10  # Sequence length (number of tokens)
batch_size = 2  # Number of examples in a batch

# Generate random queries, keys, and values (for simplicity, use random numbers)
queries = np.random.rand(batch_size, seq_len, d_model)
keys = np.random.rand(batch_size, seq_len, d_model)
values = np.random.rand(batch_size, seq_len, d_model)

print("Queries shape:", queries.shape)
print("Keys shape:", keys.shape)
print("Values shape:", values.shape)

In [None]:
def scaled_dot_product_attention(query, key, value):
    # Step 1: Calculate the dot product between query and key (transpose the key)
    scores = np.matmul(query, key.transpose(0, 2, 1))  # Shape: (batch_size, seq_len, seq_len)

    # Step 2: Scale the scores by the square root of the dimension size
    scale_factor = np.sqrt(query.shape[-1])
    scores /= scale_factor

    # Step 3: Apply softmax to get the attention weights
    attention_weights = np.exp(scores) / np.sum(np.exp(scores), axis=-1, keepdims=True)

    # Step 4: Multiply the attention weights with the values
    output = np.matmul(attention_weights, value)  # Shape: (batch_size, seq_len, d_model)

    return output, attention_weights

# Compute the output and attention weights
output, attention_weights = scaled_dot_product_attention(queries, keys, values)

print("Attention output shape:", output.shape)
print("Attention weights shape:", attention_weights.shape)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Visualize the attention weights for the first example in the batch
plt.figure(figsize=(10, 8))
sns.heatmap(attention_weights[0], annot=True, cmap='Blues')
plt.title('Attention Weights for First Example in Batch')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.show()

In [None]:
print("Attention Output (First Example):")
print(output[0])  # Print the attention output for the first example in the batch