In [None]:
#| default_exp quantize.quantize_callback

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#|export
from fastai.callback.all import *
from fastcore.basics import store_attr
from fasterai.quantize.quantizer import Quantizer
from torch.ao.quantization.quantize_fx import convert_fx
import torch
import copy

## Overview

The `QuantizeCallback` enables Quantization-Aware Training (QAT) within the fastai training loop. QAT simulates quantization effects during training, allowing the model to adapt its weights for better accuracy after quantization.

**Why use QAT over post-training quantization?**
- Higher accuracy on the quantized model
- Model learns to be robust to quantization noise
- Especially beneficial for models sensitive to precision loss

**Trade-offs:**
- Requires retraining (not just calibration)
- Training is slower due to simulated quantization
- Only for situations where you can afford additional training time

In [None]:
#|export
class QuantizeCallback(Callback):
    """
    Simple callback for Quantization-Aware Training (QAT) in fastai.
    Uses the Quantizer class for configuration and conversion.
    """
    def __init__(self, 
                 quantizer=None,        # Provide custom quantizer
                 backend='x86',         # Target backend for quantization: 'x86', 'qnnpack'
                 use_per_tensor=False,  # Force per-tensor quantization
                 verbose=False          # Enable verbose output
                ):
        "Initialize the QAT callback."
        store_attr()
        self.original_model = None
    
    def before_fit(self) -> None:
        "Prepare model for quantization-aware training"
        # Save original model
        self.original_model = copy.deepcopy(self.learn.model)
        
        # Create quantizer if not provided
        if self.quantizer is None:
            self.quantizer = Quantizer(
                backend=self.backend,
                method="qat",
                use_per_tensor=self.use_per_tensor,
                verbose=self.verbose
            )
        
        # Get example inputs
        x, _ = self.learn.dls.one_batch()
        original_device = next(self.learn.model.parameters()).device
        
        # Temporarily move to CPU for preparation
        self.learn.model = self.learn.model.cpu()
        
        # Prepare model for QAT using the quantizer
        try:
            # First save the original state dict
            orig_state_dict = self.learn.model.state_dict()
            
            # Use the _prepare_model method from the quantizer
            prepared_model = self.quantizer._prepare_model(self.learn.model, x.cpu())
            
            # Move back to original device and update learner's model
            self.learn.model = prepared_model.to(original_device)
                
            if self.verbose:
                print("Model prepared for QAT successfully")
                
        except Exception as e:
            print(f"Error preparing model for QAT: {e}")
            import traceback
            traceback.print_exc()
            # Restore original model on error
            self.learn.model = self.original_model.to(original_device)
    
    def after_fit(self) -> None:
        "Convert QAT model to fully quantized model"
        # Get original device before try block to ensure it's available in except
        original_device = next(self.learn.model.parameters()).device
        
        try:
            if self.verbose:
                print("Converting QAT model to fully quantized model")
            
            # Set model to eval mode and move to CPU for conversion
            self.learn.model = self.learn.model.cpu().eval()
            
            # Save a copy of the trained QAT model
            self.qat_model = copy.deepcopy(self.learn.model)
            
            # Convert to quantized model
            quantized_model = convert_fx(self.learn.model)
            
            # Save the quantized model
            self.learn.quantized_model = quantized_model
            
            # Keep the quantized model as the active model
            # This is crucial - the quantized model IS the trained model
            self.learn.model = quantized_model
                
        except Exception as e:
            print(f"Error converting QAT model: {e}")
            import traceback
            traceback.print_exc()
            
            # If conversion fails, at least keep the QAT-trained model
            if hasattr(self, 'qat_model'):
                self.learn.model = self.qat_model.to(original_device)
                print("Conversion failed, but QAT-trained model was kept")

In [None]:
QuantizeCallback()

NameError: name 'store_attr' is not defined

In [None]:
show_doc(QuantizeCallback)

**Parameters:**

- `quantizer`: Optional custom `Quantizer` instance for advanced configuration
- `backend`: Target backend (`'x86'`, `'qnnpack'`) - only used if quantizer not provided
- `use_per_tensor`: Force per-tensor quantization to avoid conversion issues
- `verbose`: Enable detailed output during QAT

---

## Usage Example

```python
from fasterai.quantize.quantize_callback import QuantizeCallback

# Basic QAT with default settings
cb = QuantizeCallback(backend='x86', verbose=True)

# Train with QAT
learn.fit(5, cbs=[cb])

# After training, the quantized model is available at:
quantized_model = learn.quantized_model
```

### QAT Workflow

1. **before_fit**: Model is prepared for QAT (fake quantization nodes inserted)
2. **Training**: Model trains with simulated quantization effects
3. **after_fit**: Model is converted to fully quantized form

The final `learn.model` is the quantized model ready for CPU inference.

---

## See Also

- [Quantizer](quantizer.html) - Core quantization class with backend/method options
- [ONNX Exporter](../export/onnx_exporter.html) - Export quantized models for deployment
- [PyTorch Quantization Docs](https://pytorch.org/docs/stable/quantization.html) - Official PyTorch guide