In [1]:
import copy
import gc
import torch

from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

In [2]:
def quant_with_scale_dtype(model, scale_dtype, group_size):
    quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))

    if scale_dtype != torch.float32:
        for name, m in model.named_modules():
            if isinstance(m, torch.nn.Linear):
                original_weight_tensor = m.weight.data.original_weight_tensor
                new_scales = original_weight_tensor.layout_tensor.scale.to(scale_dtype).to(torch.float32)
                original_weight_tensor.layout_tensor.scale.copy_(new_scales)

    return model

def generate_with_scale_dtype(model_id, prompts, scale_dtype, group_size=256, max_new_tokens=100):
    generation_kwargs = {
        "do_sample": False,
        "max_new_tokens": max_new_tokens,
    }

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="cuda",
        torch_dtype=torch.float32,
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model = quant_with_scale_dtype(model, scale_dtype, group_size)

    base_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
    return [base_pipe(p, **generation_kwargs)[0]["generated_text"] for p in prompts]

In [None]:
responses = {}
prompts = ["Once upon a time,", "A script to print 1 through 10 in python is: ", "Q: What is the difference between integration and differentiation? A:", "Q: How does gradient descent work? A:"]

for dtype in [torch.float32, torch.float16, torch.bfloat16]:
    print(f"Scale DType: {dtype}")
    text = generate_with_scale_dtype("TinyLlama/TinyLlama-1.1B-Chat-v1.0", prompts, dtype, group_size=32)
    responses[dtype] = text

for f32s_text, f16s_text, bf16s_text in zip(responses[torch.float32], responses[torch.float16], responses[torch.bfloat16]):
    print(f" f32: {f32s_text}")
    print(" - - - - - - ")
    print(f" f16: {f16s_text}")
    print(" - - - - - - ")
    print(f" bf16: {bf16s_text}")
    print()
    print()

Scale DType: torch.float32
Scale DType: torch.float16
Scale DType: torch.bfloat16
