In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from baa import PerplexityBenchmark, get_llm_memory_usage, device_map
from datasets import load_dataset

In [None]:
model_name = "HuggingFaceTB/SmolLM-1.7B-Instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
quantized_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map,
)
franken_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
print(model)

In [None]:
print(f"Full Model Memory Usage: {get_llm_memory_usage(model.model) / 1024**2:.2f} MB")
print(
    f"Embedding Memory Usage: {get_llm_memory_usage(model.model.embed_tokens) / 1024**2:.2f} MB"
)
print(
    f"Attention Heads Memory Usage: {get_llm_memory_usage(model.model.layers[:]) / 1024**2:.2f} MB"
)
inner_self_attn_mapping = {
    "self_attn": model.model.layers[0].self_attn,
    "mlp": model.model.layers[0].mlp,
}

for key, value in inner_self_attn_mapping.items():
    print(f"{key} Memory Usage: {get_llm_memory_usage(value) / 1024**2:.2f} MB")

print(
    f"Model Head Memory Usage: {get_llm_memory_usage(model.lm_head) / 1024**2:.2f} MB"
)

In [None]:
component_memory_usage = {
    "Full Model": get_llm_memory_usage(model.model) / 1024**2,
    "Embedding": get_llm_memory_usage(model.model.embed_tokens) / 1024**2,
    "Model Head": get_llm_memory_usage(model.lm_head) / 1024**2,
    "Attention Heads": get_llm_memory_usage(model.model.layers[:]) / 1024**2,
    "Self Attention": get_llm_memory_usage(model.model.layers[0].self_attn)
    / 1024**2
    * len(model.model.layers),
    "MLP": get_llm_memory_usage(model.model.layers[0].mlp)
    / 1024**2
    * len(model.model.layers),
}

for component, usage in component_memory_usage.items():
    print(f"{component}: {usage:.2f} MB")

In [None]:
import matplotlib.pyplot as plt

labels = []
memory_usage = []

for component, usage in component_memory_usage.items():
    labels.append(component)
    memory_usage.append(usage)

# assert that embedding + attention heads + model head ~= full model
assert(sum(memory_usage[1:4]) - memory_usage[0] < 30*1024**2)

# Plotting the histogram
plt.figure(figsize=(10, 6))
plt.bar(labels, memory_usage, color="skyblue")
plt.xlabel("Model Components")
plt.ylabel("Memory Usage (MB)")
plt.title("Memory Usage of Different Model Components")
plt.xticks(rotation=45)
# make the last two components red
plt.bar(labels[-2:], memory_usage[-2:], color="red")
plt.show()