diff --git a/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh b/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh index 7f3f9c950..fb9e72b2e 100644 --- a/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh +++ b/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh @@ -49,6 +49,8 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data export DATASET_PATH=gs://maxtext-dataset +export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=81920' + # Test whether the forward pass logits match the golden logits # default golden_logits_path=/deps/src/MaxText/test_assets/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.1 --rtol=0.1 --max_kl_div=3e-4 diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 54bd7850c..b0e964c9e 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -61,6 +61,10 @@ class MatmulPrecision(str, Enum): DEFAULT = "default" HIGH = "high" HIGHEST = "highest" + # same as default + BFLOAT16 = "bfloat16" + # same as highest + FLOAT32 = "float32" class QuantizationType(str, Enum):