# Imports

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
from baa import (
    QuantizedLinearLayerWithActivation,
    replace_linear_layer_with_activation,
    register_linear_layer_forward_hook,
    device_map,
    get_hidden_states_input,
    get_weights,
    add_custom_name_to_linear_layers,
    remove_all_hooks,
    chat_with_model,
    print_memory_usage,
    AccuracyBenchmark,
)
from baa.singletons import hidden_states, names
from datasets import load_dataset
import torch
from dotenv import load_dotenv
import gc

load_dotenv()

# Load Model

In [None]:
def delete_model():
    global model
    try:
        model.to("cpu")
        del model
        gc.collect()
        torch.cuda.empty_cache()
    except:
        pass


def reload_original_model():
    delete_model()
    global model
    global tokenizer
    global model_name
    # model_name = "HuggingFaceTB/SmolLM-135M"
    model_name = "meta-llama/Llama-3.2-3B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
reload_original_model()

In [None]:
model

# Load Benchmark

In [None]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
benchmark = AccuracyBenchmark(model, tokenizer, dataset)

## Run Benchmark on Original Model

In [None]:
original_model_accuracy = benchmark.evaluate(sample_size=200)
print("Original Model Accuracy:", original_model_accuracy)

In [None]:
global plot_data_dict
plot_data_dict = {}

In [None]:
def plot_final_data(bits=[2, 3, 4, 5, 6, 8][::-1]):
    global plot_data_dict
    labels = []
    traces = []
    for key, values in plot_data_dict.items():
        bits, accuracies = values
        labels.append(key)
        traces.append((bits, accuracies))
    for i in range(len(traces)):
        plt.plot(traces[i][0], traces[i][1], "o-")
    # add original model accuracy as dotted line
    plt.axhline(y=original_model_accuracy, color="r", linestyle="--")
    plt.xlabel("Bits")
    plt.xticks(bits)
    plt.ylabel("Accuracy")
    plt.ylim(0, 1)  # Set y-axis limits from 0 to 1
    plt.grid(True)
    plt.title("Accuracy vs Bits")
    plt.legend(labels)
    plt.show()

def plot_final_data(bits=[2, 3, 4, 5, 6, 8][::-1]):
    global plot_data_dict
    labels = []
    traces = []
    colors = ["blue", "green", "red", "orange"]
    labels.append("Original Model")
    plt.axhline(y=original_model_accuracy, color="purple", linestyle="--")
    # display value of axhline slightly above the line
    plt.text(2, original_model_accuracy + 0.01, original_model_accuracy, color="purple")
    for key, values in plot_data_dict.items():
        bits, accuracies = values
        labels.append(key)
        traces.append((bits, accuracies))
    for i in range(len(traces)):
        plt.plot(traces[i][0], traces[i][1], "o-", color=colors[i])
    # add original model accuracy as dotted line
    plt.xlabel("Bits")
    plt.xticks(bits)
    plt.ylabel("Accuracy")
    plt.ylim(0, 1)  # Set y-axis limits from 0 to 1
    plt.grid(True)
    plt.title("Accuracy vs Bits - SmolLM 135M")
    plt.legend(labels)
    plt.show()


def plot_accuracies(bits, accuracies, title="Accuracy vs Bits"):
    global plot_data_dict
    plot_data_dict[title] = (bits, accuracies)
    # add original model accuracy as star marker
    plt.plot(bits, accuracies, "o-")
    plt.xlabel("Bits")
    plt.xticks(bits)
    plt.ylabel("Accuracy")
    plt.ylim(0, 1)  # Set y-axis limits from 0 to 1
    plt.grid(True)
    plt.title(title)
    plt.show()

# Full Model Quantization

In [None]:
def quantize_full_model(bits: int, model=model):
    add_custom_name_to_linear_layers(model)
    replace_linear_layer_with_activation(
        base_model=model,
        quantizer_class=QuantizedLinearLayerWithActivation,
        weight_bits=bits,
        activation_bits=bits,
        exclude_list=[],
        quantized=True,
    )
    torch.cuda.empty_cache()


bit_selection = [2, 3, 4, 5, 6, 8][::-1]
accuracies = []

for bits in bit_selection:
    reload_original_model()
    quantize_full_model(bits, model)
    benchmark.model = model
    with torch.inference_mode():
        accuracy = benchmark.evaluate(sample_size=200)
        accuracies.append(accuracy)
        print(
            "Quantized Model Accuracy with",
            bits,
            "bits:",
            accuracy,
        )

plot_accuracies(bit_selection, accuracies, title="Full Model Quantization")

# Quantize Attention Only

In [None]:
def quantize_attention(bits: int, model=model):
    add_custom_name_to_linear_layers(model)
    exclude_list = [name for name in names if "self_attn" not in name]
    replace_linear_layer_with_activation(
        base_model=model,
        quantizer_class=QuantizedLinearLayerWithActivation,
        weight_bits=bits,
        activation_bits=bits,
        exclude_list=exclude_list,
        quantized=True,
    )
    torch.cuda.empty_cache()


bit_selection = [2, 3, 4, 5, 6, 8][::-1]
accuracies = []

for bits in bit_selection:
    reload_original_model()
    quantize_attention(bits, model)
    benchmark.model = model
    with torch.inference_mode():
        accuracy = benchmark.evaluate(sample_size=200)
        accuracies.append(accuracy)
        print(
            "Quantized Model Accuracy with",
            bits,
            "bits:",
            accuracy,
        )
plot_accuracies(bit_selection, accuracies, title="Attention Quantization")

# Quantize MLP Only

In [None]:
def quantize_mlp(bits: int, model=model):
    add_custom_name_to_linear_layers(model)
    exclude_list = [name for name in names if "mlp" not in name]
    replace_linear_layer_with_activation(
        base_model=model,
        quantizer_class=QuantizedLinearLayerWithActivation,
        weight_bits=bits,
        activation_bits=bits,
        exclude_list=exclude_list,
        quantized=True,
    )
    torch.cuda.empty_cache()


bit_selection = [2, 3, 4, 5, 6, 8][::-1]
accuracies = []

for bits in bit_selection:
    reload_original_model()
    quantize_mlp(bits, model)
    benchmark.model = model
    with torch.inference_mode():
        accuracy = benchmark.evaluate(sample_size=200)
        accuracies.append(accuracy)
        print(
            "Quantized Model Accuracy with",
            bits,
            "bits:",
            accuracy,
        )
plot_accuracies(bit_selection, accuracies, title="MLP Quantization")

# Quantize LM Head Only

In [None]:
def quantize_lm_head(bits: int, model=model):
    add_custom_name_to_linear_layers(model)
    exclude_list = [name for name in names if "lm_head" not in name]
    replace_linear_layer_with_activation(
        base_model=model,
        quantizer_class=QuantizedLinearLayerWithActivation,
        weight_bits=bits,
        activation_bits=bits,
        exclude_list=exclude_list,
        quantized=True,
    )
    torch.cuda.empty_cache()


bit_selection = [2, 3, 4, 5, 6, 8][::-1]
accuracies = []

for bits in bit_selection:
    reload_original_model()
    quantize_lm_head(bits, model)
    benchmark.model = model
    with torch.inference_mode():
        accuracy = benchmark.evaluate(sample_size=200)
        accuracies.append(accuracy)
        print(
            "Quantized Model Accuracy with",
            bits,
            "bits:",
            accuracy,
        )
plot_accuracies(bit_selection, accuracies, title="LM Head Quantization")

In [None]:
tmp_accuracy = original_model_accuracy
tmp_dict = plot_data_dict.copy()

original_model_accuracy = round(original_model_accuracy, 2)
plot_data_dict["Self Attn Quantization"] = plot_data_dict.pop("Attention Quantization")

In [None]:
plot_final_data()