In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from baa import (
    QuantizedLinearLayerWithActivation,
    replace_linear_layer_with_activation,
    register_linear_layer_forward_hook,
    device_map,
    get_hidden_states_input,
    remove_all_hooks,
    chat_with_model,
    AccuracyBenchmark,
)
from baa.singletons import hidden_states
from datasets import load_dataset
import torch

In [None]:
model_name = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
response = chat_with_model(model, tokenizer, "Hi there how are you?")

In [None]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
benchmark = AccuracyBenchmark(model, tokenizer, dataset)

In [None]:
with torch.no_grad():
    register_linear_layer_forward_hook(model, get_hidden_states_input)
    print("Original model accuracy:", benchmark.evaluate(sample_size=200))
    print(hidden_states)

    replace_linear_layer_with_activation(
        base_model=model,
        quantizer_class=QuantizedLinearLayerWithActivation,
        # hidden_states=hidden_states,
        exclude_list=[],
        quantized=True,
    )
    remove_all_hooks(model)

    torch.cuda.empty_cache()
    print("Quantized model accuracy:", benchmark.evaluate(sample_size=200))


In [None]:
response = chat_with_model(model, tokenizer, "Hi there how are you?")