# Implement Attention from Scratch
### Problem Statement
Implement a **Scaled Dot-Product Attention** mechanism from scratch using PyTorch. Your mission (should you choose to accept it) is to replicate what PyTorch's built-in `scaled_dot_product_attention` does ‚Äî manually. This core component is essential in Transformer architectures and helps models focus on relevant parts of a sequence. You'll test your implementation against PyTorch's native one to ensure you nailed it.


### Requirements
1. **Define the Function**:
   - Create a function `scaled_dot_product_attention(q, k, v, mask=None)` that:
     - Computes attention scores via the dot product of query and key vectors.
     - Scales the scores using the square root of the key dimension.
     - Applies an optional mask to the scores.
     - Applies softmax to convert scores into attention weights.
     - Uses these weights to compute a weighted sum of values (V).

2. **Test Your Work**:
   - Use sample tensors for query (Q), key (K), and value (V).
   - Compare the result of your custom implementation with PyTorch's `F.scaled_dot_product_attention` using an `assert` to check numerical accuracy.


### Constraints
- ‚ùå Do NOT use `F.scaled_dot_product_attention` inside your custom function ‚Äî that defeats the whole point.
- ‚úÖ Your implementation must handle **batch dimensions** correctly.
- ‚úÖ Support optional **masking** for future tokens or padding.
- ‚úÖ Use only PyTorch ops ‚Äî no cheating with external attention libs.



<details>
  <summary>üí° Hint</summary>
  Use `torch.matmul()` to compute dot products and `F.softmax()` for the final attention weights. The mask (if used) should be applied **before** the softmax using `masked_fill`.
</details>


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [125]:
import torch

torch.manual_seed(42)

batch_size = 1
seq_len = 3
dim = 3

q = torch.randn(batch_size, seq_len, dim)
k = torch.randn(batch_size, seq_len, dim)
v = torch.randn(batch_size, seq_len, dim)

In [None]:
import math
def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Compute the scaled dot-product attention.
    
    Args:
        q: Query tensor of shape (..., seq_len_q, d_k)
        k: Key tensor of shape (..., seq_len_k, d_k)
        v: Value tensor of shape (..., seq_len_k, d_v)
        mask: Optional mask tensor of shape (..., seq_len_q, seq_len_k)
    
    Returns:
        output: Attention output tensor of shape (..., seq_len_q, d_v)
        attention_weights: Attention weights tensor of shape (..., seq_len_q, seq_len_k)
    """
    # mask=torch.ones(size=(batch_size, seq_len, seq_len))
    q = q                                   # (batch, seq_len_q, d_k)
    v = v                                   # (batch, seq_len_k, d_v)
    d_k = k.shape[-1]
    # k_transposed = torch.transpose(k, 1, 2) # (batch, d_k, seq_len_k)
    k_transposed = torch.transpose(k, -2, -1) # (batch, d_k, seq_len_k)
    q_kt = q@k_transposed                   # (batch, seq_len_q, seq_len_k)
    q_kt = q_kt/math.sqrt(d_k)              # (batch, seq_len_q, seq_len_k)
    
    if mask is not None:
        if mask.dtype == torch.bool:
            q_kt = q_kt.masked_fill(mask == False, -1e9)
        else:
            # q_kt = q_kt * mask
            q_kt = q_kt + mask
    attention = torch.softmax(q_kt, dim=-1) # (batch, seq_len_q, seq_len_k)
    # if mask is not None:
    #     attention = attention*mask          # (batch, seq_len_q, seq_len_k)
    qkv = attention@v                       # (batch, seq_len_q, d_v)
    return qkv, attention

In [143]:
boolean_mask = torch.randint(high=2, size=(batch_size, seq_len,seq_len)).bool()
float_mask = torch.randint(high=2, size=(batch_size, seq_len,seq_len), dtype=k.dtype)
# mask=None
# Testing on data & compare
output_custom, _ = scaled_dot_product_attention(q, k, v, mask=float_mask)
print(output_custom)
output = F.scaled_dot_product_attention(q, k, v, attn_mask=float_mask)
print(output)

assert torch.allclose(output_custom, output, atol=1e-08, rtol=1e-05) # Check if they are close enough.


tensor([[[-0.0462, -0.4405,  0.9905],
         [ 0.2242, -0.4261,  0.9416],
         [ 0.3432, -0.1818,  0.7171]]])
tensor([[[-0.0462, -0.4405,  0.9905],
         [ 0.2242, -0.4261,  0.9416],
         [ 0.3432, -0.1818,  0.7171]]])
