Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset

export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=81920'

# Test whether the forward pass logits match the golden logits
# default golden_logits_path=/deps/src/MaxText/test_assets/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.1 --rtol=0.1 --max_kl_div=3e-4
Expand Down
4 changes: 4 additions & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class MatmulPrecision(str, Enum):
DEFAULT = "default"
HIGH = "high"
HIGHEST = "highest"
# same as default
BFLOAT16 = "bfloat16"
# same as highest
FLOAT32 = "float32"


class QuantizationType(str, Enum):
Expand Down
Loading