In [5]:
from hqq.core.quantize import (
    BaseQuantizeConfig, 
    HQQLinear
)
from hqq.models.hf.base import AutoHQQHFModel
from lm_eval.models.huggingface import HFLM

import gc
import lm_eval
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

%matplotlib inline

In [2]:
def quantize(model: torch.nn.Module, group_size: int, scale_dtype: torch.dtype) -> torch.nn.Module:
    qconfig = BaseQuantizeConfig(
        nbits = 4,
        group_size = group_size,
        axis = 1,
        quant_zero = False,
    )

    # Use RTN quantization (no HQQ algorithm).
    qconfig["weight_quant_params"]["optimize"] = False

    AutoHQQHFModel.quantize_model(model, quant_config=qconfig, compute_dtype=torch.float32, device="cuda")

    # Simulate low-precision scale dtypes by round-trip conversion
    if scale_dtype != torch.float32:
        for linear in filter(lambda m: isinstance(m, HQQLinear), model.modules()):
            linear.meta['scale'] = linear.meta['scale'].to(scale_dtype).to(torch.float32)
            linear.meta['zero'] = linear.meta['zero'].to(scale_dtype).to(torch.float32)

    return model
    

In [None]:
#group_sizes = [None, 32, 64, 128, 256]
group_sizes = [32]
scale_dtypes = [torch.float32] #, torch.float16, torch.bfloat16]

result_rows = []

for group_size in group_sizes:
    for scale_dtype in scale_dtypes if group_size is not None else [None]:
        print(f"Testing group_size={group_size}, scale_dtype={scale_dtype}...")
        
        hflm = HFLM(
            #pretrained = "microsoft/Phi-3-mini-4k-instruct",
            pretrained = "google/gemma-2b",
            device = "cuda",
            max_length = 2048,
        )

        if group_size is not None and scale_dtype is not None:
            quantize(hflm.model, group_size = group_size, scale_dtype = scale_dtype)

        gc.collect()
        torch.cuda.empty_cache()

        all_scales = torch.concat([linear.meta['scale'].flatten() for linear in filter(lambda m: isinstance(m, HQQLinear), hflm.model.modules())])
        counts, bins = np.histogram(all_scales.cpu())
        plt.stairs(counts, bins)
        plt.show()

        print(pd.DataFrame([all_scales.cpu()]).describe())
        
        results = lm_eval.simple_evaluate(
            model = hflm,
            tasks = ["wikitext"],
            num_fewshot = 0,
        )['results']
        
        del hflm
        
        result_rows.append({
            "group_size": group_size,
            "scale_dtype": scale_dtype,
            "word_perplexity": results['wikitext']['word_perplexity,none']
        })

        print(result_rows[-1])

result_df = pd.DataFrame.from_records(result_rows)
result_df

In [4]:
result_df = pd.DataFrame.from_records(result_rows)
result_df.pivot_table(index='scale_dtype', columns='group_size', sort=False)

Unnamed: 0_level_0,word_perplexity,word_perplexity,word_perplexity,word_perplexity
group_size,32.0,64.0,128.0,256.0
scale_dtype,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
torch.float32,16.600419,18.161994,18.835378,19.852184
torch.float16,16.598585,18.159762,18.833922,19.848994
torch.bfloat16,16.614811,18.143062,18.842773,19.871647


In [5]:
result_df

Unnamed: 0,group_size,scale_dtype,word_perplexity
0,,,15.937838
1,32.0,torch.float32,16.600419
2,32.0,torch.float16,16.598585
3,32.0,torch.bfloat16,16.614811
4,64.0,torch.float32,18.161994
5,64.0,torch.float16,18.159762
6,64.0,torch.bfloat16,18.143062
7,128.0,torch.float32,18.835378
8,128.0,torch.float16,18.833922
9,128.0,torch.bfloat16,18.842773
