# TorchAO Per-Tensor 4-bit Quantization for Qwen3-0.6B

This notebook demonstrates **per-tensor 4-bit quantization** using TorchAO's low-level primitives.

**Key Challenge**: TorchAO's `Int4WeightOnlyConfig` only supports per-group quantization (group_size=32/64/128/256).
For true per-tensor quantization (single scale per weight matrix), we use TorchAO's `quantize_affine`/`dequantize_affine` primitives.

**Contents**:
- PTQ (Post-Training Quantization): Apply int4 per-tensor quantization to pretrained model
- QAT (Quantization-Aware Training): Train with fake quantization + STE
- Inference with quantized model

**Per-Tensor vs Per-Group**:
| Aspect | Per-Tensor | Per-Group (TorchAO default) |
|--------|-----------|---------------------------|
| Scale count | 1 per weight | out_features × (in_features/group_size) |
| Accuracy | Lower | Higher |
| Simplicity | Higher | Lower |
| Hardware compat | Broader | Specific (tinygemm) |

## 0) Configuration

In [1]:
# ---- Config (edit these) ----
MODEL_NAME = 'Qwen/Qwen3-0.6B'
DEVICE = 'auto'  # 'cuda', 'mps', 'cpu', or 'auto'
SYMMETRIC = True  # Symmetric quantization (recommended for weights)
SKIP_LM_HEAD = True  # Skip quantizing the language model head

## 1) Install Dependencies

In [2]:
# Install TorchAO and dependencies
!pip install -q torchao>=0.7.0 transformers>=4.51.0 accelerate datasets

## 2) Imports & Environment Check

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm
import math

# TorchAO primitives (optional - we have manual fallback)
TORCHAO_AVAILABLE = False
try:
    from torchao.quantization.quant_primitives import (
        quantize_affine,
        dequantize_affine,
        choose_qparams_affine,
        MappingType,
        ZeroPointDomain,
    )
    TORCHAO_AVAILABLE = True
    print("TorchAO available - can use either TorchAO or manual quantization")
except ImportError:
    print("TorchAO not available - using manual quantization implementation")

print(f'torch: {torch.__version__}')
print(f'cuda: {torch.cuda.is_available()}')
print(f'mps: {torch.backends.mps.is_available()}')

# Auto device selection
def get_device(device_str='auto'):
    if device_str == 'auto':
        if torch.cuda.is_available():
            return torch.device('cuda')
        elif torch.backends.mps.is_available():
            return torch.device('mps')
        return torch.device('cpu')
    return torch.device(device_str)

device = get_device(DEVICE)
print(f'Using device: {device}')

TorchAO available - can use either TorchAO or manual quantization
torch: 2.9.0+cu126
cuda: True
mps: False
Using device: cuda


## 3) Per-Tensor Int4 Quantization Functions

Manual implementation of per-tensor int4 quantization.

**Why manual?** TorchAO's `quantize_affine`/`dequantize_affine` with large `block_size` can have issues. Our manual implementation is simpler and more robust.

**Formula**:
- Symmetric: `scale = max(|W|) / 7`, `q = round(W / scale)`, `W_dq = q * scale`
- Asymmetric: `scale = (max - min) / 15`, `zp = round(-min/scale - 8)`, `q = round(W/scale) + zp`

In [4]:
# Int4 quantization parameters
QUANT_MIN = -8  # 4-bit signed min
QUANT_MAX = 7   # 4-bit signed max

def quantize_per_tensor_int4_manual(weight: torch.Tensor, symmetric: bool = True):
    """
    Quantize weight to int4 with per-tensor scale (manual implementation).

    This is a simpler, more robust implementation that doesn't rely on
    TorchAO's block_size semantics which may not work for full-tensor sizes.

    Args:
        weight: Float weight tensor of shape (out_features, in_features)
        symmetric: Use symmetric quantization (zero_point = 0)

    Returns:
        weight_q: Quantized weight (int8 storage, int4 values)
        scale: Per-tensor scale factor (scalar)
        zero_point: Per-tensor zero point (scalar, 0 for symmetric)
    """
    # Compute scale for per-tensor quantization
    if symmetric:
        # Symmetric: scale = max(|w|) / qmax
        w_abs_max = weight.abs().max().clamp(min=1e-8)
        scale = w_abs_max / QUANT_MAX
        zero_point = torch.tensor(0, dtype=torch.int32, device=weight.device)
    else:
        # Asymmetric: scale = (max - min) / (qmax - qmin)
        w_min = weight.min()
        w_max = weight.max()
        scale = (w_max - w_min).clamp(min=1e-8) / (QUANT_MAX - QUANT_MIN)
        zero_point = torch.round(-w_min / scale + QUANT_MIN).to(torch.int32)

    # Quantize: q = round(w / scale) + zp, clamped to [qmin, qmax]
    if symmetric:
        weight_q = torch.round(weight / scale).clamp(QUANT_MIN, QUANT_MAX).to(torch.int8)
    else:
        weight_q = (torch.round(weight / scale) + zero_point).clamp(QUANT_MIN, QUANT_MAX).to(torch.int8)

    return weight_q, scale, zero_point


def dequantize_per_tensor_int4_manual(
    weight_q: torch.Tensor,
    scale: torch.Tensor,
    zero_point: torch.Tensor,
    output_dtype: torch.dtype = torch.float32,
):
    """
    Dequantize int4 weight back to float (manual implementation).

    Args:
        weight_q: Quantized weight tensor (int8)
        scale: Per-tensor scale (scalar)
        zero_point: Per-tensor zero point (scalar)
        output_dtype: Output float dtype

    Returns:
        Dequantized float weight
    """
    # Dequantize: w = (q - zp) * scale
    weight_dq = (weight_q.to(output_dtype) - zero_point.to(output_dtype)) * scale.to(output_dtype)
    return weight_dq


# Also try TorchAO primitives with different approach
def quantize_per_tensor_int4_torchao(weight: torch.Tensor, symmetric: bool = True):
    """
    Quantize using TorchAO primitives with per-tensor granularity.
    Uses block_size=(1, 1) and broadcasts, then takes mean for single scale.
    """
    from torchao.quantization.quant_primitives import (
        choose_qparams_affine,
        quantize_affine,
        MappingType,
    )

    # For per-tensor, we compute our own scale and use quantize_affine
    mapping = MappingType.SYMMETRIC if symmetric else MappingType.ASYMMETRIC

    if symmetric:
        w_abs_max = weight.abs().max().clamp(min=1e-8)
        scale = (w_abs_max / QUANT_MAX).reshape(1, 1)
        zero_point = torch.zeros(1, 1, dtype=torch.int32, device=weight.device)
    else:
        w_min = weight.min()
        w_max = weight.max()
        scale = ((w_max - w_min).clamp(min=1e-8) / (QUANT_MAX - QUANT_MIN)).reshape(1, 1)
        zero_point = torch.round(-w_min / scale + QUANT_MIN).to(torch.int32).reshape(1, 1)

    # Use block_size that matches scale shape for broadcasting
    block_size = weight.shape  # Full tensor

    weight_q = quantize_affine(
        weight,
        block_size,
        scale,
        zero_point,
        torch.int8,
        QUANT_MIN,
        QUANT_MAX,
    )

    return weight_q, scale.squeeze(), zero_point.squeeze()


# Use manual implementation (more robust)
quantize_per_tensor_int4 = quantize_per_tensor_int4_manual
dequantize_per_tensor_int4 = dequantize_per_tensor_int4_manual


# Test the functions
print("Testing quantization functions...")
test_weight = torch.randn(64, 128)
w_q, scale, zp = quantize_per_tensor_int4(test_weight, symmetric=SYMMETRIC)
w_dq = dequantize_per_tensor_int4(w_q, scale, zp)

print(f"Original shape: {test_weight.shape}, dtype: {test_weight.dtype}")
print(f"Quantized shape: {w_q.shape}, dtype: {w_q.dtype}")
print(f"Scale: {scale.item():.6f} (shape: {scale.shape})")
print(f"Zero point: {zp.item()} (shape: {zp.shape})")
print(f"Quantized range: [{w_q.min().item()}, {w_q.max().item()}]")
print(f"Dequantized range: [{w_dq.min().item():.4f}, {w_dq.max().item():.4f}]")
print(f"Original range: [{test_weight.min().item():.4f}, {test_weight.max().item():.4f}]")
print(f"Reconstruction MSE: {F.mse_loss(test_weight, w_dq).item():.6f}")

# Verify the quantization is working correctly
print(f"\n--- Sanity Check ---")
print(f"Scale * QUANT_MAX = {(scale * QUANT_MAX).item():.4f}")
print(f"Original abs max = {test_weight.abs().max().item():.4f}")
print(f"These should be approximately equal for symmetric quantization")

Testing quantization functions...
Original shape: torch.Size([64, 128]), dtype: torch.float32
Quantized shape: torch.Size([64, 128]), dtype: torch.int8
Scale: 0.519766 (shape: torch.Size([]))
Zero point: 0 (shape: torch.Size([]))
Quantized range: [-7, 7]
Dequantized range: [-3.6384, 3.6384]
Original range: [-3.5560, 3.6384]
Reconstruction MSE: 0.022520

--- Sanity Check ---
Scale * QUANT_MAX = 3.6384
Original abs max = 3.6384
These should be approximately equal for symmetric quantization


## 4) Fake Quantize for QAT (with STE)

For Quantization-Aware Training, we need a differentiable fake quantize operation.
The Straight-Through Estimator (STE) passes gradients through the quantization unchanged.

In [5]:
class PerTensorInt4FakeQuantize(torch.autograd.Function):
    """
    Fake quantize with Straight-Through Estimator (STE).

    Forward: quantize -> dequantize (simulates quantization effect)
    Backward: pass gradients through unchanged (STE)
    """

    @staticmethod
    def forward(ctx, weight, symmetric=True):
        # Quantize and immediately dequantize using our manual implementation
        w_q, scale, zp = quantize_per_tensor_int4(weight.detach(), symmetric)
        w_dq = dequantize_per_tensor_int4(w_q, scale, zp, output_dtype=weight.dtype)

        # Save for potential gradient clipping (optional)
        ctx.save_for_backward(weight)

        return w_dq

    @staticmethod
    def backward(ctx, grad_output):
        # STE: pass gradients through unchanged
        return grad_output, None


def fake_quant_int4(weight: torch.Tensor, symmetric: bool = True) -> torch.Tensor:
    """
    Apply fake int4 quantization with STE.

    Args:
        weight: Float weight tensor
        symmetric: Use symmetric quantization

    Returns:
        Fake-quantized weight (same dtype as input)
    """
    return PerTensorInt4FakeQuantize.apply(weight, symmetric)


# Test gradient flow
print("Testing STE gradient flow...")
test_w = torch.randn(32, 64, requires_grad=True)
w_fq = fake_quant_int4(test_w)
loss = w_fq.sum()
loss.backward()
print(f"Gradient exists: {test_w.grad is not None}")
print(f"Gradient shape: {test_w.grad.shape}")
print(f"Gradient mean: {test_w.grad.mean().item():.4f} (should be ~1.0 for STE)")

# Verify the fake quantization output
print(f"\n--- Fake Quantization Sanity Check ---")
print(f"Input dtype: {test_w.dtype}")
print(f"Output dtype: {w_fq.dtype}")
print(f"Input range: [{test_w.min().item():.4f}, {test_w.max().item():.4f}]")
print(f"Output range: [{w_fq.min().item():.4f}, {w_fq.max().item():.4f}]")
print(f"MSE between input and fake-quantized: {F.mse_loss(test_w.detach(), w_fq.detach()).item():.6f}")

Testing STE gradient flow...
Gradient exists: True
Gradient shape: torch.Size([32, 64])
Gradient mean: 1.0000 (should be ~1.0 for STE)

--- Fake Quantization Sanity Check ---
Input dtype: torch.float32
Output dtype: torch.float32
Input range: [-3.2562, 3.0628]
Output range: [-3.2562, 3.2562]
MSE between input and fake-quantized: 0.018185


## 5) Int4Linear Module (for QAT)

A drop-in replacement for `nn.Linear` that applies per-tensor int4 fake quantization during forward pass.

In [6]:
class Int4Linear(nn.Module):
    """
    Linear layer with per-tensor int4 fake quantization.

    During training (QAT): applies fake quantization with STE
    During inference: uses fake-quantized weights (or can be converted to real int4)
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        symmetric: bool = True,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.symmetric = symmetric

        # Full-precision weights (updated by optimizer)
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in = self.in_features
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply fake quantization to weights
        w_q = fake_quant_int4(self.weight, self.symmetric)
        return F.linear(x, w_q, self.bias)

    @classmethod
    def from_linear(cls, linear: nn.Linear, symmetric: bool = True) -> 'Int4Linear':
        """Create Int4Linear from existing nn.Linear."""
        int4_linear = cls(
            linear.in_features,
            linear.out_features,
            bias=linear.bias is not None,
            symmetric=symmetric,
        )
        int4_linear.weight.data = linear.weight.data.clone()
        if linear.bias is not None:
            int4_linear.bias.data = linear.bias.data.clone()
        return int4_linear

    def extra_repr(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, symmetric={self.symmetric}'


# Test Int4Linear
print("Testing Int4Linear...")
linear = nn.Linear(128, 64)
int4_linear = Int4Linear.from_linear(linear, symmetric=SYMMETRIC)
x = torch.randn(2, 128)
y = int4_linear(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(int4_linear)

Testing Int4Linear...
Input shape: torch.Size([2, 128])
Output shape: torch.Size([2, 64])
Int4Linear(in_features=128, out_features=64, bias=True, symmetric=True)


## 6) Replace Linear Layers with Int4Linear

Utility function to replace all `nn.Linear` layers in a model with `Int4Linear` for QAT.

In [7]:
def replace_linear_with_int4(
    model: nn.Module,
    skip_patterns: list = None,
    symmetric: bool = True,
) -> nn.Module:
    """
    Replace nn.Linear layers with Int4Linear for QAT.

    Args:
        model: PyTorch model
        skip_patterns: List of name patterns to skip (e.g., ['lm_head'])
        symmetric: Use symmetric quantization

    Returns:
        Modified model (in-place)
    """
    skip_patterns = skip_patterns or []
    replaced_count = 0
    skipped_count = 0

    # Collect modules to replace (can't modify during iteration)
    replacements = []

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Check skip patterns
            if any(pattern in name for pattern in skip_patterns):
                skipped_count += 1
                continue

            replacements.append((name, module))

    # Perform replacements
    for name, module in replacements:
        # Navigate to parent module
        parts = name.split('.')
        parent = model
        for part in parts[:-1]:
            parent = getattr(parent, part)

        # Replace
        int4_linear = Int4Linear.from_linear(module, symmetric=symmetric)
        setattr(parent, parts[-1], int4_linear)
        replaced_count += 1

    print(f"Replaced {replaced_count} Linear layers with Int4Linear")
    print(f"Skipped {skipped_count} layers (patterns: {skip_patterns})")

    return model


def count_parameters(model: nn.Module) -> dict:
    """Count model parameters."""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'total': total, 'trainable': trainable}

## 7) Load Qwen3-0.6B Model

In [8]:
print(f"Loading model: {MODEL_NAME}")

# Load in float32 for quantization
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Ensure pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

params = count_parameters(model)
print(f"Model loaded: {params['total']:,} parameters")
print(f"Model dtype: {next(model.parameters()).dtype}")

# Check baseline perplexity BEFORE quantization
print("\n--- Baseline check (before quantization) ---")
sample_text_short = "The quick brown fox jumps over the lazy dog."
model.eval()
model.to(device)
with torch.no_grad():
    inputs = tokenizer(sample_text_short, return_tensors='pt').to(device)
    outputs = model(inputs['input_ids'], labels=inputs['input_ids'])
    baseline_loss = outputs.loss.item()
    baseline_ppl = torch.exp(outputs.loss).item()
print(f"Baseline loss: {baseline_loss:.4f}")
print(f"Baseline perplexity: {baseline_ppl:.2f}")
print(f"Logits range: [{outputs.logits.min().item():.2f}, {outputs.logits.max().item():.2f}]")

# Move back to CPU for quantization
model.to('cpu')

Loading model: Qwen/Qwen3-0.6B


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Model loaded: 596,049,920 parameters
Model dtype: torch.float32

--- Baseline check (before quantization) ---
Baseline loss: 2.1961
Baseline perplexity: 8.99
Logits range: [-21.31, 21.76]


Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layer

## 8) PTQ: Apply Per-Tensor Int4 Quantization

Post-Training Quantization: quantize weights and measure reconstruction error.

In [9]:
def apply_ptq_int4(
    model: nn.Module,
    skip_patterns: list = None,
    symmetric: bool = True,
) -> tuple:
    """
    Apply PTQ int4 quantization to model weights.
    """
    skip_patterns = skip_patterns or []
    quantized_weights = {}
    total_mse = 0.0
    num_quantized = 0

    with torch.no_grad():
        for name, param in tqdm(list(model.named_parameters()), desc="Quantizing"):
            # Only quantize 2D weight matrices
            if 'weight' not in name or param.dim() != 2:
                continue

            # Check skip patterns
            if any(pattern in name for pattern in skip_patterns):
                continue

            # Store original for comparison
            original = param.data.clone()

            # Quantize
            w_q, scale, zp = quantize_per_tensor_int4(param.data, symmetric)
            w_dq = dequantize_per_tensor_int4(w_q, scale, zp, output_dtype=param.dtype)

            # Debug: Check for NaN/Inf
            if torch.isnan(w_dq).any() or torch.isinf(w_dq).any():
                print(f"WARNING: NaN/Inf in {name}!")
                print(f"  scale: {scale}, zp: {zp}")
                print(f"  w_q range: [{w_q.min()}, {w_q.max()}]")
                continue

            # Calculate MSE
            mse = F.mse_loss(original, w_dq).item()
            total_mse += mse
            num_quantized += 1

            # Store quantized info
            quantized_weights[name] = {
                'weight_int4': w_q.cpu(),
                'scale': scale.cpu() if hasattr(scale, 'cpu') else scale,
                'zero_point': zp.cpu() if hasattr(zp, 'cpu') else zp,
                'mse': mse,
            }

            # Replace weight with dequantized version
            param.data.copy_(w_dq)

    avg_mse = total_mse / num_quantized if num_quantized > 0 else 0
    print(f"\nQuantized {num_quantized} weight tensors")
    print(f"Average reconstruction MSE: {avg_mse:.6f}")

    return model, quantized_weights


# Apply PTQ
skip = ['lm_head'] if SKIP_LM_HEAD else []
model_ptq, quant_weights = apply_ptq_int4(model, skip_patterns=skip, symmetric=SYMMETRIC)

# Show per-layer MSE for first few layers
print("\nPer-layer MSE (first 10):")
for i, (name, info) in enumerate(list(quant_weights.items())[:10]):
    print(f"  {name}: MSE={info['mse']:.6f}")

# Debug: Check model weights after PTQ
print("\n--- Debug: Checking model weights after PTQ ---")
for name, param in list(model_ptq.named_parameters())[:5]:
    if param.dim() == 2:
        print(f"{name}:")
        print(f"  shape: {param.shape}, dtype: {param.dtype}")
        print(f"  range: [{param.min().item():.4f}, {param.max().item():.4f}]")
        print(f"  has NaN: {torch.isnan(param).any().item()}, has Inf: {torch.isinf(param).any().item()}")

Quantizing:   0%|          | 0/310 [00:00<?, ?it/s]


Quantized 197 weight tensors
Average reconstruction MSE: 0.000263

Per-layer MSE (first 10):
  model.embed_tokens.weight: MSE=0.000169
  model.layers.0.self_attn.q_proj.weight: MSE=0.000550
  model.layers.0.self_attn.k_proj.weight: MSE=0.000303
  model.layers.0.self_attn.v_proj.weight: MSE=0.000051
  model.layers.0.self_attn.o_proj.weight: MSE=0.000264
  model.layers.0.mlp.gate_proj.weight: MSE=0.000256
  model.layers.0.mlp.up_proj.weight: MSE=0.000231
  model.layers.0.mlp.down_proj.weight: MSE=0.000282
  model.layers.1.self_attn.q_proj.weight: MSE=0.000521
  model.layers.1.self_attn.k_proj.weight: MSE=0.000209

--- Debug: Checking model weights after PTQ ---
model.embed_tokens.weight:
  shape: torch.Size([151936, 1024]), dtype: torch.float32
  range: [-0.3184, 0.2274]
  has NaN: False, has Inf: False
model.layers.0.self_attn.q_proj.weight:
  shape: torch.Size([2048, 1024]), dtype: torch.float32
  range: [-0.5525, 0.6445]
  has NaN: False, has Inf: False
model.layers.0.self_attn.k_pro

## 9) Evaluate Perplexity

Measure model quality after quantization using perplexity on sample text.

In [10]:
@torch.no_grad()
def calculate_perplexity(
    model: nn.Module,
    tokenizer,
    text: str,
    device: torch.device,
    max_length: int = 512,
    debug: bool = False,
) -> float:
    """
    Calculate perplexity on given text.
    """
    model.eval()
    model.to(device)

    # Tokenize
    encodings = tokenizer(
        text,
        return_tensors='pt',
        truncation=True,
        max_length=max_length,
    ).to(device)

    input_ids = encodings['input_ids']

    if debug:
        print(f"Input shape: {input_ids.shape}")
        print(f"Model device: {next(model.parameters()).device}")
        print(f"Input device: {input_ids.device}")

    # Forward pass
    outputs = model(input_ids, labels=input_ids)
    loss = outputs.loss

    if debug:
        print(f"Loss: {loss.item():.4f}")
        print(f"Logits shape: {outputs.logits.shape}")
        print(f"Logits has NaN: {torch.isnan(outputs.logits).any().item()}")
        print(f"Logits has Inf: {torch.isinf(outputs.logits).any().item()}")
        print(f"Logits range: [{outputs.logits.min().item():.4f}, {outputs.logits.max().item():.4f}]")

    perplexity = torch.exp(loss).item()
    return perplexity


# Sample text for evaluation
sample_text = """The quick brown fox jumps over the lazy dog.
Machine learning is a subset of artificial intelligence that enables computers to learn from data.
Deep neural networks have revolutionized many fields including computer vision and natural language processing.
Quantization is a technique to reduce model size by using lower precision representations for weights and activations."""

# Calculate perplexity with debug
print("=" * 60)
print("Perplexity Evaluation (with debug)")
print("=" * 60)
ppl = calculate_perplexity(model_ptq, tokenizer, sample_text, device, debug=True)
print(f"\nPerplexity after PTQ: {ppl:.2f}")

# If perplexity is too high, check a single layer's quantization in detail
if ppl > 1000:
    print("\n--- HIGH PERPLEXITY DEBUG ---")
    print("Checking first quantized layer in detail...")

    first_layer_name = list(quant_weights.keys())[0]
    info = quant_weights[first_layer_name]
    print(f"\nLayer: {first_layer_name}")
    print(f"Quantized weight dtype: {info['weight_int4'].dtype}")
    print(f"Scale: {info['scale']}")
    print(f"Zero point: {info['zero_point']}")

    # Re-dequantize and check
    w_q = info['weight_int4']
    scale = info['scale']
    zp = info['zero_point']
    w_dq = dequantize_per_tensor_int4(w_q, scale, zp)
    print(f"Dequantized range: [{w_dq.min().item():.6f}, {w_dq.max().item():.6f}]")
    print(f"Dequantized has NaN: {torch.isnan(w_dq).any().item()}")

Perplexity Evaluation (with debug)
Input shape: torch.Size([1, 63])
Model device: cuda:0
Input device: cuda:0
Loss: 17.4848
Logits shape: torch.Size([1, 63, 151936])
Logits has NaN: False
Logits has Inf: False
Logits range: [-23.7292, 20.1738]

Perplexity after PTQ: 39224336.00

--- HIGH PERPLEXITY DEBUG ---
Checking first quantized layer in detail...

Layer: model.embed_tokens.weight
Quantized weight dtype: torch.int8
Scale: 0.0454799123108387
Zero point: 0
Dequantized range: [-0.318359, 0.227400]
Dequantized has NaN: False


## 10) Inference with Quantized Model

In [11]:
def generate_text(
    model: nn.Module,
    tokenizer,
    prompt: str,
    device: torch.device,
    max_new_tokens: int = 100,
    temperature: float = 0.7,
    do_sample: bool = True,
    use_chat_template: bool = True,
) -> str:
    """
    Generate text with the model.
    """
    model.eval()
    model.to(device)

    # Prepare input
    if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
        messages = [{'role': 'user', 'content': prompt}]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )
        inputs = tokenizer(text, return_tensors='pt').to(device)
    else:
        inputs = tokenizer(prompt, return_tensors='pt').to(device)

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature if do_sample else 1.0,
            do_sample=do_sample,
            pad_token_id=tokenizer.pad_token_id,
        )

    # Decode (skip input tokens)
    response = tokenizer.decode(
        outputs[0][inputs['input_ids'].shape[1]:],
        skip_special_tokens=True,
    )

    return response


# Test generation
prompts = [
    "Explain quantum computing in simple terms:",
    "What is the capital of France?",
    "Write a short poem about AI:",
]

print("=" * 60)
print("Inference with Per-Tensor Int4 Quantized Model")
print("=" * 60)

for prompt in prompts:
    print(f"\nPrompt: {prompt}")
    response = generate_text(
        model_ptq, tokenizer, prompt, device,
        max_new_tokens=50, do_sample=False,
    )
    print(f"Response: {response}")
    print("-" * 40)

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Inference with Per-Tensor Int4 Quantized Model

Prompt: Explain quantum computing in simple terms:
Response: 于 mycket固定的 anyABтьсяvantโปรดleixtATHER הוד host действ严重的encing一组 sunshineinus播DirecthinmultipkingullETER greatestattracksculusaches一组 instEngineKM greatest 'doch"Kind entfer次会议 Ez兴 ”

 greatest greatest Lansing Reasoneh
----------------------------------------

Prompt: What is the capital of France?
Response:  Google notitorountryouveâuinnenutantageiefs боль�âuquameld世    


---ان fitsapor�除 alllessikt�отockey pertymmaled Backgroundtbodyunteer่于lectibastockmaf_if lỗi�ighted我把incesахumphicana语ider
----------------------------------------

Prompt: Write a short poem about AI:
Response: 碧ikباحث�ick人生的 Notes �感激raries lưu�兮因子 Dictptive uważa由出色的geführt0gebratries후 năng中国 BloomriereанияuntaTube用心天ialesiskaai[List一组 quantbbleei yahoooubTHING BetteraddAllspieliculturaltml nghìn
----------------------------------------


## 11) QAT: Prepare Model for Training

Replace linear layers with Int4Linear for Quantization-Aware Training.

In [12]:
# Reload fresh model for QAT
print("Loading fresh model for QAT...")
model_qat = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    trust_remote_code=True,
)

# Replace linear layers
skip = ['lm_head'] if SKIP_LM_HEAD else []
model_qat = replace_linear_with_int4(model_qat, skip_patterns=skip, symmetric=SYMMETRIC)

# Move to device
model_qat = model_qat.to(device)

# Count Int4Linear layers
int4_count = sum(1 for m in model_qat.modules() if isinstance(m, Int4Linear))
print(f"\nTotal Int4Linear layers: {int4_count}")

Loading fresh model for QAT...
Replaced 196 Linear layers with Int4Linear
Skipped 1 layers (patterns: ['lm_head'])

Total Int4Linear layers: 196


*italicized text*## 12) QAT Training Loop (Demo)

A minimal QAT training example. For full training, use a proper dataset and training loop.

In [18]:
def train_qat_step(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    input_ids: torch.Tensor,
    labels: torch.Tensor,
    gradient_accumulation_steps: int = 1,
) -> float:
    """
    Single QAT training step.
    """
    model.train()

    outputs = model(input_ids=input_ids, labels=labels)
    loss = outputs.loss / gradient_accumulation_steps
    loss.backward()

    return loss.item() * gradient_accumulation_steps


def train_qat_demo(
    model: nn.Module,
    tokenizer,
    device: torch.device,
    num_steps: int = 1000,
    learning_rate: float = 1e-5,
):
    """
    Demo QAT training loop.
    """
    # Sample training texts
    train_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Machine learning enables computers to learn from data.",
        "Neural networks are inspired by biological neurons.",
        "Deep learning has transformed artificial intelligence.",
        "Quantization reduces model size while preserving accuracy.",
    ]

    # Setup optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    model.train()
    losses = []

    print(f"Starting QAT demo training for {num_steps} steps...")

    for step in range(num_steps):
        # Get random training text
        text = train_texts[step % len(train_texts)]

        # Tokenize
        encodings = tokenizer(
            text,
            return_tensors='pt',
            truncation=True,
            max_length=128,
            padding='max_length',
        ).to(device)

        input_ids = encodings['input_ids']
        labels = input_ids.clone()

        # Training step
        optimizer.zero_grad()
        loss = train_qat_step(model, optimizer, input_ids, labels)
        optimizer.step()

        losses.append(loss)

        if (step + 1) % 5 == 0:
            print(f"Step {step + 1}/{num_steps}, Loss: {loss:.4f}")

    print(f"\nQAT demo training complete!")
    print(f"Initial loss: {losses[0]:.4f}, Final loss: {losses[-1]:.4f}")

    return losses


# Run demo training
losses = train_qat_demo(model_qat, tokenizer, device, num_steps=10)

Starting QAT demo training for 10 steps...
Step 5/10, Loss: 5.3194
Step 10/10, Loss: 1.0685

QAT demo training complete!
Initial loss: 1.1880, Final loss: 1.0685


## 13) Save/Load Quantized Checkpoint

In [14]:
def save_quantized_checkpoint(
    model: nn.Module,
    quantized_weights: dict,
    path: str,
    model_name: str = None,
):
    """
    Save quantized model checkpoint.
    """
    checkpoint = {
        'format': 'torchao_int4_per_tensor',
        'model_name': model_name,
        'quantization': {
            'bits': 4,
            'granularity': 'per_tensor',
            'symmetric': SYMMETRIC,
            'quant_min': QUANT_MIN,
            'quant_max': QUANT_MAX,
        },
        'model_state_dict': model.state_dict(),
        'quantized_weights': quantized_weights,
    }

    torch.save(checkpoint, path)
    print(f"Saved checkpoint to: {path}")


def load_quantized_checkpoint(path: str, model: nn.Module = None):
    """
    Load quantized model checkpoint.
    """
    checkpoint = torch.load(path, map_location='cpu')

    print(f"Loaded checkpoint format: {checkpoint['format']}")
    print(f"Model name: {checkpoint['model_name']}")
    print(f"Quantization config: {checkpoint['quantization']}")

    if model is not None:
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Loaded model state dict")

    return checkpoint


# Example save (uncomment to run)
# save_quantized_checkpoint(
#     model_ptq,
#     quant_weights,
#     'qwen3_int4_per_tensor.pt',
#     model_name=MODEL_NAME,
# )

## 14) Summary

This notebook demonstrated:

1. **Per-tensor int4 quantization** using TorchAO's low-level primitives (`quantize_affine`, `dequantize_affine`)
2. **PTQ (Post-Training Quantization)**: Direct weight quantization with reconstruction error analysis
3. **QAT (Quantization-Aware Training)**: Fake quantization with STE for training
4. **Inference**: Generation with quantized model

### Key Takeaways

- TorchAO's `Int4WeightOnlyConfig` only supports per-group quantization (group_size=32/64/128/256)
- For **true per-tensor** quantization, use `block_size = tuple(weight.shape)`
- Per-tensor has lower accuracy but better hardware compatibility
- QAT can help recover accuracy lost from aggressive quantization

### Next Steps

- Train with larger dataset for better QAT results
- Add LoRA for parameter-efficient fine-tuning
- Compare accuracy vs per-group quantization
- Export to target deployment format

In [15]:
print("Notebook complete!")

Notebook complete!
