In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from optimum.quanto import quantize, qint8, freeze, qfloat8
import torch
from typing import List
from baa import PerplexityBenchmark, get_llm_memory_usage, device_map
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

In [None]:
model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
# model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
# model_name = "meta-llama/Llama-3.2-3B-Instruct

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def run_benchmark(quantized=False):
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    benchmark = PerplexityBenchmark(model, tokenizer, dataset)
    print(
        f"{'Quantized Activations' if quantized else ''} Model perplexity: {benchmark.evaluate(sample_size=200):.2f}"
    )

In [None]:
dataset_train = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

In [None]:
quantize(model, weights=qint8, activations=qint8)
freeze(model)
run_benchmark(quantized=False)

In [None]:
dataset_train = dataset_train.select(range(5000))
sft_config = SFTConfig(
    dataset_text_field="text",
    max_seq_length=512,
    output_dir="sft_output",
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset_train,
    args=sft_config,
)

trainer.train()

In [None]:
run_benchmark(quantized=True)

In [None]:
model.model.layers[0].self_attn.q_proj.weight[0].shape

In [None]:
model.model.layers[0].self_attn.q_proj.weight[1].shape