# PyTorch LLM Architecture & Data Flow — Simple, Intuitive Guide

Main takeaway: Modern LLMs are just stacks of smart “attention + MLP” blocks. You can swap attention styles, tweak feed-forward layers, add small adapters, and optimize memory/training to get big gains without changing the whole model. Below are the concepts explained simply with practical examples and when to use them.

***

## Custom Transformer Components

### Multi-Head Attention Variants

[check out - Attention visualization](https://www.youtube.com/watch?app=desktop&v=RNF0FvRjGZk)

Think of attention as “looking up relevant info” across the sequence.

- Multi-Query Attention (MQA): All attention heads share the same Key/Value, but have different Queries.
  - Why: Saves memory/bandwidth at inference time; faster decoding.
  - Analogy: Many people asking questions (queries), but the library of facts (key/values) is shared.
  - Use when: You need faster generation (server-side), minimal quality loss vs full MHA.

- Grouped Query Attention (GQA): Queries are grouped; each group shares K/V.
  - Why: Middle ground between full MHA and MQA—better quality than MQA, cheaper than MHA.
  - Use when: You want speed-ups with less accuracy drop than MQA.

- FlexAttention (PyTorch 2.5+): A flexible attention API where you can plug custom scores, masks, or biases.
  - Why: Experiment quickly (e.g., penalize repeated tokens, prioritize recency).
  - Example: Add a bias to prefer recent tokens in long-context chat; or block attention across segments.

Intuition example: Summarizing a story. MHA = each head looks for different patterns (characters, places, chronology). MQA/GQA: fewer libraries to consult; answers faster. FlexAttention: you can say “focus more on the ending” with a custom score modifier.

***

### Advanced Feed-Forward Networks (FFN)
The MLP between attentions shapes the information.

- SwiGLU (used in LLaMA/PaLM): A gated activation replacing ReLU/GELU that often improves performance at similar cost.
  - Intuition: A smarter filter that lets useful signals pass more cleanly.
  - Use when: You want accuracy bump without major changes.

- Mixture of Experts (MoE): Many expert MLPs; a router picks a couple per token.
  - Why: Scales parameter count without proportional compute per token.
  - Analogy: A call center routes each call to the best specialist.
  - Use when: You need larger capacity with controlled FLOPs.

***

## Positional Encodings
Models need to know token order.

- RoPE (Rotary): Encodes position by rotating Q/K vectors; works great with long contexts and extrapolation.
  - Analogy: Label each word by angle; relative positions fall out naturally.
  - Use when: You want strong long-context behavior and compatibility with FlashAttention.

- ALiBi: Adds linear bias favoring nearby tokens; simple and extrapolates well.
  - Use when: You want simplicity and stable long-context generalization.

Example: In “Alice met Bob in Paris,” positions help map “met” to “Alice” and “Bob” correctly.

***

## Advanced Normalization
Stabilize training and improve speed.

- RMSNorm: Normalizes by RMS magnitude, not mean+variance like LayerNorm.
  - Why: Slightly cheaper and often just as stable; widely used in LLaMA.
- LayerNorm axis: Choose which dimensions to normalize across—affects stability/speed on certain shapes.
- Pre-LN vs Post-LN:
  - Pre-LN (norm before sublayer): More stable in deep models, faster to converge.
  - Post-LN (norm after): Original Transformer style, sometimes better final quality but trickier to train.
- GroupNorm: Niche for conv-like patterns; rarely used in vanilla LLMs.

***

## Parameter-Efficient Fine-Tuning (PEFT)
Get domain/task gains by training a tiny subset of weights.

- LoRA: Inject small low-rank matrices into attention/MLP weights; only train these.
  - Analogy: Add small “adapters” rather than rewiring the whole brain.
  - Use when: You have limited compute or want many task variants.

- AdaLoRA: Dynamically adjusts LoRA rank per layer.
  - Why: Spend capacity where it matters most.

- Adapters / Prefix Tuning:
  - Adapters: Tiny bottleneck layers added between blocks.
  - Prefix Tuning: Learn “virtual tokens” prepended to the sequence to steer behavior.
  - Use when: You need multi-task variants with quick swaps.

Example: Fine-tune a general LLM for finance Q&A using LoRA modules. Keep a separate LoRA for medical Q&A. Swap modules at runtime.

***

## Memory Optimization

- Gradient Checkpointing: Recompute certain activations during backward pass instead of storing them.
  - Why: Fit longer sequences/bigger models on the same GPU.
  - Trade-off: Saves memory, costs extra compute.
  - Methods:
    - Function-based: Wrap specific blocks in checkpoint().
    - Module-based: Checkpoint whole submodules.
    - Sequential: Apply across the transformer stack.
    - Selective: Only checkpoint the memory hogs (e.g., attention).

- Mixed Precision + Grad Scaling:
  - Use float16/bfloat16 to speed up and reduce memory.
  - Automatic loss scaling prevents underflow.

Example: Training 8k context on a 24GB GPU—turn on checkpointing for attention/MLP, use bfloat16, and accumulate gradients.

***

## Custom Loss & Training Techniques

- Label Smoothing: Softens targets (e.g., 0.9 for correct class).
  - Why: Reduces overconfidence, improves calibration.
- Focal Loss: Downweights easy examples, focuses on hard ones.
  - Use when: Class imbalance or noisy labels (e.g., toxicity classification).
- Contrastive Loss: Pull matched pairs together, push mismatched apart.
  - Use when: Learning embeddings for retrieval/RAG reranking.
- Custom LR Schedulers: Warmup + cosine decay is a solid default; tweak per model size and batch.

Example: For instruction tuning with noisy datasets, apply label smoothing to avoid overfitting spurious patterns.

***

## Advanced PyTorch Utilities

- Forward Hooks: Peek at layer inputs/outputs to debug or record features.
  - Example: Measure head-level attention entropy to detect dead heads.
- Dynamic Freezing: Freeze early layers, train top layers/LoRA to speed fine-tuning.
- Custom Autograd Functions: Write fused ops or custom gradients for performance.
- Gradient Clipping: Prevents exploding gradients; clip by norm (e.g., 1.0) is common.

***

## Model Surgery and Weight Manipulation

- Weight Initialization: Use stable schemes (e.g., scaled init for deep nets).
- Pruning: Remove low-importance weights or heads; compress for deployment.
- Weight Averaging (Model Soups): Average checkpoints from different runs/seeds—often yields more robust models.

Example: After fine-tuning multiple seeds, average their weights to smooth out idiosyncrasies and boost generalization.

***

## Performance Optimization Tips

- torch.compile (PyTorch 2.x):
  - Speeds up model with graph capture and kernel fusion.
  - Modes: default, reduce-overhead, max-autotune.
- Scaled Dot-Product Attention API:
  - Uses FlashAttention under the hood when available for big speed/memory wins.
- Memory-Efficient Attention:
  - Chunked or block-sparse strategies for long sequences.
- Efficient Sequence Packing:
  - Pack variable-length samples to reduce padding; boosts throughput.

Example: For 4k+ context training, enable scaled_dot_product_attention, pack batches to reduce pad, use torch.compile, and checkpoint.

***

## Putting It Together: Practical Patterns

- Fast Inference Server:
  - Switch to GQA/MQA, quantize to INT8/FP8 if supported, enable FlashAttention, compile model.
- Cheap Domain Adaptation:
  - Add LoRA on attention and MLP, use label smoothing, freeze base layers, train in bfloat16 with grad accumulation.
- Long-Context RAG:
  - RoPE with long-context scaling, custom attention bias for recency via FlexAttention, checkpoint attention, use retrieval-tuned contrastive loss for the retriever.

***

## Tiny Code Sketches (illustrative)

Multi-Query/Grouped Query idea:
```python
# Pseudocode style – concept only
Q = W_q(x)          # per head
K = shared_W_k(x)   # shared across heads (MQA)
V = shared_W_v(x)   # shared across heads (MQA)
# For GQA: K/V shared within each head group
```

LoRA injection:
```python
# W_out = W_base + A @ B, where A,B are small trainable low-rank
W_eff = W_base + A @ B
y = x @ W_eff
```

FlexAttention score modifier:
```python
# scores = (Q @ K.T)/sqrt(d) + custom_bias
# e.g., bias recent tokens: bias[i,j] = -lambda * (i-j>0)
```

Gradient checkpointing (module-based):
```python
from torch.utils.checkpoint import checkpoint_sequential
model.blocks = checkpoint_sequential(model.blocks, chunks=4)
```

torch.compile:
```python
model = torch.compile(model, mode="max-autotune")
```

Gradient clipping:
```python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
```

***

## When to Choose What

- Need lower latency and memory at inference: prefer GQA/MQA, FlashAttention, quantization.
- Need better reasoning without more compute: try SwiGLU, better normalization (RMSNorm), careful LR schedule.
- Many tasks/clients on one base model: use LoRA/adapters or prefix tuning for each task.
- Limited GPU memory: mixed precision + checkpointing + sequence packing.
- Retrieval-heavy systems: use contrastive losses for retriever, FlexAttention to bias context usage, RoPE for long contexts.

Use this as a menu: pick a few techniques based on your bottleneck (latency, memory, quality, data). Start with LoRA + RMSNorm + RoPE + FlashAttention + torch.compile, then iterate with FlexAttention or GQA/MQA as needed.