In [4]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='3'
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
import sys
sys.path.insert(0, parent_dir)

import torch
from src.modeling_paligemma import PaliGemmaForConditionalGeneration

model_id = "google/paligemma2-3b-mix-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).eval()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
for i, a in model.named_parameters():
    print(i)
    # print(a)

vision_tower.vision_model.embeddings.patch_embedding.weight
vision_tower.vision_model.embeddings.patch_embedding.bias
vision_tower.vision_model.embeddings.position_embedding.weight
vision_tower.vision_model.encoder.layers.0.layer_norm1.weight
vision_tower.vision_model.encoder.layers.0.layer_norm1.bias
vision_tower.vision_model.encoder.layers.0.self_attn.k_proj.weight
vision_tower.vision_model.encoder.layers.0.self_attn.k_proj.bias
vision_tower.vision_model.encoder.layers.0.self_attn.v_proj.weight
vision_tower.vision_model.encoder.layers.0.self_attn.v_proj.bias
vision_tower.vision_model.encoder.layers.0.self_attn.q_proj.weight
vision_tower.vision_model.encoder.layers.0.self_attn.q_proj.bias
vision_tower.vision_model.encoder.layers.0.self_attn.out_proj.weight
vision_tower.vision_model.encoder.layers.0.self_attn.out_proj.bias
vision_tower.vision_model.encoder.layers.0.layer_norm2.weight
vision_tower.vision_model.encoder.layers.0.layer_norm2.bias
vision_tower.vision_model.encoder.layers.0.

In [33]:
def lora_filter(model, layers="all", layer_types=None, include_vision=True, include_language=True, vision_layers=None, language_layers=None):
    """
    Filters and formats parameter names for LoRA application from a PyTorch model.

    Args:
        model: The PyTorch model.
        layers (str or list, optional): Specifies which layers to consider.
            If "all", considers all layers in included towers and projector.
            If a list, it acts as a global layer index list for both towers
            if vision_layers and language_layers are None. Defaults to "all".
        layer_types (list, optional): A list of layer types to include.
            Supported types: "self_attn", "mlp", "embeddings", "projector",
            "q_proj", "k_proj", "v_proj", "o_proj". If None or empty, no
            specific layer type filtering is applied. Defaults to None.
        include_vision (bool, optional): Whether to include parameters from the
            vision tower. Defaults to True.
        include_language (bool, optional): Whether to include parameters from the
            language model. Defaults to True.
        vision_layers (list, optional): A list of specific layer indices to
            consider within the vision tower. If provided, overrides the
            integer elements in the 'layers' parameter for the vision tower.
            Defaults to None.
        language_layers (list, optional): A list of specific layer indices to
            consider within the language model. If provided, overrides the
            integer elements in the 'layers' parameter for the language model.
            Defaults to None.

    Returns:
        list: A list of filtered and formatted parameter names, in the
              order they appear in the model's named_parameters.
    """
    filtered_names = []
    num_vision_layers = 26
    num_language_layers = 25

    considered_vision_layers = []
    considered_language_layers = []

    if include_vision:
        if vision_layers is not None:
            considered_vision_layers = [l for l in vision_layers if isinstance(l, int) and 0 <= l < num_vision_layers]
        elif layers == "all":
            considered_vision_layers = list(range(num_vision_layers))
        elif isinstance(layers, list):
            considered_vision_layers = [l for l in layers if isinstance(l, int) and 0 <= l < num_vision_layers]

    if include_language:
        if language_layers is not None:
            considered_language_layers = [l for l in language_layers if isinstance(l, int) and 0 <= l < num_language_layers]
        elif layers == "all":
            considered_language_layers = list(range(num_language_layers))
        elif isinstance(layers, list):
            considered_language_layers = [l for l in layers if isinstance(l, int) and 0 <= l < num_language_layers]

    include_self_attn = layer_types and "self_attn" in layer_types
    include_mlp = layer_types and "mlp" in layer_types
    include_embeddings = layer_types and "embeddings" in layer_types
    include_projector = layer_types and "projector" in layer_types
    include_q_proj = layer_types and "q_proj" in layer_types and not include_self_attn
    include_k_proj = layer_types and "k_proj" in layer_types and not include_self_attn
    include_v_proj = layer_types and "v_proj" in layer_types and not include_self_attn
    include_o_proj = layer_types and "o_proj" in layer_types and not include_self_attn

    for name, _ in model.named_parameters():
        if ".bias" in name or "layer_norm" in name:
            continue

        if include_vision and "vision_tower" in name:
            layer_match = False
            if layers == "all" or (isinstance(vision_layers, list) and any(f".layers.{layer_index}." in name for layer_index in considered_vision_layers)) or (vision_layers is None and isinstance(layers, list) and any(f".layers.{layer_index}." in name for layer_index in considered_vision_layers)):
                layer_match = True

            if layer_match:
                include = False
                if include_self_attn and "self_attn" in name and any(proj in name for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]):
                    include = True
                elif include_mlp and "mlp" in name:
                    include = True
                elif include_embeddings and "embeddings" in name: # Include both patch and position embeddings
                    include = True
                elif include_q_proj and "self_attn" in name and "q_proj" in name:
                    include = True
                elif include_k_proj and "self_attn" in name and "k_proj" in name:
                    include = True
                elif include_v_proj and "self_attn" in name and "v_proj" in name:
                    include = True
                elif include_o_proj and "self_attn" in name and "out_proj" in name:
                    include = True
                elif not include_self_attn and not include_mlp and not include_embeddings and not include_q_proj and not include_k_proj and not include_v_proj and not include_o_proj and not layer_types:
                    include = True

                if include:
                    filtered_name = name.replace(".weight", "")
                    filtered_names.append(filtered_name)

        elif include_language and "language_model" in name:
            layer_match = False
            if layers == "all" or (isinstance(language_layers, list) and any(f".layers.{layer_index}." in name for layer_index in considered_language_layers)) or (language_layers is None and isinstance(layers, list) and any(f".layers.{layer_index}." in name for layer_index in considered_language_layers)):
                layer_match = True

            if layer_match:
                include = False
                if include_self_attn and "self_attn" in name and any(proj in name for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]):
                    include = True
                elif include_mlp and "mlp" in name:
                    include = True
                elif include_embeddings and "embed_tokens" in name:
                    include = True
                elif include_q_proj and "self_attn" in name and "q_proj" in name:
                    include = True
                elif include_k_proj and "self_attn" in name and "k_proj" in name:
                    include = True
                elif include_v_proj and "self_attn" in name and "v_proj" in name:
                    include = True
                elif include_o_proj and "self_attn" in name and "o_proj" in name:
                    include = True
                elif not include_self_attn and not include_mlp and not include_embeddings and not include_q_proj and not include_k_proj and not include_v_proj and not include_o_proj and not layer_types:
                    include = True

                if include:
                    filtered_name = name.replace(".weight", "")
                    filtered_names.append(filtered_name)

        if include_projector and "multi_modal_projector" in name:
            filtered_name = name.replace(".weight", "")
            filtered_names.append(filtered_name)

    return filtered_names

In [36]:
a = lora_filter(model, layers="all", layer_types=["embeddings"], include_vision=True, include_language=True)
print(a)

['vision_tower.vision_model.embeddings.patch_embedding', 'vision_tower.vision_model.embeddings.position_embedding', 'language_model.model.embed_tokens']
