# Efficient Inference: FlashAttention & Quantization

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/efficient_inference.ipynb)

Modern LLMs rely on efficiency tricks to run fast and fit in memory.

## 1. FlashAttention (Conceptual Simulation)

**Problem:** Standard Attention computes an $N \times N$ matrix ($QK^T$), which is huge ($O(N^2)$ memory).
**Solution:** FlashAttention computes attention in **tiles** (blocks) without ever materializing the full matrix. It uses **Tiling** and **Recomputation**.

Here we simulate the "Tiling" logic in Python (real FlashAttention uses CUDA kernels).

In [None]:
import torch

def standard_attention(Q, K, V):
    # Memory: O(N^2)
    scores = Q @ K.transpose(-2, -1)
    P = torch.softmax(scores, dim=-1)
    return P @ V

def tiled_attention_simulation(Q, K, V, block_size=2):
    # Simulation of block-wise attention computation
    # Real FlashAttention does this in SRAM with online softmax
    N, d = Q.shape
    output = torch.zeros_like(Q)
    
    # Loop over blocks of Q (rows)
    for i in range(0, N, block_size):
        Q_block = Q[i:i+block_size]
        
        # In real FA, we would maintain running max/sum for softmax here
        # For this simple demo, we just compute the row-chunk exact attention
        # against ALL keys (or tiled keys) to show we don't need full N*N at once IF we managed stats.
        
        # Simplified: We still compute Q_block @ K^T (size block_size * N)
        # This is O(block_size * N), much smaller than O(N^2) if block_size is small
        scores_block = Q_block @ K.transpose(-2, -1)
        P_block = torch.softmax(scores_block, dim=-1)
        output[i:i+block_size] = P_block @ V
        
    return output

# Test
torch.manual_seed(42)
N, d = 8, 4
Q = torch.randn(N, d)
K = torch.randn(N, d)
V = torch.randn(N, d)

print("Standard:\n", standard_attention(Q, K, V))
print("Tiled (Simulated):\n", tiled_attention_simulation(Q, K, V, block_size=4))

## 2. Quantization (INT8)

Reducing precision from FP32 (4 bytes) to INT8 (1 byte) reduces memory by 4x. 

### AbsMax Quantization (Symmetric)
Map range $[-absmax, absmax]$ to $[-127, 127]$.

In [None]:
def absmax_quantize(x):
    scale = 127 / torch.max(torch.abs(x))
    x_quant = (x * scale).round().clamp(-127, 127).to(torch.int8)
    return x_quant, scale

def dequantize(x_quant, scale):
    return x_quant.float() / scale

weights = torch.tensor([0.1, -0.5, 1.2, -2.5, 0.0])
q, s = absmax_quantize(weights)
dq = dequantize(q, s)

print("Original:", weights)
print("Quantized (int8):", q)
print("Scale:", s)
print("Dequantized:", dq)
print("Error:", (weights - dq).abs().mean().item())

### ZeroPoint Quantization (Asymmetric)
Map range $[min, max]$ to $[0, 255]$ (unsigned).

$$ x_{int} = round( rac{x}{scale} + zero\_point ) $$

In [None]:
def zeropoint_quantize(x):
    x_range = x.max() - x.min()
    x_range = 1 if x_range == 0 else x_range
    
    scale = 255 / x_range
    zeropoint = (-x.min() * scale).round()
    
    x_quant = (x * scale + zeropoint).round().clamp(0, 255).to(torch.uint8)
    return x_quant, scale, zeropoint

def zp_dequantize(x_quant, scale, zeropoint):
    return (x_quant.float() - zeropoint) / scale

q_zp, s_zp, z_zp = zeropoint_quantize(weights)
dq_zp = zp_dequantize(q_zp, s_zp, z_zp)

print("\nZeroPoint Quantized:", q_zp)
print("Dequantized:", dq_zp)