In [5]:
# %%
import os, json, math, re, pathlib, itertools
from typing import Dict, Tuple, Optional
import numpy as np
import torch
import torch.nn.functional as F
from safetensors.torch import load_file as load_safetensors

torch.set_grad_enabled(False)

# <<< EDIT IF NEEDED >>>
model_weights = "../weights/downloads/model.safetensors" 
out_dir = "./data"
os.makedirs(out_dir, exist_ok=True)


1) Load safetensors and index keys

In [6]:
state = load_safetensors(model_weights)


2. Helper functions for printing out and understanding model weights

In [7]:
def make_dictionary_tree(dict, split='.'):
    tree = {}

    for key, value in dict.items():
        parts = key.split(split)
        tree_key = parts[0]
        if len(parts) == 1:
            tree[tree_key] = value
            continue
        else:
            subtree_key = split.join(parts[1:])
            
            if tree_key not in tree:
                tree[tree_key] = { subtree_key: value }
            else:
                tree[tree_key][subtree_key] = value

    
    # if all items are numbers, sort them by numeric order
    if all(re.match(r'^\d+$', str(k)) for k, v in tree.items()):
        tree = { k: v for k, v in sorted(tree.items(), key=lambda x: int(x[0])) }

    for key in tree.keys():
        if isinstance(tree[key], Dict):
            tree[key] = make_dictionary_tree(tree[key], split=split)    

    return tree

def dictionary_tree_as_string(tree, prefix='', key_prefix=''):
    items = list(tree.items())
    result = ""

    for key, value in items[:-1]:
        if isinstance(value, dict):
            result += f"{prefix}├───┬ {key}\n"
            if key_prefix:
                result += f"{prefix}|   ├─ in: {key_prefix}.{key}\n"
            if isinstance(value, dict):
                result += f"{prefix}|   ├─ children: {list(value.keys())}\n"
            result += dictionary_tree_as_string(
                value, 
                f"{prefix}│   ",
                f"{key_prefix}.{key}"
            )
        else:
            result += f"{prefix}├──── {key}: {value.shape} [{value.dtype}]\n"

    last_key, last_value = items[-1]
    if isinstance(last_value, dict):
        result += f"{prefix}└───┬ {last_key}:\n"
        if key_prefix:
            result += f"{prefix}    ├─ in: {key_prefix}.{last_key}\n"
        if isinstance(last_value, dict):
            result += f"{prefix}    ├─ children: {list(last_value.keys())}\n"
        result += dictionary_tree_as_string(
            last_value, 
            f"{prefix}    ",
            f"{key_prefix}.{last_key}"
        )
    else:
        result += f"{prefix}└──── {last_key}: {last_value.shape} [{last_value.dtype}]\n"

    return result


4. Get the different dimensions of the model weights

In [8]:
def save_npy_dump(
    state_tree,
    out_dir: str,
):
    for key, value in state_tree.items():
        if isinstance(value, dict):
            save_npy_dump(
                value,
                os.path.join(out_dir, key)
            )
        else:
            os.makedirs(out_dir, exist_ok=True)
            file_name = key.replace('.', '_')
            npy_path = os.path.join(out_dir, f"{file_name}.npy")
            np_value = value.to(torch.float32).cpu().numpy()
            np.save(npy_path, np_value)


In [9]:
save_npy_dump(
    make_dictionary_tree(state),
    out_dir
)
