Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions tests/utils/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,12 +302,16 @@ def main(config, test_args): # pylint: disable=W0621
"Comparing up to the smaller vocab size."
)
min_vocab_size = min(full_train_logits.shape[-1], golden_logits.shape[-1])

start_index = 1 if test_args.skip_first_token else 0
# shape [seq_len, vocab_size]
train_logits_slice = full_train_logits[0, :token_size, :min_vocab_size]
golden_logits_slice = golden_logits[:token_size, :min_vocab_size]
max_logging.log("\n[logits: token 2]")
max_logging.log(f"{golden_logits_slice[2]=}")
max_logging.log(f"{train_logits_slice[2]=}")
train_logits_slice = full_train_logits[0, start_index:token_size, :min_vocab_size]
golden_logits_slice = golden_logits[start_index:token_size, :min_vocab_size]

if train_logits_slice.shape[0] > 2:
max_logging.log(f"\n[logits: token {start_index + 2}]")
max_logging.log(f"{golden_logits_slice[2]=}")
max_logging.log(f"{train_logits_slice[2]=}")

# Calculate absolute and relative differences for detailed reporting
abs_diff = jnp.abs(train_logits_slice - golden_logits_slice)
Expand Down Expand Up @@ -337,17 +341,18 @@ def main(config, test_args): # pylint: disable=W0621
model_probabilities = jax.nn.softmax(train_logits_slice, axis=-1)
golden_probabilities = jax.nn.softmax(golden_logits_slice, axis=-1)

max_logging.log("\n[probability: token 1]")
max_logging.log(f"{golden_probabilities[1]=}")
max_logging.log(f"{model_probabilities[1]=}")
if golden_probabilities.shape[0] > 1:
max_logging.log(f"\n[probability: token {start_index + 1}]")
max_logging.log(f"{golden_probabilities[1]=}")
max_logging.log(f"{model_probabilities[1]=}")

kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1)
max_kl_div_val = jax.numpy.max(kl_div)
max_kl_div_idx = jax.numpy.argmax(kl_div)
max_logging.log(
f"\n[KL divergence]\n"
f"KL divergence = {kl_div}, max KL divergence = {max_kl_div_val} at index {max_kl_div_idx}, "
f"the corresponding token id is {ids[0, max_kl_div_idx]}"
f"the corresponding token id is {ids[0, max_kl_div_idx + start_index]}"
)

if jax.process_index() == 0 and test_args.output_logits_path:
Expand Down Expand Up @@ -465,7 +470,12 @@ def main(config, test_args): # pylint: disable=W0621

# --- Compare all logits in the sequence (for the first batch item) ---
# Unsqueeze to add batch dimension for check_kl_divergence: [1, seq, vocab]
check_kl_divergence(mt_logits_torch[0].unsqueeze(0), hf_logits_torch[0].unsqueeze(0), atol=test_args.max_kl_div)
start_index = 1 if test_args.skip_first_token else 0
check_kl_divergence(
mt_logits_torch[0, start_index:].unsqueeze(0),
hf_logits_torch[0, start_index:].unsqueeze(0),
atol=test_args.max_kl_div,
)
if jax.process_index() == 0 and test_args.output_logits_path:
data_to_save = {
"mt_logits": mt_logits_torch[0].tolist(),
Expand Down Expand Up @@ -504,6 +514,13 @@ def main(config, test_args): # pylint: disable=W0621
parser.add_argument("--output_logits_path", type=str, required=False, default="")
parser.add_argument("--gcs_output_logits_path", type=str, required=False, default="")
parser.add_argument("--clip_logits_epsilon", type=float, required=False, default=None)
parser.add_argument(
"--skip_first_token",
Comment thread
shuningjin marked this conversation as resolved.
action="store_true",
required=False,
default=False,
help="Skip the first token during comparison to ignore BOS/init mismatches.",
)
test_args, _ = parser.parse_known_args()

# Remove args defined in this test file to avoid error from pyconfig
Expand All @@ -519,6 +536,7 @@ def main(config, test_args): # pylint: disable=W0621
"--output_logits_path",
"--gcs_output_logits_path",
"--clip_logits_epsilon",
"--skip_first_token",
]
for arg in to_remove_args:
model_args = [s for s in model_args if not s.startswith(arg)]
Expand All @@ -527,6 +545,10 @@ def main(config, test_args): # pylint: disable=W0621
assert (
test_args.atol is not None or test_args.max_kl_div is not None
), "At least one of --atol or --max_kl_div must be specified to define the test criteria."

if test_args.run_hf_model and test_args.clip_logits_epsilon is not None:
raise ValueError("--clip_logits_epsilon is not supported when running HF model on-the-fly (run_hf_model=True).")

if cfg.use_multimodal:
assert not test_args.run_hf_model, (
"Multimodal does not support running hf model on-the-fly, please generate hf golden logits "
Expand Down
Loading