In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------
# Scaled Dot-Product Attention (standard)
# -------------------------
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.scale = d_k ** -0.5

    def forward(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output, attn_probs

# -------------------------
# Linear Attention (efficient for long seq)
# -------------------------
class LinearAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V):
        # Φ(x) = elu(x)+1 (positive kernel feature map)
        Q = F.elu(Q) + 1
        K = F.elu(K) + 1

        KV = torch.einsum('...nd,...ne->...de', K, V)  # (d, e)
        Z = 1.0 / torch.einsum('...nd,...d->...n', Q, K.sum(dim=-2) + 1e-6)
        output = torch.einsum('...nd,...de,...n->...ne', Q, KV, Z)
        return output

# -------------------------
# Hybrid Multi-Head Attention
# -------------------------
class HybridAttention(nn.Module):
    def __init__(self, d_model, n_heads, seq_threshold=128):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.seq_threshold = seq_threshold

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.fc_out = nn.Linear(d_model, d_model)

        self.full_attn = ScaledDotProductAttention(self.d_k)
        self.linear_attn = LinearAttention()

    def forward(self, x, mask=None):
        B, N, D = x.shape

        Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)

        if N <= self.seq_threshold:
            out, _ = self.full_attn(Q, K, V, mask)
        else:
            out = self.linear_attn(Q, K, V)

        out = out.transpose(1, 2).contiguous().view(B, N, D)
        return self.fc_out(out)

# -------------------------
# Transformer Encoder Block
# -------------------------
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ff_hidden, dropout=0.1):
        super().__init__()
        self.attn = HybridAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_hidden),
            nn.ReLU(),
            nn.Linear(ff_hidden, d_model)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out = self.attn(x, mask)
        x = self.norm1(x + self.dropout(attn_out))

        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

# -------------------------
# Mini Hybrid Transformer (for classification)
# -------------------------
class MiniHybridTransformer(nn.Module):
    def __init__(self, input_dim, d_model=128, n_heads=4, ff_hidden=256, num_classes=10, num_layers=2):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, ff_hidden) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(d_model, num_classes)

    def forward(self, x, mask=None):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, mask)
        x = x.mean(dim=1)  # Pooling
        return self.fc_out(x)

# -------------------------
# Test Run (CPU)
# -------------------------
if __name__ == "__main__":
    device = torch.device("cpu")
    model = MiniHybridTransformer(input_dim=64, num_classes=5).to(device)

    dummy_input = torch.randn(8, 150, 64).to(device)  # (batch, seq_len, features)
    out = model(dummy_input)
    print("Output shape:", out.shape)  # Expected: (8, 5)


Output shape: torch.Size([8, 5])



## **What the code does step by step**:

---

### 1. **Imports**

It loads the required Python libraries:

* `torch` → for deep learning.
* `torch.nn` → for building layers (like Linear, LayerNorm, Dropout).
* `torch.nn.functional` → for activation functions (like softmax, relu).

---

### 2. **Scaled Dot-Product Attention (standard attention)**

This is the **Full Attention** part:

* Takes **queries (Q), keys (K), and values (V)**.
* Computes attention scores by `Q*Kᵀ / sqrt(d)`.
* Applies **softmax** to convert scores into probabilities.
* Multiplies probabilities with **V** to get the final weighted output.

This is the **classic transformer attention**.

---

### 3. **Linear Attention (efficient attention)**

This is the **fast version**:

* Instead of computing the full `Q*Kᵀ` matrix (which is slow for large inputs),
  it uses an approximation where softmax is applied differently.
* Idea: compute `(Q·(softmax(K)ᵀ·V))` instead of `(softmax(Q·Kᵀ)·V)`.
* Saves a **lot of memory and time** when sequence length is huge.

---

### 4. **MiniHybridAttention (hybrid block)**

* This block combines **both Full Attention and Linear Attention**.
* It runs both attentions separately:

  * `out_full` = result from normal attention.
  * `out_linear` = result from linear attention.
* Then it **averages them** → `(out_full + out_linear) / 2`.
  (You can also add a trainable weight to decide how much to trust each one.)
* This makes it a **hybrid attention mechanism**.

---

### 5. **MiniHybridTransformer (the model)**

* **Input layer**: Embeds the input vector using a linear layer.
* **Attention block**: Uses `MiniHybridAttention`.
* **Feed-forward block**: A small neural network with two linear layers and ReLU activation.
* **LayerNorm**: Helps stabilize training.
* **Dropout**: Prevents overfitting.
* **Output layer**: Converts hidden state back to vocab size (for classification or prediction).

---

### 6. **Running the Model (example forward pass)**

* Creates some **dummy input data** (batch of 2 sequences, each of length 10, vocab size 50).
* Feeds it into the model.
* Prints the output shape → `(batch_size, seq_len, vocab_size)`.

---

✅ **In short:**
This code builds a **small transformer-like model** that **mixes full attention and linear attention** into one **hybrid mechanism**, making it faster and more memory-efficient while still keeping the accuracy of normal attention.

---



**Line by line** through the **Mini Hybrid Transformer** code 
---

## 📌 Code Recap (Mini Hybrid Transformer for CPU)

```python
import torch
import torch.nn as nn
import torch.nn.functional as F

# Hybrid Attention: combines Full Attention (exact) and Linear Attention (approximate)
class HybridAttention(nn.Module):
    def __init__(self, d_model, n_heads, linear_ratio=0.5):
        super(HybridAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.linear_ratio = linear_ratio

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, D = x.shape  # Batch size, Sequence length, Embedding size

        # Project input into Q, K, V
        Q = self.q_linear(x).view(B, N, self.n_heads, D // self.n_heads).transpose(1, 2)
        K = self.k_linear(x).view(B, N, self.n_heads, D // self.n_heads).transpose(1, 2)
        V = self.v_linear(x).view(B, N, self.n_heads, D // self.n_heads).transpose(1, 2)

        # Split heads into full and linear
        split = int(self.n_heads * self.linear_ratio)
        Q_full, K_full, V_full = Q[:, :split], K[:, :split], V[:, :split]
        Q_lin, K_lin, V_lin = Q[:, split:], K[:, split:], V[:, split:]

        # Full attention
        attn_scores = torch.matmul(Q_full, K_full.transpose(-2, -1)) / (D ** 0.5)
        attn_probs = F.softmax(attn_scores, dim=-1)
        full_out = torch.matmul(attn_probs, V_full)

        # Linear attention (kernel trick: relu(Q) * (relu(K).T @ V))
        K_lin_relu = F.relu(K_lin)
        Q_lin_relu = F.relu(Q_lin)
        KV = torch.matmul(K_lin_relu.transpose(-2, -1), V_lin)
        lin_out = torch.matmul(Q_lin_relu, KV)

        # Concatenate outputs
        out = torch.cat([full_out, lin_out], dim=1).transpose(1, 2).contiguous().view(B, N, D)
        return self.out(out)

# Mini Transformer Block with Hybrid Attention
class MiniHybridTransformer(nn.Module):
    def __init__(self, d_model=64, n_heads=4, linear_ratio=0.5, dim_ff=128):
        super(MiniHybridTransformer, self).__init__()
        self.attn = HybridAttention(d_model, n_heads, linear_ratio)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(),
            nn.Linear(dim_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))  # Residual connection
        x = x + self.ff(self.norm2(x))    # Residual connection
        return x

# Test model on CPU
if __name__ == "__main__":
    model = MiniHybridTransformer(d_model=64, n_heads=4, linear_ratio=0.5).cpu()
    x = torch.randn(2, 10, 64)  # (batch=2, seq_len=10, embedding=64)
    y = model(x)
    print("Input shape:", x.shape)
    print("Output shape:", y.shape)
```

---

# 🔎 Line-by-Line Explanation

---

### 1. Import libraries

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
```

* `torch` → Core PyTorch library.
* `nn` → Used to build layers like Linear, ReLU, LayerNorm.
* `F` → Contains useful functions (softmax, relu, etc.).

---

### 2. Define **Hybrid Attention**

```python
class HybridAttention(nn.Module):
```

* Creates a **new type of Attention layer** that mixes **Full Attention** and **Linear Attention**.

---

### 3. Initialize constructor

```python
def __init__(self, d_model, n_heads, linear_ratio=0.5):
    super(HybridAttention, self).__init__()
```

* `d_model`: size of embeddings (example: 64).
* `n_heads`: number of attention heads (example: 4).
* `linear_ratio`: how many heads use **linear attention** (rest use full attention).

---

### 4. Define weight matrices

```python
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
```

* These are the **Q (Query), K (Key), V (Value)** projection layers.
* `self.out` → final layer to recombine head outputs.

---

### 5. Forward method (main computation)

```python
def forward(self, x):
    B, N, D = x.shape
```

* `B` = batch size (number of samples at once).
* `N` = sequence length (words/tokens).
* `D` = embedding size (features per token).

---

### 6. Project inputs into Q, K, V

```python
Q = self.q_linear(x).view(B, N, self.n_heads, D // self.n_heads).transpose(1, 2)
K = self.k_linear(x).view(B, N, self.n_heads, D // self.n_heads).transpose(1, 2)
V = self.v_linear(x).view(B, N, self.n_heads, D // self.n_heads).transpose(1, 2)
```

* Convert input into Q, K, V matrices.
* Split into multiple **heads** (`n_heads`).
* `transpose(1, 2)` → puts `n_heads` dimension in front.

---

### 7. Split heads into **full** and **linear**

```python
split = int(self.n_heads * self.linear_ratio)
Q_full, K_full, V_full = Q[:, :split], K[:, :split], V[:, :split]
Q_lin, K_lin, V_lin = Q[:, split:], K[:, split:], V[:, split:]
```

* First few heads (`split`) → use **Full Attention**.
* Remaining heads → use **Linear Attention**.

---

### 8. Full Attention

```python
attn_scores = torch.matmul(Q_full, K_full.transpose(-2, -1)) / (D ** 0.5)
attn_probs = F.softmax(attn_scores, dim=-1)
full_out = torch.matmul(attn_probs, V_full)
```

* Calculate similarity between queries & keys.
* Normalize with **softmax**.
* Multiply with values → gives **weighted sum** (context).

---

### 9. Linear Attention

```python
K_lin_relu = F.relu(K_lin)
Q_lin_relu = F.relu(Q_lin)
KV = torch.matmul(K_lin_relu.transpose(-2, -1), V_lin)
lin_out = torch.matmul(Q_lin_relu, KV)
```

* Uses **ReLU kernel trick** for efficiency.
* Instead of `softmax(QKᵀ)`, it computes `(ReLU(Q) * (ReLU(K)ᵀ @ V))`.
* Much faster for long sequences.

---

### 10. Concatenate results

```python
out = torch.cat([full_out, lin_out], dim=1).transpose(1, 2).contiguous().view(B, N, D)
return self.out(out)
```

* Combines **full + linear attention outputs**.
* Reorders & reshapes to original format `(B, N, D)`.
* Passes through final `Linear` layer.

---

### 11. Define **Transformer Block**

```python
class MiniHybridTransformer(nn.Module):
```

* A block = **Attention + FeedForward + Residuals + Norm**.

---

### 12. Constructor

```python
def __init__(self, d_model=64, n_heads=4, linear_ratio=0.5, dim_ff=128):
```

* `d_model=64` → embedding size.
* `n_heads=4` → multi-head attention.
* `dim_ff=128` → hidden size for feed-forward network.

---

### 13. Define layers

```python
self.attn = HybridAttention(d_model, n_heads, linear_ratio)
self.ff = nn.Sequential(
    nn.Linear(d_model, dim_ff),
    nn.ReLU(),
    nn.Linear(dim_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
```

* `self.attn`: Hybrid attention.
* `self.ff`: Feed-forward network.
* `self.norm1` & `self.norm2`: stabilize training.

---

### 14. Forward pass

```python
def forward(self, x):
    x = x + self.attn(self.norm1(x))  # Residual connection
    x = x + self.ff(self.norm2(x))    # Residual connection
    return x
```

* Apply **LayerNorm → Attention → Residual**.
* Apply **LayerNorm → FeedForward → Residual**.
* Return the transformed output.

---

### 15. Testing model on CPU

```python
if __name__ == "__main__":
    model = MiniHybridTransformer(d_model=64, n_heads=4, linear_ratio=0.5).cpu()
    x = torch.randn(2, 10, 64)  # (batch=2, seq_len=10, embedding=64)
    y = model(x)
    print("Input shape:", x.shape)
    print("Output shape:", y.shape)
```

* Creates model.
* Input: `2` batches, sequence length `10`, embedding size `64`.
* Prints input & output shapes.

---

✅ **In short:**
This is a **mini Transformer block** that mixes **Full Attention (accurate but slow)** with **Linear Attention (fast but approximate)** to balance **speed and accuracy**.

---



In [None]:
## 📘 **Interview Questions on Full Attention vs Linear Attentionl break this into **three levels: Easy, Moderate, Hard**—so you can handle** **MCA + CSE-level interviews** fully prepared.

---

# 🔹 Easy Level Questions (Basics)

1. **Q:** What is the main purpose of this code?
   **A:** To implement and test a hybrid transformer model combining full attention and linear attention on CPU.

2. **Q:** Which library is primarily used in this code?
   **A:** PyTorch (`torch`).

3. **Q:** What does the `MiniHybridTransformer` class represent?
   **A:** A custom transformer model that mixes full and linear attention.

4. **Q:** What is the function of `nn.Embedding` in this code?
   **A:** Converts token indices into dense vector representations.

5. **Q:** What is the role of `nn.TransformerEncoder`?
   **A:** It stacks multiple transformer encoder layers for sequence modeling.

6. **Q:** What does `nn.TransformerEncoderLayer` do?
   **A:** Defines a single transformer encoder block with attention + feedforward layers.

7. **Q:** What is linear attention?
   **A:** An approximation of standard attention that reduces computation from quadratic to linear in sequence length.

8. **Q:** Why do we use `nn.Linear` at the output?
   **A:** To map hidden states to vocabulary size for prediction.

9. **Q:** Why is ReLU used in linear attention?
   **A:** To ensure positive features and stabilize attention approximation.

10. **Q:** What does `batch_size=2` mean?
    **A:** Two input sequences are processed at the same time.

11. **Q:** Why is `torch.randn` used in linear attention?
    **A:** To generate random projection vectors for approximation.

12. **Q:** Why do we set `dtype=torch.float32`?
    **A:** To define the precision of computations.

13. **Q:** What is the function of the `forward` method in PyTorch models?
    **A:** Defines how input data flows through the model.

14. **Q:** Why do we call `.to(device)`?
    **A:** To move tensors and models to CPU or GPU.

15. **Q:** What is the role of `seq_len` in this code?
    **A:** The length of each input sequence (10 tokens here).

16. **Q:** Why are vocab\_size and embed\_dim important?
    **A:** They define the input token space and the size of embeddings.

17. **Q:** What is the purpose of `torch.randint`?
    **A:** To generate random integer token IDs for testing.

18. **Q:** What does `src.transpose(0, 1)` do?
    **A:** Changes shape to match transformer input format: `(seq_len, batch, embed_dim)`.

19. **Q:** Why does the transformer need `(seq_len, batch, embed_dim)`?
    **A:** It expects time-first input format.

20. **Q:** What does the hybrid model output shape `(2, 10, 50)` mean?
    **A:** Batch size = 2, sequence length = 10, vocab predictions = 50.

---

# 🔹 Moderate Level Questions (Deeper Understanding)

21. **Q:** What is the difference between full attention and linear attention?
    **A:** Full attention computes all pairwise token interactions (O(n²)), while linear attention approximates them to O(n).

22. **Q:** Why would we mix both attentions in one model?
    **A:** To balance accuracy (from full attention) and efficiency (from linear attention).

23. **Q:** Why do we use random projection in linear attention?
    **A:** To reduce the dimensionality of attention computations.

24. **Q:** How does scaling by `1/sqrt(d_k)` help in attention?
    **A:** Prevents large dot-product values, stabilizing gradients.

25. **Q:** Why does the code define `MiniHybridTransformer` instead of using only PyTorch’s built-in transformer?
    **A:** To add the linear attention mechanism alongside the default full attention.

26. **Q:** Why is `ReLU(Q @ projection)` used instead of softmax?
    **A:** It avoids quadratic normalization cost, making it linear-time.

27. **Q:** What would happen if we increased `seq_len` from 10 to 1000?
    **A:** Full attention would slow down (O(n²)), but linear attention would scale better.

28. **Q:** What is the role of the `output_layer`?
    **A:** Projects hidden states to logits over vocabulary for prediction.

29. **Q:** What happens if we remove the `output_layer`?
    **A:** The model won’t produce class probabilities, just embeddings.

30. **Q:** Why do we use `torch.matmul` in attention computation?
    **A:** To compute dot-products efficiently between queries, keys, and values.

31. **Q:** How is memory usage different in linear vs. full attention?
    **A:** Full attention stores an n×n matrix; linear stores smaller projected representations.

32. **Q:** Why is `nn.TransformerEncoderLayer` stacked into `nn.TransformerEncoder`?
    **A:** To build deeper models with multiple layers of attention.

33. **Q:** Can this hybrid transformer be used for NLP tasks?
    **A:** Yes, for tasks like text classification, translation, summarization.

34. **Q:** Why does `vocab_size=50` not match real-world NLP vocabularies?
    **A:** This is just a toy example; real models use vocab sizes in the tens of thousands.

35. **Q:** Why did we not train the model here?
    **A:** The code is only a forward pass demo, not training.

36. **Q:** Why might linear attention lose accuracy compared to full attention?
    **A:** Because approximation reduces exact token-to-token interactions.

37. **Q:** Why do we use batch processing in transformers?
    **A:** To process multiple sequences simultaneously for efficiency.

38. **Q:** Why is normalization important in transformers?
    **A:** It stabilizes training by keeping activations in a reasonable range.

39. **Q:** Can we extend this to a decoder model?
    **A:** Yes, by adding transformer decoder layers.

40. **Q:** Why might this model perform differently on CPU vs GPU?
    **A:** GPUs are optimized for matrix multiplications, while CPUs may be slower for large inputs.

---

# 🔹 Hard Level Questions (Advanced / Research)

41. **Q:** How does linear attention achieve O(n) complexity?
    **A:** By using kernel tricks and projections to avoid computing the full n×n attention matrix.

42. **Q:** What’s the tradeoff between accuracy and speed in hybrid transformers?
    **A:** Full attention is more accurate but slower; linear attention is faster but may lose precision.

43. **Q:** Why are random projections effective in approximating attention?
    **A:** They preserve distances probabilistically (Johnson-Lindenstrauss lemma).

44. **Q:** How would you modify this code to train on a text dataset like IMDB reviews?
    **A:** Replace random tokens with actual dataset inputs and add a loss function + optimizer.

45. **Q:** Why is position encoding not included here?
    **A:** For simplicity, but real transformers need positional information to capture order.

46. **Q:** How could we add sinusoidal positional encodings here?
    **A:** By adding a `PositionalEncoding` module before embeddings.

47. **Q:** Why is multi-head attention not explicitly implemented in the linear attention function?
    **A:** For simplicity; real-world implementations extend linear attention to multi-head form.

48. **Q:** What is the effect of `ReLU` in linear attention normalization?
    **A:** It ensures non-negativity, allowing normalization like a probability distribution.

49. **Q:** Can linear attention replace full attention completely?
    **A:** Sometimes yes, but accuracy may drop in tasks needing long-range dependencies.

50. **Q:** How would you integrate memory-efficient attention mechanisms like FlashAttention into this model?
    **A:** Replace the attention computation with FlashAttention kernels.

---
