# Scaled Dot-Product Attention from Scratch (NumPy)

This notebook implements scaled dot-product attention with visualization.

## Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt

## Softmax Function

In [None]:
def softmax(x, axis=-1):
    x = x - np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

## Scaled Dot-Product Attention

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)
    if mask is not None:
        scores = np.where(mask == 0, -1e9, scores)
    attention_weights = softmax(scores, axis=-1)
    output = np.matmul(attention_weights, V)
    return output, attention_weights

## Example & Visualization

In [None]:
tokens = ['I','love','machine','learning']
seq_len = len(tokens)
batch_size = 1
d_k = 8
np.random.seed(42)
Q = np.random.rand(batch_size, seq_len, d_k)
K = np.random.rand(batch_size, seq_len, d_k)
V = np.random.rand(batch_size, seq_len, d_k)
output, attention_weights = scaled_dot_product_attention(Q, K, V)

plt.imshow(attention_weights[0], cmap='viridis')
plt.colorbar()
plt.xticks(range(seq_len), tokens, rotation=45)
plt.yticks(range(seq_len), tokens)
plt.title('Attention Weights Heatmap')
plt.show()