diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index c60c8f8cc9..07732d2bd4 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -233,6 +233,10 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: if model_name != "default": # First try relative to base config path model_config_path = os.path.join(os.path.dirname(config_path), "models", f"{model_name}.yml") + # Try looking for "models" under "src/maxtext/configs/" + if not os.path.isfile(model_config_path): + model_config_path = os.path.join(os.path.dirname(os.path.dirname(config_path)), "models", f"{model_name}.yml") + if not os.path.isfile(model_config_path): # Fallback to default location within package dir_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/tests/unit/pyconfig_test.py b/tests/unit/pyconfig_test.py index 6a0353153e..335d2e92b2 100644 --- a/tests/unit/pyconfig_test.py +++ b/tests/unit/pyconfig_test.py @@ -20,7 +20,7 @@ from MaxText import pyconfig from MaxText.pyconfig import resolve_config_path from MaxText.globals import MAXTEXT_PKG_DIR -from tests.utils.test_helpers import get_test_config_path +from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path class PyconfigTest(unittest.TestCase): @@ -85,6 +85,18 @@ def test_overriding_model(self): self.assertEqual(config.base_emb_dim, 1024) self.assertEqual(config.base_mlp_dim, 24576) + def test_overriding_model_in_sft(self): + # TODO: Update MAXTEXT_PKG_DIR after repo restructuring is complete. + config = pyconfig.initialize( + [os.path.join("maxtext.trainers.post_train.sft.train_sft"), get_post_train_test_config_path("sft")], + skip_jax_distributed_system=True, + model_name="llama3.1-8b", + override_model_config=True, + ) + + self.assertEqual(config.base_emb_dim, 4096) + self.assertEqual(config.base_mlp_dim, 14336) + def test_resolve_config_path(self): self.assertEqual(resolve_config_path("foo"), os.path.join("src", "foo")) self.assertEqual(resolve_config_path(__file__), __file__) diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index d8603041e6..656e0e1c37 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -35,6 +35,15 @@ def get_test_config_path(): return os.path.join(MAXTEXT_CONFIGS_DIR, base_cfg) +def get_post_train_test_config_path(sub_type="sft"): + """Return absolute path to the chosen test config file. + + Returns `decoupled_base_test.yml` when decoupled, otherwise `base.yml`. + """ + base_cfg = "rl.yml" if sub_type == "rl" else "sft.yml" + return os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", base_cfg) + + def get_test_dataset_path(cloud_path=None): """Return the dataset path for tests. @@ -70,5 +79,6 @@ def get_test_base_output_directory(cloud_path=None): __all__ = [ "get_test_base_output_directory", "get_test_config_path", + "get_post_train_test_config_path", "get_test_dataset_path", ]