In [1]:
from transformers import MllamaConfig, MllamaForConditionalGeneration
from accelerate import init_empty_weights
from torch.nn import ModuleList
from collections import OrderedDict

In [2]:
model_id = 'meta-llama/Llama-3.2-11B-Vision-Instruct'

In [3]:
config = MllamaConfig.from_pretrained(model_id)

with init_empty_weights():
    model = MllamaForConditionalGeneration(config)

In [4]:
def gather_named_children(mdl, result: list[str], parent: list[str]|None=None, recurse: bool=True):
    if parent is None:
        parent = []

    for n, m in mdl.named_children():
        names = list(parent) # copy
        names.append(n)

        is_list = isinstance(m, ModuleList)
        has_params = len(list(m.parameters(recurse=False))) > 0
        has_buffers = len(list(m.buffers(recurse=False))) > 0

        if has_params or has_buffers or not recurse:
            #print(f"{'.'.join(names)}")
            result.append('.'.join(names))

        if recurse:
            gather_named_children(m, result, parent=names, recurse=not is_list)

result: list[str] = []
gather_named_children(model, result)

In [5]:
device_map = OrderedDict([
    (n, 0) for n in result
])

device_map

OrderedDict([('vision_model', 0),
             ('vision_model.patch_embedding', 0),
             ('vision_model.gated_positional_embedding', 0),
             ('vision_model.gated_positional_embedding.tile_embedding', 0),
             ('vision_model.pre_tile_positional_embedding', 0),
             ('vision_model.pre_tile_positional_embedding.embedding', 0),
             ('vision_model.post_tile_positional_embedding', 0),
             ('vision_model.post_tile_positional_embedding.embedding', 0),
             ('vision_model.layernorm_pre', 0),
             ('vision_model.layernorm_post', 0),
             ('vision_model.transformer.layers.0', 0),
             ('vision_model.transformer.layers.1', 0),
             ('vision_model.transformer.layers.2', 0),
             ('vision_model.transformer.layers.3', 0),
             ('vision_model.transformer.layers.4', 0),
             ('vision_model.transformer.layers.5', 0),
             ('vision_model.transformer.layers.6', 0),
             ('visi