In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from baa import PerplexityBenchmark, get_memory_usage
from datasets import load_dataset
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
# model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
model_name = "meta-llama/Llama-3.2-3B-Instruct"
# model_name = "meta-llama/Llama-3.1-8B-Instruct"

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

In [None]:
print(model)

In [None]:
component_memory_usage = {
    "Full Model": get_memory_usage(model.model) / 1024**2,
    "MLP": get_memory_usage(model.model.layers[0].mlp)
    / 1024**2
    * len(model.model.layers),
    "Self Attention": get_memory_usage(model.model.layers[0].self_attn)
    / 1024**2
    * len(model.model.layers),
    "Model Head": get_memory_usage(model.lm_head) / 1024**2,
    # "Decoder Blocks": get_memory_usage(model.model.layers[:]) / 1024**2,
}

for component, usage in component_memory_usage.items():
    print(f"{component}: {usage:.2f} MB")

In [None]:
# apply arial font, great font sizes, set dpi to 300
mpl.rc("font", size=16)
mpl.rc("axes", titlesize=16)
mpl.rc("axes", labelsize=16)
mpl.rc("xtick", labelsize=16)
mpl.rc("ytick", labelsize=16)
mpl.rc("legend", fontsize=14)
mpl.rc("figure", dpi=300)

In [None]:
labels = []
memory_usage = []

for component, usage in component_memory_usage.items():
    labels.append(component)
    memory_usage.append(usage)

# assert that embedding + attention heads + model head ~= full model
assert sum(memory_usage[1:4]) - memory_usage[0] < 30 * 1024**2

# Generate colors for each category (index-based)
# cmap = matplotlib.colormaps["tab10"]
# colors = [cmap(i) for i in range(len(component_memory_usage))]
colors = ["tab:blue", "tab:green", "tab:red", "tab:orange"]
plt.figure(figsize=(10, 6))
plt.bar(labels, memory_usage, color=colors)
# show numbers on top of bars
for i, usage in enumerate(memory_usage):
    plt.text(i, usage, f"{usage:.0f} MB", ha="center", va="bottom", fontsize=14)
plt.ylabel("Memory Usage (MB)")
plt.title("Memory Usage of Model Components - Llama 3.2 3B Instruct", fontweight="bold")
plt.xticks(rotation=45)

# save to visualizations folder
plt.savefig("visualizations/memory_usage-llama3.2-3b-instruct.png", bbox_inches="tight")

plt.show()