In [None]:
import json
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
from pprint import pprint

In [None]:
# retrieve original model benchmarks
folder_path_original_benchmarks = "logs/original_model_benchmarks"
original_benchmarks = {}
for filename in os.listdir(folder_path_original_benchmarks):
    with open(os.path.join(folder_path_original_benchmarks, filename), "r") as f:
        raw_data = json.load(f)
        model_name = raw_data["model_name"]
        original_benchmarks[model_name] = raw_data["original_model_benchmarks"]


In [None]:
def generate_data() -> dict:
    data = {}
    folder_path = "logs/first_experiment"
    for filename in os.listdir(folder_path):
        # Check if the file is a JSON file
        if filename.endswith(".json"):
            # Construct the full path of the file
            file_path = os.path.join(folder_path, filename)

            # Open and parse the JSON file
            with open(file_path, "r") as file:
                model_data = json.load(file)

            # Store the parsed data in a dictionary using the filename as the key
            model_name = model_data["model_name"]
            original_model_benchmark = original_benchmarks[model_name]
            original_model_accuracy = original_model_benchmark["wikitext_accuracy"]
            include_component = model_data["include_component"]
            weight_bits = model_data["weight_bits"]
            quantized_model_benchmarks = model_data["quantized_model_benchmarks"]
            quantized_model_accuracy = quantized_model_benchmarks["wikitext_accuracy"]
            quantization_data = {
                "include_component": include_component,
                "weight_bits": weight_bits,
                "quantized_model_benchmarks": quantized_model_benchmarks,
                "quantized_model_accuracy": quantized_model_accuracy,
            }
            # If the model name does not exist in the dictionary, add it
            if model_name not in data:
                data[model_name] = {
                    "original_model_accuracy": original_model_accuracy,
                }

            if include_component not in data[model_name]:
                data[model_name][include_component] = {}

            # Add the quantization data to the dictionary
            data[model_name][include_component][weight_bits] = quantization_data
    return data


data = generate_data()

In [None]:
pprint(data)

In [None]:
# Create Subplots per model
# For each quantile range [(0.0, 1.0), (0.01, 0.99), (0.05, 0.95)]
# x-axis is the sqnr dB with labels and error_threshold as value
# y-axis is the model quantized model accuracy

fig, ax = plt.subplots(1, len(data), figsize=(20, 5))

models_in_correct_order = [
    "HuggingFaceTB/SmolLM-135M-Instruct",
    "meta-llama/Llama-3.2-3B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
]
# sort the models in the correct order
data = {k: data[k] for k in models_in_correct_order}
for i, (model_name, model_data) in enumerate(data.items()):
    original_model_accuracy = model_data["original_model_accuracy"]
    model_data.pop("original_model_accuracy")
    lines = []

    for component_name, component_data in model_data.items():
        x = []
        y = []

        for bit_width, bit_width_data in component_data.items():
            weight_bit = bit_width
            quantized_model_benchmarks = bit_width_data["quantized_model_benchmarks"]
            quantized_model_accuracy = quantized_model_benchmarks["wikitext_accuracy"]

            x.append(bit_width)
            y.append(quantized_model_accuracy)

        # Sort the x and y values based on the x values
        x, y = zip(*sorted(zip(x, y)))

        # (line,) = ax[i].plot(x, y, label=f"{quantile_range} quantile range")
        # plot with markers
        (line,) = ax[i].plot(
            x,
            y,
            marker="o",
            label=f"{component_name}",
        )
        lines.append(line)

    ax[i].axhline(y=original_model_accuracy, color="r", linestyle="--")
    # add horizontal line to legend
    lines.append(
        mpl.lines.Line2D(
            [0], [0], color="r", linestyle="--", label="Original Model Accuracy"
        )
    )

    # use 0 to .60 on y axis in steps of 0.05
    ax[i].yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.05))
    # set axis limits
    ax[i].set_ylim(0, 0.6)

    ax[i].set_title(model_name)
    ax[i].set_xlabel("Bit Width")
    ax[i].set_ylabel("Quantized Model Accuracy")
    # set location of legend down right
    ax[i].legend(handles=lines, loc="lower right")
# sort plots by moving the last to first plot
plt.tight_layout()
plt.show()

In [None]:
data = generate_data()

In [None]:
# Create Subplots per model
# For each quantile range [(0.0, 1.0), (0.01, 0.99), (0.05, 0.95)]
# x-axis is the sqnr dB with labels and error_threshold as value
# y-axis is the model quantized model accuracy

fig, ax = plt.subplots(1, len(data), figsize=(20, 5))

models_in_correct_order = [
    "HuggingFaceTB/SmolLM-135M-Instruct",
    "meta-llama/Llama-3.2-3B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
]
# sort the models in the correct order
data = {k: data[k] for k in models_in_correct_order}
for i, (model_name, model_data) in enumerate(data.items()):
    original_model_accuracy = original_benchmarks[model_name]["mmlu_results"][
        "overall_score"
    ]
    model_data.pop("original_model_accuracy")
    lines = []

    for component_name, component_data in model_data.items():
        x = []
        y = []

        for bit_width, bit_width_data in component_data.items():
            weight_bit = bit_width
            quantized_model_benchmarks = bit_width_data["quantized_model_benchmarks"]
            quantized_model_accuracy = quantized_model_benchmarks["mmlu_results"][
                "overall_score"
            ]

            x.append(bit_width)
            y.append(quantized_model_accuracy)

        # Sort the x and y values based on the x values
        x, y = zip(*sorted(zip(x, y)))

        # (line,) = ax[i].plot(x, y, label=f"{quantile_range} quantile range")
        # plot with markers
        (line,) = ax[i].plot(
            x,
            y,
            marker="o",
            label=f"{component_name}",
        )
        lines.append(line)

    ax[i].axhline(y=original_model_accuracy, color="r", linestyle="--")
    ax[i].axhline(y=0.25, color="violet", linestyle="--")
    # add horizontal line to legend
    lines.append(
        mpl.lines.Line2D(
            [0], [0], color="r", linestyle="--", label="Original Model Accuracy"
        )
    )
    lines.append(
        mpl.lines.Line2D([0], [1], color="violet", linestyle="--", label="Baseline")
    )

    # use 0 to .60 on y axis in steps of 0.05
    ax[i].yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.05))
    # set axis limits
    ax[i].set_ylim(0, 0.7)

    ax[i].set_title(model_name)
    ax[i].set_xlabel("Bit Width")
    ax[i].set_ylabel("Quantized Model Accuracy on MMLU tasks")
    # set location of legend down right
    ax[i].legend(handles=lines, loc="lower right")
# sort plots by moving the last to first plot
plt.tight_layout()
plt.show()