In [None]:
# This script will generate the golden data for Llama-4-Scout-17B-16E and Llama-4-Maverick-17B-128E
# which can be used for logit verificaiton / testing
# NOTE: to change the model size, change the MODEL_SIZE variable in cell 3

In [None]:
!python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!python3 -m pip install tokenizers -U
!python3 -m pip install transformers -U

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import jsonlines

MODEL_SIZE = "scout" # "scout" or "maverick"
assert MODEL_SIZE in ["scout", "maverick"]

In [None]:
# Load the tokenizer and model from Hugging Face

model_id = "meta-llama/Llama-4-Scout-17B-16E" if MODEL_SIZE == "scout" else "meta-llama/Llama-4-Maverick-17B-128E"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    torch_dtype="float32",
)

# Save to disk
model_size_to_num_experts = "16e" if MODEL_SIZE == "scout" else "128e"
output_path = f"golden_data_llama4-17b-{model_size_to_num_experts}.jsonl"


# Your prompt text
prompt_texts = ["I love to"]
all_data_to_save = []


for prompt_text in prompt_texts:
  # Encode the prompt text
  input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
  print(f"Input ids are {input_ids}")

  # Get the logits for the prompt + completion
  with torch.no_grad():
    # NOTE: `use_cache=False` is needed, otherwise you'll get an error complaining about mixing
    # BF16 and FP32
    outputs = model(input_ids, use_cache=False)
    logits = outputs.logits

    # Convert logits to fp32
    logits = logits.cpu().numpy().astype("float32")

    # Prepare data to be saved
    data_to_save = {
        "prompt": prompt_text,
        "tokens": input_ids.tolist()[0], # squeeze batch (of 1) out
        "logits": logits.tolist()[0],  # # squeeze batch (of 1) out + convert numpy array to list for JSON serialization
    }
    all_data_to_save.append(data_to_save)

with jsonlines.open(output_path, "w") as f:
  f.write_all(all_data_to_save)


print(f"Data saved to {output_path}")