# Implement Scaled Dot-Product Attention in PyTorch

## Problem Statement

Your task is to **implement the Scaled Dot-Product Attention mechanism**, including the **Query (Q), Key (K), and Value (V) transformations**, in PyTorch. This mechanism is a fundamental operation in **Transformer-based models**, as introduced in the **Attention Is All You Need** paper (Vaswani et al., 2017). 

You will build a **self-attention module** and apply it to a **machine translation** task using the **Multi30k dataset** (English-German sentence pairs).

---

## 📌 Background

The **Scaled Dot-Product Attention** computes attention weights and aggregates relevant information from different parts of the input sequence. Given input embeddings, the mechanism generates Query, Key, and Value matrices, then applies attention to learn contextual relationships.

### **Mathematical Formulation**
For each sequence, the **attention score** is computed as:

$$e_{ij} = \frac{Q_i K_j^T}{\sqrt{d_k}}$$

where:
- **\(Q = X W_Q\)** → Query matrix
- **\(K = X W_K\)** → Key matrix
- **\(V = X W_V\)** → Value matrix

where:
- \( X \) is the input embeddings.
- \( W_Q, W_K, W_V \) are learnable weight matrices.
- \( d_k \) is the hidden size of keys/queries (used for scaling).

The normalized attention weights are:

$$alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T} \exp(e_{ik})}$$

The final attention output is:


$$O_i = \sum_{j=1}^{T} \alpha_{ij} V_j$$

where:
- \( O \) is the output sequence after applying attention.

---

## **Task Requirements**

### 1️⃣ Implement **Scaled Dot-Product Attention with Query, Key, and Value Transformations**
- Compute **Query, Key, and Value matrices** from the input.
- Apply **dot-product attention**:
  - Compute attention scores.
  - Normalize with softmax.
  - Aggregate values based on attention scores.

### 2️⃣ Integrate into a Simple **Self-Attention Model**
- Use the **Multi30k dataset** for English-German translation.
- Implement a **Transformer-style attention block** using your scaled dot-product attention module.
- Build a basic **encoder-decoder architecture** for sequence-to-sequence translation.

### 3️⃣ Handle Variable-Length Sequences
- Implement **masking** to ignore padding tokens.

---

## **Constraints**
- The input to the attention module should be of shape:

$$X \in \mathbb{R}^{(B, T, d)}$$

where:
- **\(B\)** is the batch size.
- **\(T\)** is the sequence length.
- **\(d\)** is the hidden dimension of embeddings.

- The Query, Key, and Value matrices should be computed using:

$$Q = X W_Q, \quad K = X W_K, \quad V = X W_V$$

where \( W_Q, W_K, W_V \) are **trainable weight matrices**.

---

## **💡 Hints**
1. **Define Learnable Weight Matrices**:
   - Use `nn.Linear` layers to generate Q, K, and V.
   - Example: `self.W_Q = nn.Linear(d_model, d_k, bias=False)`

2. **Apply Dot-Product Attention**:
   - Compute scores: `torch.matmul(Q, K.transpose(-2, -1))`
   - Scale by $$sqrt{d_k}$$: `scores /= math.sqrt(d_k)`
   - Apply `torch.softmax(scores, dim=-1)`

3. **Handle Padding Masks**:
   - Mask **padding tokens** before softmax using `masked_fill(mask == 0, -inf)`

---

## **📌 Example Implementation**

### **1️⃣ Implement Scaled Dot-Product Attention**
```python


In [88]:
import torch
import torch.nn as nn
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, dropout=0.1, num_heads=8):
        super(ScaledDotProductAttention, self).__init__()
        self.num_heads = num_heads
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        
        # Learnable linear layers for Q, K, V transformations
        self.Q = nn.Linear(d_model, d_model)
        self.K = nn.Linear(d_model, d_model)
        self.V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model, bias=False)  # Final output projection
        
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, X, query, key, value, mask=None):
        """
        X: (batch_size, seq_len, d_model) - Input embeddings
        mask: (batch_size, seq_len, seq_len) - Optional padding mask

        Returns:
        - Attention output (batch_size, seq_len, d_k)
        - Attention weights (batch_size, seq_len, seq_len)
        """
        B, T, D = X.shape
        
        # Compute Query, Key, and Value 
        # (batch_size, seq_len, head, d_k) -> (batch_size, head, seq_len, d_k)
        Query = self.Q(X).reshape(B, T, self.num_heads, self.d_k).transpose(1, 2) #batch_size * head * token_size * k
        Key = self.K(X).reshape(B, T, self.num_heads, self.d_k).transpose(1, 2) 
        Value = self.V(X).reshape(B, T, self.num_heads, self.d_k).transpose(1, 2)
        
        # Compute attention scores
        # (batch_size, head, seq_len, seq_len)
        attention_scores = torch.matmul(Query, Key.transpose(-1, -2)) / math.sqrt(self.d_k)
        
        # Apply mask (optional)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float('inf'))

        # Compute attention weights
        attn_weights = self.softmax(attention_scores)

        # Compute weighted sum of values
        output = torch.matmul(attn_weights, Value) # (batch_size, head, seq_len, d_k)
        output = output.transpose(1, 2).reshape(B, T, D) # (batch_size, seq_len, d_model)
        output = self.W_O(output)
    

        return output, attn_weights


In [61]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v=32):
        super().__init__()
        self.d_k = d_k

        # Query, Key, and Value transformations
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_v, bias=False)
        self.W_O = nn.Linear(d_v, d_model, bias=False)  # Final output projection

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X, mask=None):
        batch_size, seq_len, _ = X.shape

        # Compute Query, Key, and Value projections
        Q = self.W_Q(X)  # Shape: (batch_size, seq_len, d_k)
        K = self.W_K(X)  # Shape: (batch_size, seq_len, d_k)
        V = self.W_V(X)  # Shape: (batch_size, seq_len, d_v)

        # Compute attention scores
        attention_scores = torch.matmul(Q, K.transpose(1, 2)) / math.sqrt(self.d_k)

        # Apply mask (optional)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

        # Compute attention weights
        attn_weights = self.softmax(attention_scores)  # Shape: (batch_size, seq_len, seq_len)

        # Compute weighted sum of values
        output = torch.matmul(attn_weights, V)  # Shape: (batch_size, seq_len, d_v)

        # Project back to d_model
        # output = self.W_O(output)  # Shape: (batch_size, seq_len, d_model)

        return output, attn_weights


In [91]:
from torchsummary import summary
batch_size, seq_len, d_model = 16, 10, 512
model = ScaledDotProductAttention(d_model=input_dim, num_heads=num_heads)
summary(model, input_size=(32, 8, 8, 8))

TypeError: forward() missing 3 required positional arguments: 'query', 'key', and 'value'

In [89]:

# Training MultiHeadAttention with Toy Dataset
def train_attention_with_toy_dataset():
    batch_size, seq_len, d_model = 16, 10, 512
    num_heads = 8
    num_epochs = 50
    learning_rate = 0.001
    
    attention = ScaledDotProductAttention(d_model=d_model, num_heads=num_heads)
    optimizer = torch.optim.Adam(attention.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    
    # Initialize Toy Dataset and DataLoader
    class ToySequenceDataset(Dataset):
        def __init__(self, num_samples=500, sequence_length=10, input_dim=512):
            self.sequence_length = sequence_length
            self.input_dim = input_dim
            self.data = torch.rand(num_samples, sequence_length, input_dim)  # Shape: (num_samples, seq_length, input_dim)

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            return self.data[idx]

    toy_dataset = ToySequenceDataset(num_samples=500)
    train_loader = DataLoader(toy_dataset, batch_size=batch_size, shuffle=True)

    # Training Loop
    for epoch in range(num_epochs):
        attention.train()
        total_loss = 0
        for sequences in train_loader:
            query, key, value = sequences, sequences, sequences
            mask = torch.ones(batch_size, 1, seq_len, seq_len)
            
            optimizer.zero_grad()
            outputs, attn_weights = attention(query, key, value, mask)
            loss = criterion(outputs, sequences)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        if epoch % 10 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}")
    
    # Testing the trained model
    attention.eval()
    with torch.no_grad():
        test_sequences = next(iter(train_loader))
        test_outputs, test_attn_weights = attention(test_sequences, test_sequences, test_sequences, mask)
        print("Sample Output Shape:", test_outputs.shape)  # Expected: (batch_size, sequence_length, d_model)
        print("Sample Attention Weights Shape:", test_attn_weights.shape)  # Expected: (batch_size, num_heads, sequence_length, sequence_length)

train_attention_with_toy_dataset()

Epoch [1/50], Loss: 0.1077
Epoch [11/50], Loss: 0.0815
Epoch [21/50], Loss: 0.0254
Epoch [31/50], Loss: 0.0071
Epoch [41/50], Loss: 0.0022
Sample Output Shape: torch.Size([16, 10, 512])
Sample Attention Weights Shape: torch.Size([16, 8, 10, 10])
