# Lab 3.2.5: AWQ Quantization

**Module:** 3.2 - Model Quantization & Optimization  
**Time:** 1.5 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚òÜ‚òÜ

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand AWQ's activation-aware approach
- [ ] Apply AWQ quantization to a model
- [ ] Compare AWQ vs GPTQ performance and quality
- [ ] Choose the right quantization method for your use case

---

## üìö Prerequisites

- Completed: Lab 3.2.4 (GPTQ Quantization)
- Library: `autoawq` (`pip install autoawq`)
- Hardware: CUDA GPU (DGX Spark recommended)

---

## üåç Real-World Context

**AWQ (Activation-aware Weight Quantization)** was published in 2023 and quickly became a GPTQ alternative:

| Aspect | GPTQ | AWQ |
|--------|------|-----|
| Core idea | Hessian-based error correction | Protect salient weights |
| Speed | Slower quantization | Faster quantization |
| Quality | Excellent | Often slightly better |
| Inference | Fast | Fast (comparable) |

**Key Insight:** Not all weights are equally important! AWQ identifies "salient" weights (those that matter most for accuracy) and protects them during quantization.

---

## üßí ELI5: How AWQ Works

> **Imagine you're compressing a photo album...**
>
> **GPTQ approach:** Compress everything, then go back and fix the most blurry parts.
>
> **AWQ approach:** First, identify which photos are your favorites (most viewed).  
> Then, compress the boring photos aggressively but keep your favorites sharp!
>
> **In AI terms:**
> 1. Run calibration data through the model
> 2. Track which weights produce large activations ("salient" weights)
> 3. Scale up salient weights before quantization (protecting them)
> 4. Quantize all weights to 4-bit
> 5. Scale down at inference to compensate
>
> The result: Important weights stay accurate, less important weights get compressed more!

---

## Part 1: Environment Setup

In [None]:
import torch
import gc
import time
import numpy as np
import matplotlib.pyplot as plt

print("=" * 60)
print("AWQ Quantization Lab")
print("=" * 60)

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Check for autoawq
try:
    from awq import AutoAWQForCausalLM
    HAS_AWQ = True
    print(f"\nautoawq: Available")
except ImportError:
    HAS_AWQ = False
    print(f"\nautoawq: Not installed")
    print("  Install with: pip install autoawq")

print("=" * 60)

In [None]:
# Helper functions
def get_gpu_memory():
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1e9
    return 0

def clear_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("Helpers loaded!")

---

## Part 2: Understanding AWQ's Salient Weight Protection

In [None]:
# Demonstrate the AWQ concept: protecting salient weights

def simulate_awq_quantize(weights, activations, bits=4, protect_ratio=0.01):
    """
    Simplified AWQ-style quantization.
    
    AWQ's key insight:
    1. Some weights produce large activations consistently
    2. These "salient" weights should be protected from quantization error
    3. We scale them up before quantization, then scale down at inference
    
    Args:
        weights: Weight matrix [out, in]
        activations: Sample activations [batch, in]
        bits: Quantization bits
        protect_ratio: Fraction of weights to protect (top salient)
    """
    # Step 1: Compute "saliency" - how much each weight contributes to activations
    # In real AWQ, this is |weight| * activation_std
    activation_importance = activations.abs().mean(dim=0)  # [in]
    weight_importance = weights.abs()  # [out, in]
    saliency = weight_importance * activation_importance  # [out, in]
    
    # Step 2: Find salient weights (top percentile)
    threshold = torch.quantile(saliency.flatten(), 1 - protect_ratio)
    is_salient = saliency >= threshold
    
    # Step 3: Compute per-channel scales
    # Salient channels get scaled up to reduce relative quantization error
    channel_saliency = saliency.mean(dim=0)  # [in]
    scales = (channel_saliency / channel_saliency.mean()).clamp(min=0.5, max=2.0)
    
    # Step 4: Apply scales and quantize
    scaled_weights = weights * scales  # Scale up salient channels
    
    # Quantize
    qmax = 2 ** (bits - 1) - 1
    max_val = scaled_weights.abs().max(dim=1, keepdim=True)[0]
    quant_scale = max_val / qmax
    quant_scale = torch.clamp(quant_scale, min=1e-10)
    
    quantized = torch.round(scaled_weights / quant_scale).clamp(-qmax-1, qmax)
    dequantized = quantized * quant_scale
    
    # Undo scaling
    final_weights = dequantized / scales
    
    return final_weights, is_salient, scales


def simple_quantize(weights, bits=4):
    """Simple symmetric quantization without protection."""
    qmax = 2 ** (bits - 1) - 1
    max_val = weights.abs().max(dim=1, keepdim=True)[0]
    scale = max_val / qmax
    scale = torch.clamp(scale, min=1e-10)
    quantized = torch.round(weights / scale).clamp(-qmax-1, qmax)
    return quantized * scale


# Test with sample data
torch.manual_seed(42)
weights = torch.randn(256, 512)
activations = torch.randn(100, 512)  # 100 samples

# Compare simple vs AWQ
simple_result = simple_quantize(weights, bits=4)
awq_result, is_salient, scales = simulate_awq_quantize(weights, activations, bits=4)

# Compute errors
simple_error = (weights - simple_result).abs()
awq_error = (weights - awq_result).abs()

print("AWQ vs Simple Quantization")
print("=" * 50)
print(f"{'Metric':<25} {'Simple':>12} {'AWQ':>12}")
print("-" * 50)
print(f"{'Mean error':<25} {simple_error.mean():.6f} {awq_error.mean():>12.6f}")
print(f"{'Max error':<25} {simple_error.max():.6f} {awq_error.max():>12.6f}")
print(f"{'RMSE':<25} {simple_error.pow(2).mean().sqrt():.6f} {awq_error.pow(2).mean().sqrt():>12.6f}")

# Error on salient weights specifically
salient_simple_error = simple_error[is_salient].mean()
salient_awq_error = awq_error[is_salient].mean()
print(f"{'Error on salient weights':<25} {salient_simple_error:.6f} {salient_awq_error:>12.6f}")

improvement = (salient_simple_error - salient_awq_error) / salient_simple_error * 100
print(f"\nAWQ reduces salient weight error by {improvement:.1f}%!")

In [None]:
# Visualize AWQ's protection of salient weights

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Saliency distribution
ax = axes[0, 0]
saliency = (weights.abs() * activations.abs().mean(dim=0)).flatten()
ax.hist(saliency.numpy(), bins=50, alpha=0.7, color='steelblue', edgecolor='black')
threshold = torch.quantile(saliency, 0.99)
ax.axvline(threshold.item(), color='red', linestyle='--', linewidth=2, label=f'Top 1% threshold')
ax.set_xlabel('Saliency Score')
ax.set_ylabel('Count')
ax.set_title('Weight Saliency Distribution')
ax.legend()

# Scaling factors
ax = axes[0, 1]
ax.bar(range(len(scales)), scales.numpy(), alpha=0.7, color='coral')
ax.axhline(1.0, color='black', linestyle='--', linewidth=1)
ax.set_xlabel('Channel Index')
ax.set_ylabel('Scale Factor')
ax.set_title('AWQ Per-Channel Scales (sample)')
ax.set_xlim(0, 50)  # Show first 50 channels

# Error comparison
ax = axes[1, 0]
ax.hist(simple_error.flatten().numpy(), bins=50, alpha=0.5, label='Simple', color='coral')
ax.hist(awq_error.flatten().numpy(), bins=50, alpha=0.5, label='AWQ', color='steelblue')
ax.set_xlabel('Absolute Error')
ax.set_ylabel('Count')
ax.set_title('Error Distribution: Simple vs AWQ')
ax.legend()
ax.set_yscale('log')

# Error on salient vs non-salient
ax = axes[1, 1]
categories = ['Non-Salient', 'Salient (Top 1%)']
simple_errors = [
    simple_error[~is_salient].mean().item(),
    simple_error[is_salient].mean().item()
]
awq_errors = [
    awq_error[~is_salient].mean().item(),
    awq_error[is_salient].mean().item()
]

x = np.arange(len(categories))
width = 0.35
ax.bar(x - width/2, simple_errors, width, label='Simple', color='coral')
ax.bar(x + width/2, awq_errors, width, label='AWQ', color='steelblue')
ax.set_ylabel('Mean Absolute Error')
ax.set_title('Error by Weight Importance')
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.legend()

plt.tight_layout()
plt.show()

print("\nKey insight: AWQ significantly reduces error on the weights that matter most!")

---

## Part 3: AWQ Quantization with AutoAWQ

In [None]:
from transformers import AutoTokenizer

# Model selection
MODEL_NAME = "microsoft/phi-2"  # Good for demo
# MODEL_NAME = "meta-llama/Llama-2-7b-hf"  # Production benchmark

print(f"Selected model: {MODEL_NAME}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer loaded. Vocab size: {len(tokenizer)}")

In [None]:
# AWQ Quantization

if HAS_AWQ:
    print("\nQuantizing with AWQ...")
    print("=" * 60)
    
    # AWQ configuration
    awq_config = {
        "zero_point": True,      # Use zero-point quantization
        "q_group_size": 128,     # Group size (like GPTQ)
        "w_bit": 4,              # 4-bit weights
        "version": "GEMM",       # GEMM kernel (fastest)
    }
    
    print(f"Config: {awq_config}")
    
    # Load model
    print(f"\nLoading model...")
    start_time = time.time()
    
    model = AutoAWQForCausalLM.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
    )
    
    print(f"Model loaded in {time.time() - start_time:.1f}s")
    print(f"GPU memory: {get_gpu_memory():.2f} GB")
    
    # Quantize
    print(f"\nRunning AWQ quantization...")
    start_time = time.time()
    
    model.quantize(
        tokenizer,
        quant_config=awq_config,
    )
    
    print(f"Quantization complete in {time.time() - start_time:.1f}s")
    print(f"GPU memory: {get_gpu_memory():.2f} GB")
    
else:
    print("autoawq not available. Showing expected workflow:")
    print("""
    from awq import AutoAWQForCausalLM
    
    model = AutoAWQForCausalLM.from_pretrained(MODEL_NAME)
    model.quantize(tokenizer, quant_config={"w_bit": 4, "q_group_size": 128})
    model.save_quantized("./my-model-awq-4bit")
    """)
    model = None

In [None]:
# Test the AWQ model

if model is not None:
    test_prompt = "The key to machine learning is"
    
    print(f"Testing AWQ model...")
    print(f"Prompt: '{test_prompt}'")
    
    inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"\nGenerated: {response}")

---

## Part 4: AWQ vs GPTQ Comparison

In [None]:
# AWQ vs GPTQ comparison table

print("AWQ vs GPTQ Comparison")
print("=" * 70)
print("""
Based on published benchmarks (Llama 2 7B):

| Metric | FP16 | GPTQ-4bit | AWQ-4bit |
|--------|------|-----------|----------|
| Memory | 13.5 GB | 4.0 GB | 4.0 GB |
| MMLU | 46.0% | 45.5% | 45.7% |
| HellaSwag | 79.0% | 78.5% | 78.8% |
| Winogrande | 74.0% | 73.5% | 73.8% |
| Quant Time | - | ~45 min | ~15 min |
| Inference Speed | 1.0x | 2.5x | 2.5x |

Key differences:
1. AWQ is ~3x faster to quantize
2. AWQ often has slightly better quality (0.1-0.3%)
3. Inference speed is comparable
4. Memory usage is identical

When to use each:
- AWQ: Default choice for most use cases
- GPTQ: When you need desc_act for specific models
- Either: Both are production-ready!
""")

In [None]:
# Loading pre-quantized AWQ models

print("\nLoading Pre-Quantized AWQ Models")
print("=" * 50)
print("""
Many AWQ models are available on Hugging Face:

TheBloke's AWQ models:
- TheBloke/Llama-2-7B-AWQ
- TheBloke/Llama-2-13B-AWQ  
- TheBloke/Llama-2-70B-AWQ
- TheBloke/Mistral-7B-AWQ
- And many more!

Loading a pre-quantized model:
""")

print("""
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model = AutoAWQForCausalLM.from_quantized(
    "TheBloke/Llama-2-7B-AWQ",
    fuse_layers=True,  # Fuse for faster inference
)
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-7B-AWQ")
""")

---

## ‚úã Try It Yourself

### Exercise 1: Compare AWQ and GPTQ Quality

Load both AWQ and GPTQ versions of the same model and compare:
1. Output quality on the same prompts
2. Inference speed
3. Memory usage

### Exercise 2: Test AWQ Fused Layers

Compare inference speed with and without layer fusion:
```python
model = AutoAWQForCausalLM.from_quantized(..., fuse_layers=True)
model = AutoAWQForCausalLM.from_quantized(..., fuse_layers=False)
```

In [None]:
# Exercise: Your code here

# TODO: Load AWQ and GPTQ versions of the same model
# TODO: Run the same prompts through both
# TODO: Compare quality and speed

# Your code here...

---

## Common Mistakes

### Mistake 1: Not Using Fused Layers

```python
# Wrong: Loading without fusion (slower)
model = AutoAWQForCausalLM.from_quantized(model_path)

# Right: Enable layer fusion for speed
model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=True)
```

### Mistake 2: Wrong GEMM Version

```python
# Wrong: Using slower GEMV for batched inference
config = {"version": "GEMV"}  # Only for batch_size=1

# Right: Use GEMM for batched inference
config = {"version": "GEMM"}  # Works for all batch sizes
```

### Mistake 3: Ignoring Hardware Compatibility

```python
# AWQ kernels require specific GPU architectures
# Check compatibility before using

# Supported: Ampere, Ada, Hopper, Blackwell (SM 80+)
# Limited support: Turing (SM 75)
# Not supported: Volta and earlier
```

---

## Checkpoint

You've learned:

- **AWQ concept**: Protect salient weights through activation-aware scaling
- **AWQ vs GPTQ**: Similar quality, AWQ is faster to quantize
- **Layer fusion**: Enable for faster inference
- **Pre-quantized models**: TheBloke has AWQ versions of popular models

---

## Further Reading

- [AWQ Paper](https://arxiv.org/abs/2306.00978)
- [AutoAWQ GitHub](https://github.com/casper-hansen/AutoAWQ)
- [AWQ vs GPTQ Comparison](https://huggingface.co/blog/4bit-transformers-bitsandbytes)

---

## Cleanup

In [None]:
# Clean up
if 'model' in dir() and model is not None:
    del model

clear_memory()
print("Notebook complete! Ready for Lab 3.2.6: GGUF Conversion")

---

## Next Steps

In the next notebook, we'll explore **GGUF Conversion** - the format used by llama.cpp for CPU/GPU inference!

‚û°Ô∏è Continue to: [Lab 3.2.6: GGUF Conversion](lab-3.2.6-gguf-conversion.ipynb)