In [7]:
from VidUtil.torch_utils import inspect_checkpoint

checkpoint_path = "/home/denninge/CamC2V/ckpts/256_cami2v.pt"
checkpoint_info = inspect_checkpoint(checkpoint_path, return_dict=True)

  model = torch.load(ckpt_path, map_location=map_location)


In [8]:
checkpoint_info.keys()

dict_keys(['metadata', 'state_dict', 'totals', 'options'])

In [9]:
parameter_names = list(checkpoint_info['state_dict'].keys())

In [14]:
import re
from collections import defaultdict
from pprint import pprint

def extract_input_block_layers(keys, start_idx=0, end_idx=8, root="model.diffusion_model.middle_block"):
    """
    Extract parameters that live under model.diffusion_model.input_blocks.{i}
    for i in [start_idx, end_idx] inclusive.

    Returns a nested dict:
    {
      i: {
        "all": [full_param_key, ...],                      # all raw keys for this block
        "by_submodule": { submodule: [keys...], ... },     # grouped by the token right after input_blocks.i
        "unique_submodules": {submodule, ...}              # convenience set
      },
      ...
    }
    """
    results = {}
    # Precompile regex: start of string, root, dot, (i), dot, capture next token (no dot), then the rest
    # Example match groups for "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias":
    #   block_i = "2", submodule = "0", remainder = "emb_layers.1.bias"
    pattern = re.compile(rf"^{re.escape(root)}\.(\d+)\.([^\.]+)(?:\.(.*))?$")

    # Prepare containers for each i
    for i in range(start_idx, end_idx + 1):
        results[i] = {
            "all": [],
            "by_submodule": defaultdict(list),
            "unique_submodules": set(),
        }

    for k in keys:
        m = pattern.match(k)
        if not m:
            continue

        block_i_str, submodule, remainder = m.groups()
        try:
            block_i = int(block_i_str)
        except ValueError:
            continue

        # Keep only requested range
        if block_i < start_idx or block_i > end_idx:
            continue

        # Store
        results[block_i]["all"].append(k)
        results[block_i]["by_submodule"][submodule].append(k)
        results[block_i]["unique_submodules"].add(submodule)

    # Turn defaultdicts into plain dicts for cleanliness
    for i in range(start_idx, end_idx + 1):
        results[i]["by_submodule"] = dict(results[i]["by_submodule"])

    return results

# --------- Example usage ----------
blocks = extract_input_block_layers(parameter_names, start_idx=0, end_idx=8)

# Print a concise summary:
for i in range(0, 9):
    print(f"\n=== input_blocks.{i} ===")
    print(f"Total params: {len(blocks[i]['all'])}")
    with open(f"../model_architecture/middle_block_{i}.txt", "w") as f:
        for k in blocks[i]["all"]:
            f.write(k + "\n")
    print("Submodules:", sorted(blocks[i]["unique_submodules"]))


=== input_blocks.0 ===
Total params: 26
Submodules: ['emb_layers', 'in_layers', 'out_layers', 'temopral_conv']

=== input_blocks.1 ===
Total params: 29
Submodules: ['norm', 'proj_in', 'proj_out', 'transformer_blocks']

=== input_blocks.2 ===
Total params: 34
Submodules: ['norm', 'proj_in', 'proj_out', 'transformer_blocks']

=== input_blocks.3 ===
Total params: 26
Submodules: ['emb_layers', 'in_layers', 'out_layers', 'temopral_conv']

=== input_blocks.4 ===
Total params: 0
Submodules: []

=== input_blocks.5 ===
Total params: 0
Submodules: []

=== input_blocks.6 ===
Total params: 0
Submodules: []

=== input_blocks.7 ===
Total params: 0
Submodules: []

=== input_blocks.8 ===
Total params: 0
Submodules: []
