In [1]:
from hqq.core.quantize import (
    BaseQuantizeConfig, 
    HQQLinear
)
from hqq.models.hf.base import AutoHQQHFModel
from lm_eval.models.huggingface import HFLM

import gc
import lm_eval
import pandas as pd
import torch

In [2]:
def quantize(model: torch.nn.Module, group_size: int, scale_dtype: torch.dtype) -> torch.nn.Module:
    qconfig = BaseQuantizeConfig(
        nbits = 4,
        group_size = group_size,
        axis = 1,
        quant_zero = False,
    )

    # Use RTN quantization (no HQQ algorithm).
    qconfig["weight_quant_params"]["optimize"] = False

    AutoHQQHFModel.quantize_model(model, quant_config=qconfig, compute_dtype=torch.float32, device="cuda")

    # Simulate low-precision scale dtypes by round-trip conversion
    if scale_dtype != torch.float32:
        for linear in filter(lambda m: isinstance(m, HQQLinear), model.modules()):
            linear.meta['scale'] = linear.meta['scale'].to(scale_dtype).to(torch.float32)
            linear.meta['zero'] = linear.meta['zero'].to(scale_dtype).to(torch.float32)

    return model
    

In [6]:
group_sizes = [None, 32, 64, 128, 256]
scale_dtypes = [torch.float32, torch.float16, torch.bfloat16]

result_rows = []

for group_size in group_sizes:
    for scale_dtype in scale_dtypes if group_size is not None else [None]:
        print(f"Testing group_size={group_size}, scale_dtype={scale_dtype}...")
        
        hflm = HFLM(
            # pretrained = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
            pretrained = "meta-llama/Meta-Llama-3-8B-Instruct",
            device = "cuda",
            max_length = 2048,
        )

        if group_size is not None and scale_dtype is not None:
            quantize(hflm.model, group_size = group_size, scale_dtype = scale_dtype)

        gc.collect()
        torch.cuda.empty_cache()
        
        results = lm_eval.simple_evaluate(
            model = hflm,
            tasks = ["wikitext"],
            num_fewshot = 0,
        )['results']
        
        del hflm
        
        result_rows.append({
            "group_size": group_size,
            "scale_dtype": scale_dtype,
            "word_perplexity": results['wikitext']['word_perplexity,none']
        })

        print(result_rows[-1])

result_df = pd.DataFrame.from_records(result_rows)
result_df

2024-06-25:01:50:40,617 INFO     [huggingface.py:162] Using device 'cuda'


Testing group_size=None, scale_dtype=None...




Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2024-06-25:01:50:45,105 INFO     [evaluator.py:131] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234
Repo card metadata block was not found. Setting CardData to empty.
2024-06-25:01:50:49,936 INFO     [task.py:395] Building contexts for wikitext on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:00<00:00, 1137.13it/s]
2024-06-25:01:50:49,994 INFO     [evaluator.py:362] Running loglikelihood_rolling requests
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

{'group_size': None, 'scale_dtype': None, 'word_perplexity': 10.807844752731176}


Unnamed: 0,group_size,scale_dtype,word_perplexity
0,,,10.807845


In [4]:
result_df = pd.DataFrame.from_records(result_rows)
result_df.pivot_table(index='scale_dtype', columns='group_size', sort=False)

Unnamed: 0_level_0,word_perplexity,word_perplexity,word_perplexity,word_perplexity
group_size,32,64,128,256
scale_dtype,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
torch.float32,11.057797,11.121785,11.337771,11.342766
torch.float16,11.05811,11.121726,11.337575,11.342822
torch.bfloat16,11.054658,11.125918,11.338817,11.338682


In [5]:
result_df

Unnamed: 0,group_size,scale_dtype,word_perplexity
0,32,torch.float32,11.057797
1,32,torch.float16,11.05811
2,32,torch.bfloat16,11.054658
3,64,torch.float32,11.121785
4,64,torch.float16,11.121726
5,64,torch.bfloat16,11.125918
6,128,torch.float32,11.337771
7,128,torch.float16,11.337575
8,128,torch.bfloat16,11.338817
9,256,torch.float32,11.342766
