In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from baa import (
    QuantizedLinearLayerWithActivation,
    replace_linear_layer_with_activation,
    register_linear_layer_forward_hook,
    device_map,
    get_hidden_states_input,
    get_weights,
    add_custom_name_to_linear_layers,
    remove_all_hooks,
    chat_with_model,
    print_memory_usage,
    AccuracyBenchmark,
)
from baa.singletons import hidden_states, names
from datasets import load_dataset
import torch
from dotenv import load_dotenv

load_dotenv()

In [None]:
# model_name = "meta-llama/Llama-3.2-3B-Instruct"
model_name = "HuggingFaceTB/SmolLM-135M"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
model

In [None]:
# response = chat_with_model(model, tokenizer, "Hi there how are you?")

In [None]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
benchmark = AccuracyBenchmark(model, tokenizer, dataset)

In [None]:
with torch.inference_mode():
    # register_linear_layer_forward_hook(model, get_hidden_states_input)
    print("Original model accuracy:", benchmark.evaluate(sample_size=200))
    # print(hidden_states)
    layers = []
    # add elemnt of names to string if element is not in string "mlp"
    add_custom_name_to_linear_layers(model)
    exclude_list = [name for name in names if "mlp" not in name]
    print("exclude_list:", exclude_list)
    print(
        f"hidden_states is empty: {not bool(hidden_states)}"
    )  # empty dicts resolve to False
    replace_linear_layer_with_activation(
        base_model=model,
        quantizer_class=QuantizedLinearLayerWithActivation,
        weight_bits=4,
        activation_bits=4,
        exclude_list=exclude_list,
        quantized=True,
    )
    remove_all_hooks(model)

    torch.cuda.empty_cache()

In [None]:
model

In [None]:
with torch.inference_mode():
    print("Quantized model accuracy:", benchmark.evaluate(sample_size=200))


In [None]:
response = chat_with_model(model, tokenizer, "Hi there how are you?")

In [None]:
print(model.model.layers[0].self_attn.k_proj.weight.max())
print(model.model.layers[0].self_attn.k_proj.weight.min())