In [1]:
import networkx as nx

In [2]:
from networkx_reverse_topological_sort import (
    reverse_topological_generations,
    reverse_topological_sort,
)

In [3]:
from gpt_oss_simplified import GptOssForCausalLM

In [4]:
# Main code: assign tensors as you yield them
model = GptOssForCausalLM()       # Uninitialized, or on your desired device

In [5]:
submodule_path_graph = nx.DiGraph()
submodule_paths_to_sumbodules = {}
submodule_paths_to_submodule_types = {}
submodule_types_to_submodules = {}

for submodule_path, submodule in model.named_modules():
    submodule_path_graph.add_node(submodule_path)
    submodule_path_components = submodule_path.split('.') if submodule_path else []
    if len(submodule_path_components) == 0:
        pass
    elif len(submodule_path_components) == 1:
        submodule_path_graph.add_edge('', submodule_path)
    elif len(submodule_path_components) >= 2:
        parent_submodule_path = '.'.join(submodule_path_components[:-1])
        submodule_path_graph.add_edge(parent_submodule_path, submodule_path)

    submodule_paths_to_sumbodules[submodule_path] = submodule
    
    submodule_type = type(submodule)
    submodule_paths_to_submodule_types[submodule_path] = submodule_type
    submodule_types_to_submodules.setdefault(submodule_type, []).append(submodule)

In [6]:
first_generation = next(reverse_topological_generations(submodule_path_graph))
first_generation

['model.embed_tokens',
 'model.layers.0.self_attn.q_proj',
 'model.layers.0.self_attn.k_proj',
 'model.layers.0.self_attn.v_proj',
 'model.layers.0.self_attn.o_proj',
 'model.layers.0.mlp.router',
 'model.layers.0.mlp.experts',
 'model.layers.0.input_layernorm',
 'model.layers.0.post_attention_layernorm',
 'model.layers.1.self_attn.q_proj',
 'model.layers.1.self_attn.k_proj',
 'model.layers.1.self_attn.v_proj',
 'model.layers.1.self_attn.o_proj',
 'model.layers.1.mlp.router',
 'model.layers.1.mlp.experts',
 'model.layers.1.input_layernorm',
 'model.layers.1.post_attention_layernorm',
 'model.layers.2.self_attn.q_proj',
 'model.layers.2.self_attn.k_proj',
 'model.layers.2.self_attn.v_proj',
 'model.layers.2.self_attn.o_proj',
 'model.layers.2.mlp.router',
 'model.layers.2.mlp.experts',
 'model.layers.2.input_layernorm',
 'model.layers.2.post_attention_layernorm',
 'model.layers.3.self_attn.q_proj',
 'model.layers.3.self_attn.k_proj',
 'model.layers.3.self_attn.v_proj',
 'model.layers.3.

In [7]:
{submodule_paths_to_submodule_types[p] for p in first_generation}

{gpt_oss_simplified.GptOssExperts,
 gpt_oss_simplified.GptOssRMSNorm,
 gpt_oss_simplified.GptOssRotaryEmbedding,
 gpt_oss_simplified.GptOssTopKRouter,
 torch.nn.modules.linear.Linear,
 torch.nn.modules.sparse.Embedding}

In [8]:
for p in first_generation:
    t = submodule_paths_to_submodule_types[p]
    if t.__module__.startswith('torch'):
        m = submodule_paths_to_sumbodules[p]
        print(m)

Embedding(201088, 2880, padding_idx=199999)
Linear(in_features=2880, out_features=4096, bias=True)
Linear(in_features=2880, out_features=512, bias=True)
Linear(in_features=2880, out_features=512, bias=True)
Linear(in_features=4096, out_features=2880, bias=True)
Linear(in_features=2880, out_features=4096, bias=True)
Linear(in_features=2880, out_features=512, bias=True)
Linear(in_features=2880, out_features=512, bias=True)
Linear(in_features=4096, out_features=2880, bias=True)
Linear(in_features=2880, out_features=4096, bias=True)
Linear(in_features=2880, out_features=512, bias=True)
Linear(in_features=2880, out_features=512, bias=True)
Linear(in_features=4096, out_features=2880, bias=True)
Linear(in_features=2880, out_features=4096, bias=True)
Linear(in_features=2880, out_features=512, bias=True)
Linear(in_features=2880, out_features=512, bias=True)
Linear(in_features=4096, out_features=2880, bias=True)
Linear(in_features=2880, out_features=4096, bias=True)
Linear(in_features=2880, out_

In [9]:
import torch

In [10]:
first_level_submodule_types_to_parameter_names_to_getters = {
    torch.nn.Embedding: {
        'num_embeddings': lambda submodule: submodule.num_embeddings,
        'embedding_dim': lambda submodule: submodule.embedding_dim,
        'padding_idx': lambda submodule: submodule.padding_idx,
    },
    torch.nn.Linear: {
        'in_features': lambda submodule: submodule.in_features,
        'out_features': lambda submodule: submodule.out_features,
        'bias': lambda submodule: submodule.bias is not None,
    },
}

In [11]:
first_level_submodule_paths_to_parameter_dicts = {}
for p in first_generation:
    t = submodule_paths_to_submodule_types[p]
    if t.__module__.startswith('torch'):
        m = submodule_paths_to_sumbodules[p]
        parameter_names_to_getters = first_level_submodule_types_to_parameter_names_to_getters[t]
        first_level_submodule_paths_to_parameter_dicts[p] = {
            parameter_name: getter(m)
            for parameter_name, getter in parameter_names_to_getters.items()
        }

In [12]:
import json

In [13]:
print(json.dumps(first_level_submodule_paths_to_parameter_dicts, indent=2))

{
  "model.embed_tokens": {
    "num_embeddings": 201088,
    "embedding_dim": 2880,
    "padding_idx": 199999
  },
  "model.layers.0.self_attn.q_proj": {
    "in_features": 2880,
    "out_features": 4096,
    "bias": true
  },
  "model.layers.0.self_attn.k_proj": {
    "in_features": 2880,
    "out_features": 512,
    "bias": true
  },
  "model.layers.0.self_attn.v_proj": {
    "in_features": 2880,
    "out_features": 512,
    "bias": true
  },
  "model.layers.0.self_attn.o_proj": {
    "in_features": 4096,
    "out_features": 2880,
    "bias": true
  },
  "model.layers.1.self_attn.q_proj": {
    "in_features": 2880,
    "out_features": 4096,
    "bias": true
  },
  "model.layers.1.self_attn.k_proj": {
    "in_features": 2880,
    "out_features": 512,
    "bias": true
  },
  "model.layers.1.self_attn.v_proj": {
    "in_features": 2880,
    "out_features": 512,
    "bias": true
  },
  "model.layers.1.self_attn.o_proj": {
    "in_features": 4096,
    "out_features": 2880,
    "bias": tr