In [28]:
import torch
import transformers
import rich.table
import rich

In [54]:
MODEL_NAME = "google/flan-t5-base"

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)

if config.is_encoder_decoder:
    cls = transformers.AutoModelForSeq2SeqLM
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
else:
    cls = transformers.AutoModelForCausalLM

In [55]:
print("model_int8_config")
model_int8_config = cls.from_pretrained(
    MODEL_NAME, 
    quantization_config=transformers.BitsAndBytesConfig(load_in_8bit=True),
    device_map="auto", 
    torch_dtype=torch.float16,
)
print("model_int8")
model_int8 = cls.from_pretrained(
    MODEL_NAME, 
    load_in_8bit=True, 
    device_map="auto", 
    torch_dtype=torch.float16,
)
print("model_float32")
model_float32 = cls.from_pretrained(MODEL_NAME).cuda()
print("model_bfloat16")
model_bfloat16 = cls.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16).cuda()
print("model_float16")
model_float16 = cls.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).cuda()

models_by_name = dict(
    model_int8=model_int8,
    model_int8_config=model_int8_config,
    model_bfloat16=model_bfloat16,
    model_float16=model_float16,
    model_float32=model_float32,
)

model_int8_config
model_int8
model_float32
model_bfloat16
model_float16


In [56]:
sample_text = "Question: What is the color of the moon? Answer: "
sample = tokenizer(sample_text, return_tensors="pt").to(0)

table = rich.table.Table("[purple]Model name", "[purple]Ce", "[purple]Generation", title="Precision Test", show_lines=True)

rich.print(f"[purple bold]Sample text: [/]{sample_text}")
for name, model in models_by_name.items():
    ce = model(**sample, labels=sample.input_ids).loss
    
    if model.config.pad_token_id is None:
        model.config.pad_token_id = model.config.eos_token_id
    
    gen_ids = model.generate(**sample, max_new_tokens=20)
    
    if not model.config.is_encoder_decoder:
        # The output of a causal model also includes the input ids
        # so we need to remove them.
        gen_ids = gen_ids[0, sample.input_ids.shape[-1]:]
    generation = tokenizer.decode(gen_ids).strip().replace("\n", " ")
    
    table.add_row(name, f"{ce:0.3f}", generation)

rich.print(table)


: 