# SOTA Optimization for Transformers: Training & Inference

This notebook demonstrates modern optimization techniques for transformer models across training and inference, highlighting key ideas like FlashAttention, QLoRA, vLLM, and torch.compile.

## Learning objectives
- Apply mixed precision, checkpointing, and FSDP to reduce training cost
- Use dynamic/PTQ/QAT quantization paths for inference efficiency
- Understand K/V cache and paged attention patterns
- Explore speculative decoding and multi-query attention benefits

## References (selected)
- FlashAttention v1/v2 — Dao et al., 2022/2023 (arXiv:2205.14135, 2307.08691)
- QLoRA — Dettmers et al., 2023 (arXiv:2305.14314)
- vLLM — Kwon et al., 2023 (Efficient Memory Management for LLM Serving)
- GPTQ — Frantar et al., 2022 (arXiv:2210.17323)
- AWQ — Lin et al., 2023 (arXiv:2306.00978)
- ZeRO — Rajbhandari et al., 2020 (arXiv:1910.02054)


In [None]:
# Mixed precision + compile demo on a toy module
import torch
import torch.nn as nn
from train_optimizations import MixedPrecisionTrainer
from inference_optimizations import torch_compile_model

class Tiny(nn.Module):
    def __init__(self, d=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d)
        )
    def forward(self, x):
        return self.net(x)

model = Tiny()
model = torch_compile_model(model)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
trainer = MixedPrecisionTrainer(model, opt, use_bf16=False)

x = torch.randn(8, 256)
y = torch.randn(8, 256)

def loss_fn(pred, tgt):
    return nn.functional.mse_loss(pred, tgt)

for _ in range(3):
    pred = trainer.forward_with_autocast(model, x)
    loss = loss_fn(pred, y)
    val = trainer.step(loss)
    print('Loss:', val)


In [None]:
# Quantization path: dynamic PTQ on a tiny head
import torch
import torch.nn as nn
from quantization_toolkit import ptq_dynamic_linear_only

head = nn.Sequential(nn.Linear(256, 256), nn.GELU(), nn.Linear(256, 128))
qhead = ptq_dynamic_linear_only(head)
print('Quantized module:', qhead)
