In [1]:
import copy
import gc
import torch

from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

In [2]:
# Copied from torchao - modified to allow scale_dtype != act_dtype
from torchao.quantization.quant_primitives import (
    choose_qparams_affine,
    quantize_affine,
    dequantize_affine,
    ZeroPointDomain,
    MappingType,
    int_scaled_matmul,
    quantize_affine_hqq,
    FP8_TYPES,
    choose_qparams_affine_fpx,
    quantize_affine_fpx,
    dequantize_affine_fpx,
)

def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
    mapping_type = MappingType.ASYMMETRIC
    target_dtype = torch.int8
    return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype)

def apply_int8_dynamic_activation_int4_weight_quant_custom(weight, group_size=32, scale_dtype=torch.float32):
    if weight.shape[-1] % group_size != 0:
        return weight

    # weight settings
    mapping_type = MappingType.SYMMETRIC
    block_size = (1, group_size)
    target_dtype = torch.int8
    eps = torch.finfo(torch.float32).eps
    quant_min = -8
    quant_max = 7

    # input settings
    input_quant_func = _int8_asymm_per_token_quant

    weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype=scale_dtype)
    weight = to_linear_activation_quantized(weight, input_quant_func)
    return weight

def int8_dynamic_activation_int4_weight_custom(group_size=32, scale_dtype=torch.float32):
    def insert_subclass(lin):
        lin.weight = torch.nn.Parameter(apply_int8_dynamic_activation_int4_weight_quant_custom(lin.weight, group_size, scale_dtype), requires_grad=False)
        return lin

    return insert_subclass

In [3]:
def quant_with_scale_dtype(model, scale_dtype, group_size):
    qmodel = copy.deepcopy(model)
    #quantize_(qmodel, int8_dynamic_activation_int4_weight_custom(group_size=group_size, scale_dtype=scale_dtype))
    quantize_(qmodel, int8_dynamic_activation_int4_weight(group_size=group_size))
    
    if scale_dtype != torch.float32:
        for name, m in qmodel.named_modules():
            if isinstance(m, torch.nn.Linear):
                original_weight_tensor = m.weight.data.original_weight_tensor
                new_scales = original_weight_tensor.layout_tensor.scale.to(scale_dtype).to(torch.float32)
                original_weight_tensor.layout_tensor.scale.copy_(new_scales)

    return qmodel

def record_activations(model, run_func):
    with torch.no_grad():
        recorded_activations = {}
    
        def hook(module, hook_in, hook_out, key):
            hook_out = hook_out.to("cpu")
            if key in recorded_activations:
                recorded_activations[key].append(hook_out)
            else:
                recorded_activations[key] = [hook_out]
    
        def make_hook(key):
            return lambda m, i, o: hook(m, i, o, key)
    
        # Set hooks on all linear modules
        hooks = []
        for name, mod in model.named_modules():
            if isinstance(mod, torch.nn.Linear):
                hooks.append(mod.register_forward_hook(make_hook(name)))
    
        # Run forward pass
        outputs = run_func(input)
        print(outputs)
    
        # Clear hooks
        for h in hooks:
            h.remove()
    
        return recorded_activations

def record_activations_by_dtype(model_id, group_size=32, max_new_tokens=1):
    prompt = "Once upon a time, "
    generation_kwargs = {
        "do_sample": False,
        "max_new_tokens": max_new_tokens,
    }

    base_model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="cuda",
        torch_dtype=torch.float32,
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    base_pipe = pipeline("text-generation", model=base_model, tokenizer=tokenizer)
    base_activations = record_activations(base_pipe.model, lambda _: base_pipe(prompt, **generation_kwargs))

    del base_model

    results = {
        "base": base_activations,
    }

    cases = [
        ("f32s", torch.float32),
        ("f16s", torch.float16),
        ("bf16s", torch.bfloat16),
    ]

    for key, dtype in cases:
        base_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="cuda",
            torch_dtype=torch.float32,
            trust_remote_code=True,
        )
        
        quant_model = quant_with_scale_dtype(base_model, dtype, group_size)
        del base_model
        gc.collect()
        
        quant_pipe = pipeline("text-generation", model=quant_model, tokenizer=tokenizer)
        quant_activations = record_activations(quant_model, lambda _: quant_pipe(prompt, **generation_kwargs))
        results[key] = quant_activations
        del quant_model
        gc.collect()
        

    return results

def compare_activations(base_acts, f32s_acts, quant_acts):
    err1 = f32s_acts - base_acts
    signed_err1 = err1.abs()

    err2 = quant_acts - base_acts
    signed_err2 = err2.abs()
    
    return [
        (signed_err1.mean(), signed_err1.max(), torch.linalg.vector_norm(signed_err1)),
        (signed_err2.mean(), signed_err2.max(), torch.linalg.vector_norm(signed_err2)),
    ]    

In [None]:
acts = record_activations_by_dtype("google/gemma-2-2b")

In [None]:
for key in acts["base"].keys():
    print(f"[{key}]")
    base_acts = acts["base"][key]
    f32s_acts = acts["f32s"][key]
    
    for dtypes in ["f16s", "bf16s"]:
        print(f"  [{dtypes}]")
        quant_acts = acts[dtypes][key]
        for i in range(len(quant_acts)):
            err_info = compare_activations(base_acts[i], f32s_acts[i], quant_acts[i])
            print(f"    {err_info[0]}")
            print(f"    {err_info[1]}")