# Setup
 - Point to your weights
 - Pick a tiny input shape for vectors
 - All outputs go to ./smolvla_test_vectors

In [13]:
# %%
import os, json, math, re, pathlib, itertools
from typing import Dict, Tuple, Optional
import numpy as np
import torch
import torch.nn.functional as F
from safetensors.torch import load_file as load_safetensors

torch.set_grad_enabled(False)

# <<< EDIT IF NEEDED >>>
model_weights = "../weights/downloads/model.safetensors"  # <= your path
out_dir = "./smolvla_test_vectors"
os.makedirs(out_dir, exist_ok=True)


1) Load safetensors and index keys

In [14]:
state = load_safetensors(model_weights, device=device)


2. Helper functions for printing out and understanding model weights

In [15]:
def make_dictionary_tree(dict, split='.'):
    tree = {}

    for key, value in dict.items():
        parts = key.split(split)
        tree_key = parts[0]
        if len(parts) == 1:
            tree[tree_key] = value
            continue
        else:
            subtree_key = split.join(parts[1:])
            
            if tree_key not in tree:
                tree[tree_key] = { subtree_key: value }
            else:
                tree[tree_key][subtree_key] = value

    
    # if all items are numbers, sort them by numeric order
    if all(re.match(r'^\d+$', str(k)) for k, v in tree.items()):
        tree = { k: v for k, v in sorted(tree.items(), key=lambda x: int(x[0])) }

    for key in tree.keys():
        if isinstance(tree[key], Dict):
            tree[key] = make_dictionary_tree(tree[key], split=split)    

    return tree

def dictionary_tree_as_string(tree, prefix='', key_prefix=''):
    items = list(tree.items())
    result = ""

    for key, value in items[:-1]:
        if isinstance(value, dict):
            result += f"{prefix}├───┬ {key}\n"
            if key_prefix:
                result += f"{prefix}|   ├─ in: {key_prefix}.{key}\n"
            if isinstance(value, dict):
                result += f"{prefix}|   ├─ children: {list(value.keys())}\n"
            result += dictionary_tree_as_string(
                value, 
                f"{prefix}│   ",
                f"{key_prefix}.{key}"
            )
        else:
            result += f"{prefix}├──── {key}: {value.shape} [{value.dtype}]\n"

    last_key, last_value = items[-1]
    if isinstance(last_value, dict):
        result += f"{prefix}└───┬ {last_key}:\n"
        if key_prefix:
            result += f"{prefix}    ├─ in: {key_prefix}.{last_key}\n"
        if isinstance(last_value, dict):
            result += f"{prefix}    ├─ children: {list(last_value.keys())}\n"
        result += dictionary_tree_as_string(
            last_value, 
            f"{prefix}    ",
            f"{key_prefix}.{last_key}"
        )
    else:
        result += f"{prefix}└──── {last_key}: {last_value.shape} [{last_value.dtype}]\n"

    return result


3. prints out the actual weights as a nice dictionary

In [6]:
state_tree = make_dictionary_tree(state)

tree_string = dictionary_tree_as_string(
    state_tree
)

with open(f"{out_dir}/model_shape.txt", "w") as f:
    f.write(tree_string)

print(tree_string)


└───┬ model:
    ├─ children: ['action_in_proj', 'action_out_proj', 'action_time_mlp_in', 'action_time_mlp_out', 'state_proj', 'vlm_with_expert']
    ├───┬ action_in_proj
    |   ├─ in: .model.action_in_proj
    |   ├─ children: ['bias', 'weight']
    │   ├──── bias: torch.Size([720]) [torch.float32]
    │   └──── weight: torch.Size([720, 32]) [torch.float32]
    ├───┬ action_out_proj
    |   ├─ in: .model.action_out_proj
    |   ├─ children: ['bias', 'weight']
    │   ├──── bias: torch.Size([32]) [torch.float32]
    │   └──── weight: torch.Size([32, 720]) [torch.float32]
    ├───┬ action_time_mlp_in
    |   ├─ in: .model.action_time_mlp_in
    |   ├─ children: ['bias', 'weight']
    │   ├──── bias: torch.Size([720]) [torch.float32]
    │   └──── weight: torch.Size([720, 1440]) [torch.float32]
    ├───┬ action_time_mlp_out
    |   ├─ in: .model.action_time_mlp_out
    |   ├─ children: ['bias', 'weight']
    │   ├──── bias: torch.Size([720]) [torch.float32]
    │   └──── weight: torch.S

4. Hone in on an attention layer

In [7]:
def save_selected_layer(
    state_tree,
    layer
):
    keys = layer.split('.')
    selected_layer = state_tree
    for key in keys:
        if key not in selected_layer:
            print (f"Layer {layer} not found in state tree. (stopped at key '{key}')")
            print ("Available keys at this level:", list(selected_layer.keys()))
            return 
        
        selected_layer = selected_layer[key]

    
    selected_layer_string = dictionary_tree_as_string(
        selected_layer,
        key_prefix=layer
    )

    with open(f"{out_dir}/selected_layer_{layer.replace('.', '_')}.txt", "w") as f:
        f.write(selected_layer_string)

layers_to_save = [
    'model.vlm_with_expert.vlm.model.vision_model.encoder.layers.6',
    'model.vlm_with_expert.vlm.model.text_model.layers.3',
    'model.vlm_with_expert.lm_expert.layers.2',
]

for layer in layers_to_save:
    save_selected_layer(
        state_tree,
        layer=layer
    )


5. Get the model weights for a given layer

In [9]:
state_shapes = {}

for state_name, tensor in state.items():
    shapes = tensor.shape

    for shape in shapes:
        if shape not in state_shapes:
            state_shapes[shape] = set()
        state_shapes[shape].add(f"{state_name}: {tensor.dtype}")