In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from baa import PerplexityBenchmark, get_memory_usage
from datasets import load_dataset

In [None]:
# model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
model_name = "meta-llama/Llama-3.2-3B-Instruct"
# model_name = "meta-llama/Llama-3.1-8B-Instruct"
# model_name = "google/gemma-2b-it"
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True
)

In [None]:
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    # quantization_config=bnb_config,
    device_map="auto",
)
# franken_model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     device_map=device_map,
# )
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
print(model)

In [None]:
component_memory_usage = {
    "Full Model": get_memory_usage(model.model) / 1024**2,
    "MLP": get_memory_usage(model.model.layers[0].mlp)
    / 1024**2
    * len(model.model.layers),
    "Self Attention": get_memory_usage(model.model.layers[0].self_attn)
    / 1024**2
    * len(model.model.layers),
    "Model Head": get_memory_usage(model.lm_head) / 1024**2,
    # "Decoder Blocks": get_memory_usage(model.model.layers[:]) / 1024**2,
}

for component, usage in component_memory_usage.items():
    print(f"{component}: {usage:.2f} MB")

In [None]:
import matplotlib.pyplot as plt

labels = []
memory_usage = []

for component, usage in component_memory_usage.items():
    labels.append(component)
    memory_usage.append(usage)

# assert that embedding + attention heads + model head ~= full model
assert sum(memory_usage[1:4]) - memory_usage[0] < 30 * 1024**2

# Plotting the histogram
plt.figure(figsize=(10, 6))
plt.bar(labels, memory_usage, color=["tab:blue", "tab:green", "tab:orange", "tab:red"])
plt.ylabel("Memory Usage (MB)")
plt.title("Memory Usage of Model Components - Llama 3.2 3B")
plt.xticks(rotation=45)
# make the two last bars purple
# plt.bar(labels[3:], memory_usage[3:], color="purple")
# add two lines from attention heads to  selft attention and mlp
# plt.plot(
#     [labels[2], labels[3]],
#     [memory_usage[2], memory_usage[3]],
#     color="black",
#     linestyle="--",
# )
# plt.plot(
#     [labels[2], labels[4]],
#     [memory_usage[2], memory_usage[4]],
#     color="black",
#     linestyle="--",
# )
# # add lines from full model to embedding, model head and attention heads
# plt.plot(
#     [labels[0], labels[1]],
#     [memory_usage[0], memory_usage[1]],
#     color="black",
#     linestyle="--",
# )
# plt.plot(
#     [labels[0], labels[2]],
#     [memory_usage[0], memory_usage[2]],
#     color="black",
#     linestyle="--",
# )
plt.show()