In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import transformers
from transformers import AutoModelForCausalLM 
from dotenv import load_dotenv
import gc
import torch
import pandas as pd

from baa import add_custom_name_to_linear_layers, get_memory_usage, count_parameters


In [None]:
load_dotenv()

In [None]:
model_names = [
    "HuggingFaceTB/SmolLM-135M-Instruct",
    "meta-llama/Llama-3.2-3B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
]

In [None]:
model = None

In [None]:
df = pd.DataFrame(columns=["Model", "Component", "Parameters", "Memory (Bytes)"])

In [None]:
for model_name in model_names:
    del model
    gc.collect()
    torch.cuda.empty_cache()

    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

    add_custom_name_to_linear_layers(model)

    include_list = {
        "Full_Model": [model],
        "Self_Attention": [
            module
            for name, module in model.named_modules()
            if "self_attn" in getattr(module, "custom_name", "")
        ],
        "MLP": [
            module
            for name, module in model.named_modules()
            if "mlp" in getattr(module, "custom_name", "")
        ],
        "LM_Head": [
            module
            for name, module in model.named_modules()
            if "lm_head" in getattr(module, "custom_name", "")
        ],
    }

    for name, include in include_list.items():
        print(f"Model: {model_name} - {name}")
        total_params = 0
        total_memory = 0
        for layer in include:
            total_params += count_parameters(layer)
            total_memory += get_memory_usage(layer)
        print(f"Total Parameters: {total_params}")
        print(f"Total Memory: {total_memory / 1024 / 1024:.2f} MB")

        df = pd.concat(
            [
                df,
                pd.DataFrame(
                    [[model_name, name, total_params, total_memory]], columns=df.columns
                ),
            ]
        )
df.to_csv("model_memory_per_component.csv", index=False)
