In [1]:
import os.path
from glob import glob

In [2]:
from gpt_oss_simplified import *

In [3]:
from safetensors_layer_grabber import extract_layer_state_dict

In [4]:
from summary_stats import *

In [5]:
MODEL_DIRECTORY_PATH = os.path.expanduser('~/models/gpt-oss-20b/')

In [6]:
safetensors_file_names = glob(os.path.join(MODEL_DIRECTORY_PATH, '*.safetensors'))

------

In [None]:
router = GptOssTopKRouter()

In [None]:
router.load_state_dict(
    extract_layer_state_dict(
        safetensors_file_names=safetensors_file_names,
        layer_name='model.layers.7.mlp.router'
    )
)

In [None]:
parameters = torch.load('model.layers.7.mlp.router.pt')
parameters.keys()

In [None]:
output = router(parameters['hidden_states'])

In [None]:
summary_stats(output, parameters['return'])

------

In [None]:
post_attention_layernorm = GptOssRMSNorm()

In [None]:
post_attention_layernorm.load_state_dict(
    extract_layer_state_dict(
        safetensors_file_names=safetensors_file_names,
        layer_name='model.layers.23.post_attention_layernorm'
    )
)

In [None]:
parameters = torch.load('model.layers.23.post_attention_layernorm.pt')
parameters.keys()

In [None]:
output = post_attention_layernorm(parameters['hidden_states'])

In [None]:
summary_stats(output, parameters['return'])

------

In [None]:
rotary_embedding = GptOssRotaryEmbedding()

In [None]:
parameters = torch.load('model.rotary_emb.pt')
parameters.keys()

In [None]:
output = rotary_embedding(parameters['x'], parameters['position_ids'])

In [None]:
summary_stats(output, parameters['return'])

------

In [None]:
experts = GptOssExperts()

In [None]:
experts.load_state_dict(
    extract_layer_state_dict(
        safetensors_file_names=safetensors_file_names,
        layer_name='model.layers.4.mlp.experts'
    )
)

In [None]:
parameters = torch.load('model.layers.4.mlp.experts.pt')
parameters.keys()

In [None]:
output = experts(parameters['hidden_states'], parameters['router_indices'], parameters['routing_weights'])

In [None]:
summary_stats(output, parameters['return'])

---------

In [None]:
self_attn = GptOssAttention()

In [None]:
self_attn.load_state_dict(
    extract_layer_state_dict(
        safetensors_file_names=safetensors_file_names,
        layer_name='model.layers.8.self_attn'
    )
)

In [None]:
parameters = torch.load('model.layers.8.self_attn.pt')
parameters.keys()

In [None]:
output = self_attn(parameters['hidden_states'], parameters['attention_mask'], parameters['position_embeddings'])

In [None]:
summary_stats(output, parameters['return'])

---------

In [None]:
mlp = GptOssMLP()

In [None]:
mlp.load_state_dict(
    extract_layer_state_dict(
        safetensors_file_names=safetensors_file_names,
        layer_name='model.layers.7.mlp'
    )
)

In [None]:
parameters = torch.load('model.layers.7.mlp.pt')
parameters.keys()

In [None]:
output = mlp(parameters['hidden_states'])

In [None]:
summary_stats(output, parameters['return'])

------

In [None]:
layers_23 = GptOssDecoderLayer()

In [None]:
layers_23.load_state_dict(
    extract_layer_state_dict(
        safetensors_file_names=safetensors_file_names,
        layer_name='model.layers.23'
    )
)

In [None]:
parameters = torch.load('model.layers.23.pt')
parameters.keys()

In [None]:
output = layers_23(parameters['hidden_states'], parameters['attention_mask'], parameters['position_embeddings'])

In [None]:
summary_stats(output, parameters['return'])

------

In [7]:
model = GptOssForCausalLM()

In [8]:
model.load_state_dict(
    extract_layer_state_dict(
        safetensors_file_names=safetensors_file_names,
        layer_name=''
    )
)

<All keys matched successfully>

In [10]:
input_ids = torch.LongTensor([[   40,  6423,   290, 10915,   328,  2615,   382]])
input_ids

tensor([[   40,  6423,   290, 10915,   328,  2615,   382]])

In [11]:
attention_mask = torch.BoolTensor([[True, True, True, True, True, True, True]])
attention_mask

tensor([[True, True, True, True, True, True, True]])

In [12]:
output_token_sequences = generate(model, input_ids, attention_mask)
output_token_sequences

tensor([[   40,  6423,   290, 10915,   328,  2615,   382,   290,  4215,   328,
          5396,   885, 10335,   326,  3240,  3692,   410,  4066,   256,  7306]])

------

In [13]:
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIRECTORY_PATH)

In [15]:
[tokenizer.decode(output_token_sequence) for output_token_sequence in output_token_sequences]

['I believe the meaning of life is the sum of human\'s understanding and experience."**  \n   Here']