In [20]:
import torch
import specdecodes.models.llm.modeling_llama as modeling_llama

from hqq.models.hf.base import AutoHQQHFModel
from hqq.models.base import get_all_children_from_model, forward_device_hooked, find_parent, name_to_linear_tag, is_leaf_module
from hqq.core.quantize import *
from hqq.core.utils import cleanup
from hqq.models.base import _QUANT_LAYERS

def get_modules_by_substring(model, substring: str, ignore: list = []) -> list:
    """
    Return all modules from `model` whose last name contains `substring`,
    ignoring any final submodule names listed in `ignore`.
    """
    matched = []
    for name, module in model.named_modules():
        # Only collect leaf modules (like in your get_all_children_from_model)
        if (name.split(".")[-1] not in ignore):
            if substring in name.split(".")[-1]:
                matched.append((name, module))
    return matched

def get_linear_from_model(model, ignore: list = []) -> list:
    matched = []
    for name, module in model.named_modules():
        if (type(module) in _QUANT_LAYERS) and (name.split(".")[-1] not in ignore):
            matched.append((name, module))
    return matched

In [1]:
"layers" in ["layers"]

True

In [4]:
def get_modules_by_substring(model, substring: str, ignore: list = []) -> list:
    """
    Return all modules from `model` whose last name contains `substring`,
    ignoring any final submodule names listed in `ignore`.
    """
    matched = []
    for name, module in model.named_modules():
        # Only collect leaf modules (like in your get_all_children_from_model)
        if (name.split(".")[-1] not in ignore):
            if substring in name.split(".")[-1]:
                matched.append((name, module))
    return matched

In [2]:
model_path = 'meta-llama/Llama-2-7b-chat-hf'
dtype = torch.bfloat16
device = 'cuda:0'

model = modeling_llama.LlamaForCausalLM.from_pretrained(
    model_path, 
    torch_dtype=dtype,
    low_cpu_mem_usage=True,
    device_map=device,
    _attn_implementation="sdpa",
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.24s/it]


In [21]:
get_linear_from_model(model)

[('model.layers.0.self_attn.q_proj',
  Linear(in_features=4096, out_features=4096, bias=False)),
 ('model.layers.0.self_attn.k_proj',
  Linear(in_features=4096, out_features=4096, bias=False)),
 ('model.layers.0.self_attn.v_proj',
  Linear(in_features=4096, out_features=4096, bias=False)),
 ('model.layers.0.self_attn.o_proj',
  Linear(in_features=4096, out_features=4096, bias=False)),
 ('model.layers.0.mlp.gate_proj',
  Linear(in_features=4096, out_features=11008, bias=False)),
 ('model.layers.0.mlp.up_proj',
  Linear(in_features=4096, out_features=11008, bias=False)),
 ('model.layers.0.mlp.down_proj',
  Linear(in_features=11008, out_features=4096, bias=False)),
 ('model.layers.1.self_attn.q_proj',
  Linear(in_features=4096, out_features=4096, bias=False)),
 ('model.layers.1.self_attn.k_proj',
  Linear(in_features=4096, out_features=4096, bias=False)),
 ('model.layers.1.self_attn.v_proj',
  Linear(in_features=4096, out_features=4096, bias=False)),
 ('model.layers.1.self_attn.o_proj',
 

In [7]:
get_modules_by_substring(model, 'self_attn')

[('model.layers.0.self_attn',
  LlamaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  )),
 ('model.layers.1.self_attn',
  LlamaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  )),
 ('model.layers.2.self_attn',
  LlamaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  ))

In [12]:
# check if type(get_modules_by_substring(model, 'layers')[0][1]) is torch.nn.modules.container.ModuleList

a = get_modules_by_substring(model, 'layers')[0][1]
isinstance(a, torch.nn.modules.container.ModuleList)

True

In [10]:
find_parent(model, 'layers')

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_e