In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from baa import PerplexityBenchmark, get_llm_memory_usage, device_map
from datasets import load_dataset

In [None]:
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
quantized_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map,
)
franken_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
print(f"Model memory usage: {get_llm_memory_usage(model) / 1024 ** 2:.2f} MB")
print(
    f"Quantized model memory usage: {get_llm_memory_usage(quantized_model) / 1024 ** 2:.2f} MB"
)

## Benchmarking both models based on the same dataset with Perplexity

In [None]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

benchmark = PerplexityBenchmark(model=model, tokenizer=tokenizer, dataset=dataset)
quantized_benchmark = PerplexityBenchmark(
    model=quantized_model,
    tokenizer=tokenizer,
    dataset=dataset,
)

print(f"Original model perplexity: {benchmark.evaluate(sample_size=200):.2f}")
print(
    f"Quantized model perplexity: {quantized_benchmark.evaluate(sample_size=200):.2f}"
)

## Frankensteining the two models together

In [None]:
# replace embedding
franken_model.model.embed_tokens = quantized_model.model.embed_tokens

# run benchmark again
benchmark.model = franken_model
print(
    f"Quantized model perplexity after embedding replacement: {benchmark.evaluate(sample_size=200):.2f}"
)

In [None]:
# replace first 10 attention heads
for i in range(30):
    franken_model.model.layers[i].self_attn = quantized_model.model.layers[i].self_attn

# run benchmark again
benchmark.model = franken_model
print(
    f"Quantized model perplexity after attention head replacement: {benchmark.evaluate(sample_size=200):.2f}"
)

In [None]:
# show frankenmodel memory usage
print(
    f"Frankenmodel memory usage: {get_llm_memory_usage(franken_model) / 1024 ** 2:.2f} MB"
)