diff --git a/Dockerfile b/Dockerfile index 7b5b48d12..86635af1b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,7 @@ ARG TARGET=base ARG BASE_IMAGE=ubuntu:22.04 +ARG BASE_IMAGE_COLOCATED=us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:2025_10_06-python_3.10-jax_0.6.2 FROM ${BASE_IMAGE} AS base @@ -94,18 +95,47 @@ ARG EXTRAS= # Install a custom jaxlib that includes backport of Pathways shared memory feature. # PR: https://github.com/openxla/xla/pull/31417 # Needed until Jax is upgraded to 0.8.0 or newer. -ARG INSTALL_PATHWAYS_JAXLIB=false +ARG INSTALL_PATHWAYS_JAXLIB=true # Ensure we install the TPU version, even if building locally. # Jax will fallback to CPU when run on a machine without TPU. RUN uv pip install -qq --prerelease=allow .[core,tpu] && uv cache clean RUN if [ -n "$EXTRAS" ]; then uv pip install -qq .[$EXTRAS] && uv cache clean; fi + +COPY jaxlib-0.6.2.dev20251021-cp310-cp310-manylinux2014_x86_64.whl . + +# 2. RUN the pip install command using the new, simple path *inside* the container RUN if [ "$INSTALL_PATHWAYS_JAXLIB" = "true" ]; then \ - uv pip install --prerelease=allow "jaxlib==0.5.3.dev20250918" \ - --find-links https://storage.googleapis.com/axlearn-wheels/wheels.html; \ + # uv pip install --prerelease=allow "jaxlib==0.6.2.dev20251020" \ + # --find-links https://storage.googleapis.com/axlearn-wheels/wheels.html; \ + uv pip install jaxlib-0.6.2.dev20251021-cp310-cp310-manylinux2014_x86_64.whl; \ fi COPY . . +################################################################################ +# Colocated Python container spec. # +################################################################################ + +FROM ${BASE_IMAGE_COLOCATED} as colocated-python + +WORKDIR /app +COPY . . + +# Install the additional user-provided dependencies, strictly enforcing the rules +# from the base image's constraints file. +RUN \ + echo "--> Installing user-provided dependencies..." && \ + uv pip install ".[core,gcp]" -c /opt/venv/server_constraints.txt && \ + \ + # 2. Verify that the colocated_python_cpu_client is present. + echo "--> Verifying JAX patch integrity..." && \ + python -c "from jax._src.lib import _jax; _jax.colocated_python_cpu_client" && \ + echo "--> JAX patch verification successful." && \ + \ + # 3. Clean the cache to keep the image slim. + uv cache clean + + ################################################################################ # GPU container spec. # ################################################################################ @@ -125,4 +155,4 @@ COPY . . # Final target spec. # ################################################################################ -FROM ${TARGET} AS final +FROM ${TARGET} AS final \ No newline at end of file diff --git a/axlearn/cloud/gcp/bundler.py b/axlearn/cloud/gcp/bundler.py index 8070f3257..6af17b735 100644 --- a/axlearn/cloud/gcp/bundler.py +++ b/axlearn/cloud/gcp/bundler.py @@ -98,11 +98,65 @@ class ArtifactRegistryBundler(DockerBundler): TYPE = "artifactregistry" + @config_class + class Config(DockerBundler.Config): + """Configures CloudBuildBundler. + + Attributes: + colocated_image_required: Bool to build a colocated image + colocated_image_name: Colocated Image Name + colocated_dockerfile: Colocated Dockerfile + """ + # Build image asynchronously. + colocated_image_required: bool = False + colocated_image_name: str = None + #colocated_dockerfile: str = None + + @classmethod def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> DockerBundler.Config: - cfg = super().from_spec(spec, fv=fv) + cfg: ArtifactRegistryBundler.Config = super().from_spec(spec, fv=fv) cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv) cfg.dockerfile = cfg.dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv) + cfg.colocated_image_required = cfg.colocated_image_required or gcp_settings("colocated_image_required", required=False, fv=fv) + cfg.colocated_image_name = cfg.colocated_image_name or gcp_settings("colocated_image_name", required=False, fv=fv) + #cfg.colocated_dockerfile = cfg.colocated_dockerfile or gcp_settings("colocated_dockerfile", required=False, fv=fv) + return cfg + + def _build_and_push(self, *args, **kwargs): + cfg = self.config + subprocess.run( + ["gcloud", "auth", "configure-docker", registry_from_repo(cfg.repo)], + check=True, + ) + + actual_name = cfg.image + #actual_dockerfile=cfg.dockerfile + actual_target=cfg.target + if bool(cfg.colocated_image_required): + + #cfg.dockerfile=cfg.colocated_dockerfile + cfg.image=cfg.colocated_image_name + cfg.target="colocated-python" + + colocated_bundler_class = ColocatedArtifactRegistryBundler(cfg=cfg) + colocated_image_name = colocated_bundler_class.bundle(tag=cfg.image) + + #cfg.dockerfile=actual_dockerfile + cfg.image=actual_name + cfg.target=actual_target + + return super()._build_and_push(*args, **kwargs) + + +class ColocatedArtifactRegistryBundler(DockerBundler): + """A DockerBundler that reads configs from gcp_settings, and auths to Artifact Registry.""" + + @classmethod + def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> DockerBundler.Config: + cfg: ColocatedArtifactRegistryBundler.Config = super().from_spec(spec, fv=fv) + cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv) + cfg.dockerfile = cfg.colocated_dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv) return cfg def _build_and_push(self, *args, **kwargs): @@ -111,6 +165,7 @@ def _build_and_push(self, *args, **kwargs): ["gcloud", "auth", "configure-docker", registry_from_repo(cfg.repo)], check=True, ) + return super()._build_and_push(*args, **kwargs) @@ -263,4 +318,4 @@ def with_tpu_extras(bundler: Bundler.Config) -> Bundler.Config: if __name__ == "__main__": common_flags() bundler_main_flags() - app.run(bundler_main) + app.run(bundler_main) \ No newline at end of file diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 93ac8a9e4..f1fe6a0f4 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -35,6 +35,7 @@ ) from axlearn.common.config import REQUIRED, Required, config_class from axlearn.common.utils import Nested +from axlearn.cloud.gcp.config import gcp_settings # The port used by pathways proxy server. # The specific value is not important, as long as clients and servers use the same port. @@ -45,19 +46,25 @@ # The port used by pathways worker server. # The specific value is not important, as long as clients and servers use the same port. _PATHWAYS_WORKER_PORT = 29001 +_COLOCATED_CONTAINER_PORT = 50051 # Pin to specific pathways image version for stable release. # There is no guarantee that this image will work with newer Jax releases. # This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625 # This image extends GRPC timeout for long context models. -_PATHWAYS_IMAGE_TAG = "shm_proxy_settings" +_PATHWAYS_IMAGE_TAG = "2025-10-03" + # The docker image used by pathways proxy container. +# pylint: disable=line-too-long _PATHWAYS_PROXY_IMAGE = ( - f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}" + # pylint: disable=line-too-long + f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/proxy_server:{_PATHWAYS_IMAGE_TAG}" ) # The docker image used by pathways resource manager container and worker container. _PATHWAYS_SERVER_IMAGE = ( - f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}" + # pylint: disable=line-too-long + f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/server:{_PATHWAYS_IMAGE_TAG}" ) + # The container name of pathways resourcemanager. _PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm" # The container name of pathways proxy. @@ -67,6 +74,9 @@ # The k8s replicatedJob name for pathways-worker pods. _PATHWAYS_WORKER_REPLICATED_JOB_NAME = "pathways-worker" +_COLOCATED_PYTHON_SIDECAR_NAME = "colocated-python-sidecar" + + # Add node-selector for cpu workload to avoid sharing nodes with system services. _PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY = "axlearn/nodepool_type" _PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE = "workload" @@ -75,6 +85,12 @@ # While workers will share #workers * _PATHWAYS_BACK_OFF_LIMIT total times. _PATHWAYS_BACK_OFF_LIMIT = 32 +FLAGS = flags.FLAGS + +def get_colocated_python_image(colocated_image_name, fv: flags.FlagValues = FLAGS) -> str: + repo = gcp_settings("docker_repo", required=False, fv=fv) + return repo+"/"+colocated_image_name+":"+colocated_image_name + def parse_xla_flag_value(value: str) -> Union[int, bool, str]: """Attempts to convert an XLA flag string value to int. @@ -135,7 +151,6 @@ def get_xla_options( """ return {k: v for k, v in xla_options.items() if k.startswith("xla_")} - def round_up_to_power_of_2(n): """ Rounds an integer up to the nearest power of 2. @@ -173,6 +188,7 @@ class Config(BaseReplicatedJob.Config): pathways_xla_flags: list[str] = [] pathways_head_cpu: Optional[str] = None pathways_head_mem: Optional[str] = None + colocated_image: Optional[str] = None @classmethod def define_flags(cls, fv): @@ -201,12 +217,19 @@ def define_flags(cls, fv): "Memory request for pathways-head container in GiB. Default is 16GiB", **common_kwargs, ) + flags.DEFINE_string( + "colocated_image", + None, + "Colocated Image Name", + **common_kwargs, + ) @classmethod def set_defaults(cls, fv): super().set_defaults(fv) fv.set_default("pathways_head_cpu", fv.pathways_head_cpu or "1") fv.set_default("pathways_head_mem", fv.pathways_head_mem or "16") + fv.set_default("colocated_image", fv.colocated_image or None) @classmethod def default_config(cls): @@ -311,29 +334,29 @@ def _build_pathways_head_container(self) -> dict: } ) - # pylint: disable=line-too-long - env_list.append( - { - "name": "NUM_REPLICAS", - "valueFrom": { - "fieldRef": { - "fieldPath": "metadata.annotations['jobset.sigs.k8s.io/replicatedjob-replicas']" - } - }, - } - ) + # # pylint: disable=line-too-long + # env_list.append( + # { + # "name": "NUM_REPLICAS", + # "valueFrom": { + # "fieldRef": { + # "fieldPath": "metadata.annotations['jobset.sigs.k8s.io/replicatedjob-replicas']" + # } + # }, + # } + # ) # pylint: enable=line-too-long - env_list.append( - { - "name": "REPLICA_ID", - "valueFrom": { - "fieldRef": { - "fieldPath": "metadata.annotations['jobset.sigs.k8s.io/job-index']" - } - }, - } - ) + # env_list.append( + # { + # "name": "REPLICA_ID", + # "valueFrom": { + # "fieldRef": { + # "fieldPath": "metadata.annotations['jobset.sigs.k8s.io/job-index']" + # } + # }, + # } + # ) head_container["env"] = env_list @@ -373,6 +396,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}", f"--server_port={_PATHWAYS_PROXY_PORT}", f"--gcs_scratch_location={staging_location}", + "--sidecar_name=external", ] cmd_args.extend(xla_flags_from_options(self._xla_options).split()) @@ -426,6 +450,22 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: ), ] + def _colocated_python_container(self): + cfg: PathwaysReplicatedJob.Config = self.config + return dict( + name=_COLOCATED_PYTHON_SIDECAR_NAME, + image=get_colocated_python_image(cfg.colocated_image), + restartPolicy="Always", + env=[ + { + "name": "GRPC_SERVER_ADDRESS", + "value": f"0.0.0.0:{_COLOCATED_CONTAINER_PORT}", + }, + ], + imagePullPolicy="Always", + ports=[dict(containerPort=_COLOCATED_CONTAINER_PORT)], + ) + def _build_pathways_head_pod(self) -> Nested[Any]: """Builds a pathways head pod. The pod includes a head container, a proxy container and a resource manager container. @@ -604,6 +644,7 @@ def _build_pathways_worker_pod( ) -> Nested[Any]: """Conoverts a worker pod to a new pod for the 'pathways-workers' role.""" cfg: TPUReplicatedJob.Config = self._inner.config + pathways_cfg: PathwaysReplicatedJob.Config = self.config # pylint: disable-next=protected-access pod = self._inner._build_pod() worker_pod = copy.deepcopy(pod) @@ -619,6 +660,10 @@ def _build_pathways_worker_pod( pod_spec["containers"] = [ self._build_pathways_worker_container(pathways_worker_replicated_job_index) ] + + if pathways_cfg.colocated_image: + pod_spec["initContainers"] = [self._colocated_python_container()] + worker_pod["spec"] = pod_spec # Service account for nodes. @@ -1056,4 +1101,4 @@ def __call__(self) -> Nested[Any]: size=system.vms_per_slice + 1, leaderTemplate=self.build_leader_pod(), workerTemplate=self.build_worker_pod(), - ) + ) \ No newline at end of file diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 369926e2e..d9988862f 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -307,6 +307,7 @@ async def _async_serialize( ) # pylint: disable=protected-access spec_has_metadata = { + "0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl._spec_has_metadata, "0.6.2": lambda: serialization.ts_impl._spec_has_metadata, "0.5.3": lambda: serialization._spec_has_metadata, }[jax.__version__]() @@ -487,6 +488,7 @@ async def cb(index: array.Index, device: jax.Device): requested_domain = ts.IndexTransform(input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) estimate_read_memory_footprint = { + "0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl.estimate_read_memory_footprint, "0.6.2": lambda: serialization.ts_impl.estimate_read_memory_footprint, "0.5.3": lambda: serialization.estimate_read_memory_footprint, }[jax.__version__]() @@ -568,6 +570,7 @@ async def cb(index: array.Index, device: jax.Device): # pylint: disable=protected-access create_async_array_from_callback = { + "0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl._create_async_array_from_callback, "0.6.2": lambda: serialization.ts_impl._create_async_array_from_callback, "0.5.3": lambda: serialization.create_async_array_from_callback, }[jax.__version__]() @@ -653,6 +656,7 @@ def serialize( commit_futures = [[] for _ in range(len(tensorstore_specs))] async_serialize = { + "0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl.async_serialize, "0.6.2": lambda: serialization.ts_impl.async_serialize, "0.5.3": lambda: serialization.async_serialize, }[jax.__version__]() diff --git a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt index df3ea721d..cc4022370 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt @@ -24,6 +24,7 @@ evalers['eval_train'].input.processor.vocab_cfg.extra_ids: 0 evalers['eval_train'].input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['eval_train'].input.processor.vocab_cfg.reverse_extra_ids: True evalers['eval_train'].input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['eval_train'].input.processor.vocab_cfg.use_fast_tokenizer: False evalers['eval_train'].input.source.dataset_name: 'librispeech:2.1.0' evalers['eval_train'].input.source.download: False evalers['eval_train'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' @@ -51,6 +52,7 @@ evalers['eval_dev_clean'].input.processor.vocab_cfg.extra_ids: 0 evalers['eval_dev_clean'].input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['eval_dev_clean'].input.processor.vocab_cfg.reverse_extra_ids: True evalers['eval_dev_clean'].input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['eval_dev_clean'].input.processor.vocab_cfg.use_fast_tokenizer: False evalers['eval_dev_clean'].input.source.dataset_name: 'librispeech:2.1.0' evalers['eval_dev_clean'].input.source.download: False evalers['eval_dev_clean'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' @@ -78,6 +80,7 @@ evalers['eval_dev_other'].input.processor.vocab_cfg.extra_ids: 0 evalers['eval_dev_other'].input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['eval_dev_other'].input.processor.vocab_cfg.reverse_extra_ids: True evalers['eval_dev_other'].input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['eval_dev_other'].input.processor.vocab_cfg.use_fast_tokenizer: False evalers['eval_dev_other'].input.source.dataset_name: 'librispeech:2.1.0' evalers['eval_dev_other'].input.source.download: False evalers['eval_dev_other'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' @@ -105,6 +108,7 @@ evalers['decoder_dev_clean'].input.processor.vocab_cfg.extra_ids: 0 evalers['decoder_dev_clean'].input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_dev_clean'].input.processor.vocab_cfg.reverse_extra_ids: True evalers['decoder_dev_clean'].input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_dev_clean'].input.processor.vocab_cfg.use_fast_tokenizer: False evalers['decoder_dev_clean'].input.source.dataset_name: 'librispeech:2.1.0' evalers['decoder_dev_clean'].input.source.download: False evalers['decoder_dev_clean'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' @@ -122,6 +126,7 @@ evalers['decoder_dev_clean'].metric_calculator.vocab.extra_ids: 0 evalers['decoder_dev_clean'].metric_calculator.vocab.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_dev_clean'].metric_calculator.vocab.reverse_extra_ids: True evalers['decoder_dev_clean'].metric_calculator.vocab.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_dev_clean'].metric_calculator.vocab.use_fast_tokenizer: False evalers['decoder_dev_clean'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['decoder_dev_clean'].summary_writer.write_every_n_steps: 1 evalers['decoder_dev_other'].eval_dtype: 'jax.numpy.float32' @@ -140,6 +145,7 @@ evalers['decoder_dev_other'].input.processor.vocab_cfg.extra_ids: 0 evalers['decoder_dev_other'].input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_dev_other'].input.processor.vocab_cfg.reverse_extra_ids: True evalers['decoder_dev_other'].input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_dev_other'].input.processor.vocab_cfg.use_fast_tokenizer: False evalers['decoder_dev_other'].input.source.dataset_name: 'librispeech:2.1.0' evalers['decoder_dev_other'].input.source.download: False evalers['decoder_dev_other'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' @@ -157,6 +163,7 @@ evalers['decoder_dev_other'].metric_calculator.vocab.extra_ids: 0 evalers['decoder_dev_other'].metric_calculator.vocab.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_dev_other'].metric_calculator.vocab.reverse_extra_ids: True evalers['decoder_dev_other'].metric_calculator.vocab.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_dev_other'].metric_calculator.vocab.use_fast_tokenizer: False evalers['decoder_dev_other'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['decoder_dev_other'].summary_writer.write_every_n_steps: 1 evalers['decoder_test_clean'].eval_dtype: 'jax.numpy.float32' @@ -175,6 +182,7 @@ evalers['decoder_test_clean'].input.processor.vocab_cfg.extra_ids: 0 evalers['decoder_test_clean'].input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_test_clean'].input.processor.vocab_cfg.reverse_extra_ids: True evalers['decoder_test_clean'].input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_test_clean'].input.processor.vocab_cfg.use_fast_tokenizer: False evalers['decoder_test_clean'].input.source.dataset_name: 'librispeech:2.1.0' evalers['decoder_test_clean'].input.source.download: False evalers['decoder_test_clean'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' @@ -192,6 +200,7 @@ evalers['decoder_test_clean'].metric_calculator.vocab.extra_ids: 0 evalers['decoder_test_clean'].metric_calculator.vocab.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_test_clean'].metric_calculator.vocab.reverse_extra_ids: True evalers['decoder_test_clean'].metric_calculator.vocab.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_test_clean'].metric_calculator.vocab.use_fast_tokenizer: False evalers['decoder_test_clean'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['decoder_test_clean'].summary_writer.write_every_n_steps: 1 evalers['decoder_test_other'].eval_dtype: 'jax.numpy.float32' @@ -210,6 +219,7 @@ evalers['decoder_test_other'].input.processor.vocab_cfg.extra_ids: 0 evalers['decoder_test_other'].input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_test_other'].input.processor.vocab_cfg.reverse_extra_ids: True evalers['decoder_test_other'].input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_test_other'].input.processor.vocab_cfg.use_fast_tokenizer: False evalers['decoder_test_other'].input.source.dataset_name: 'librispeech:2.1.0' evalers['decoder_test_other'].input.source.download: False evalers['decoder_test_other'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' @@ -227,6 +237,7 @@ evalers['decoder_test_other'].metric_calculator.vocab.extra_ids: 0 evalers['decoder_test_other'].metric_calculator.vocab.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_test_other'].metric_calculator.vocab.reverse_extra_ids: True evalers['decoder_test_other'].metric_calculator.vocab.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_test_other'].metric_calculator.vocab.use_fast_tokenizer: False evalers['decoder_test_other'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['decoder_test_other'].summary_writer.write_every_n_steps: 1 evalers['decoder_train'].eval_dtype: 'jax.numpy.float32' @@ -245,6 +256,7 @@ evalers['decoder_train'].input.processor.vocab_cfg.extra_ids: 0 evalers['decoder_train'].input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_train'].input.processor.vocab_cfg.reverse_extra_ids: True evalers['decoder_train'].input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_train'].input.processor.vocab_cfg.use_fast_tokenizer: False evalers['decoder_train'].input.source.dataset_name: 'librispeech:2.1.0' evalers['decoder_train'].input.source.download: False evalers['decoder_train'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' @@ -262,6 +274,7 @@ evalers['decoder_train'].metric_calculator.vocab.extra_ids: 0 evalers['decoder_train'].metric_calculator.vocab.klass: 'seqio.vocabularies.SentencePieceVocabulary' evalers['decoder_train'].metric_calculator.vocab.reverse_extra_ids: True evalers['decoder_train'].metric_calculator.vocab.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +evalers['decoder_train'].metric_calculator.vocab.use_fast_tokenizer: False evalers['decoder_train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['decoder_train'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -276,6 +289,7 @@ input.processor.vocab_cfg.extra_ids: 0 input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' input.processor.vocab_cfg.reverse_extra_ids: True input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +input.processor.vocab_cfg.use_fast_tokenizer: False input.source.dataset_name: 'librispeech:2.1.0' input.source.download: False input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' diff --git a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt index f0c611528..08ca99368 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt @@ -20,6 +20,7 @@ input.processor.vocab_cfg.extra_ids: 0 input.processor.vocab_cfg.klass: 'seqio.vocabularies.SentencePieceVocabulary' input.processor.vocab_cfg.reverse_extra_ids: True input.processor.vocab_cfg.sentencepiece_model_file: '$DATA_DIR/tokenizers/sentencepiece/librispeech_bpe_1024.model' +input.processor.vocab_cfg.use_fast_tokenizer: False input.source.dataset_name: 'librispeech:2.1.0' input.source.download: False input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset' diff --git a/colocated_commands.txt b/colocated_commands.txt new file mode 100644 index 000000000..e99fd7017 --- /dev/null +++ b/colocated_commands.txt @@ -0,0 +1,69 @@ +#### Different names for these images because they are in same repository ##### + + +export NAME=axlearn-img +export COLOCATED_NAME=colocated-img +export CKPT_BUCKET_NAME=<> +export CLUSTER_NAME=<> + +axlearn gcp bundle --name=$NAME \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry \ + --bundler_spec=dockerfile=Dockerfile \ + --bundler_spec=image=tpu \ + --bundler_spec=target=tpu \ + --bundler_spec=colocated_image_required=True \ + --bundler_spec=colocated_image_name=$COLOCATED_NAME + + + +axlearn gcp launch run --cluster=$CLUSTER_NAME \ + --runner_name gke_tpu_pathways \ + --name=$NAME \ + --instance_type=tpu-v5p-32 \ + --num_replicas=1 \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry \ + --bundler_spec=image=tpu \ + --bundler_spec=dockerfile=Dockerfile \ + --bundler_spec=target=tpu \ + --colocated_image=$COLOCATED_NAME \ + -- TPU_PREMAPPED_BUFFER_SIZE=34359738368 python3 test_benchmark.py --ckpt_path $CKPT_BUCKET_NAME + + +#### Commands to build images separately ###### + +export NAME=lk-axlearnimg13 +export COLOCATED_NAME=colocated-image23 +export CKPT_BUCKET_NAME=<> +export CLUSTER_NAME=<> + +### colocated image ##### +axlearn gcp bundle --name=$COLOCATED_NAME \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry \ + --bundler_spec=dockerfile=Dockerfile \ + --bundler_spec=image=$COLOCATED_NAME \ + --bundler_spec=target=colocated-python + +### axlearn image ##### +axlearn gcp bundle --name=$NAME \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry \ + --bundler_spec=dockerfile=Dockerfile \ + --bundler_spec=image=tpu \ + --bundler_spec=target=tpu + +axlearn gcp launch run --cluster=$CLUSTER_NAME \ + --runner_name gke_tpu_pathways \ + --name=$NAME \ + --instance_type=tpu-v5p-32 \ + --num_replicas=1 \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry \ + --bundler_spec=image=tpu \ + --bundler_spec=dockerfile=Dockerfile \ + --bundler_spec=target=tpu \ + --colocated_image=$COLOCATED_NAME \ + -- TPU_PREMAPPED_BUFFER_SIZE=34359738368 python3 test_benchmark.py --ckpt_path $CKPT_BUCKET_NAME + diff --git a/jaxlib-0.6.2.dev20251021-cp310-cp310-manylinux2014_x86_64.whl b/jaxlib-0.6.2.dev20251021-cp310-cp310-manylinux2014_x86_64.whl new file mode 100644 index 000000000..b5ff3d121 Binary files /dev/null and b/jaxlib-0.6.2.dev20251021-cp310-cp310-manylinux2014_x86_64.whl differ diff --git a/pyproject.toml b/pyproject.toml index 41ab5d0df..dfe7be409 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ requires-python = ">=3.10" # Minimal requirments for axlearn/common/config.py. dependencies = [ "attrs>=23.1.0", # We use `type` in `attrs.field` - "numpy==1.26.4", # verified with tensorflow 2.14 RaggedTensor + "numpy==2.1.1", # verified with tensorflow 2.14 RaggedTensor ] [project.optional-dependencies] @@ -25,6 +25,9 @@ core = [ "importlab==0.8.1", # breaks pytype on 0.8 "jax==0.6.2", "jaxlib==0.6.2", + "ml-dtypes==0.5.1", + "jax==0.6.2", + "jaxlib==0.6.2", "ml-dtypes>=0.5,<0.6", "msgpack==1.1.0", # for checkpointing. "nltk==3.7", # for text preprocessing @@ -34,16 +37,18 @@ core = [ "protobuf>=3.20.3", "tensorboard-plugin-profile==2.20.4", # This has both x86 and arm64 wheels. Underneath the hood it uses tensorflow-macos since 2.13. - "tensorflow==2.17.1", + "tensorflow==2.19.0.1", "tensorflow-datasets>=4.9.2", + "tensorflow-io>=0.37.2", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called". + "tensorflow-metadata==1.17.2", # Otherwise Seqio will report no core for tfds "tensorflow-io>=0.37.1", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called". "tensorflow-metadata>=1.0.0", # Otherwise Seqio will report no core for tfds - "tensorflow_text==2.17.0", # implied by seqio, but also used directly for text processing + "tensorflow_text==2.19.0", # implied by seqio, but also used directly for text processing "tensorstore>=0.1.63", # used for supporting GDA checkpoints "toml", # for config management "typing-extensions==4.12.2", - "scipy==1.12.0", # to avoid "module 'scipy.linalg' has no attribute 'tril'" - "seqio==0.0.18", # used for inputs + "scipy==1.15.3", # to avoid "module 'scipy.linalg' has no attribute 'tril'" + "seqio==0.0.18.1", # used for inputs "aqtp==0.8.2", # Updated from 0.4.0; compatible with Python 3.10 "flax==0.10.2", # for AQT, param converter and adapter. "prefixed==0.9.0", # For formatting file sizes, param counts, etc. @@ -66,7 +71,6 @@ dev = [ "pylint==2.17.7", "pytest", # test runner "pytest-xdist", # pytest plugin for test parallelism - "pytest-timeout", # pytest plugin for forcing timeout of tests "pytype==2022.4.22", # type checking "scikit-learn==1.5.2", # test-only # Fix AttributeError: module 'scipy.linalg' has no attribute 'tril' and related scipy import errors. @@ -74,7 +78,7 @@ dev = [ "sentencepiece != 0.1.92", "tqdm", # test-only "timm==0.6.12", # DiT Dependency test-only - "torch>=2.1.1", # test-only + "torch>=1.12.1", # test-only "torchvision==0.16.1", # test-only "safetensors<=0.5.3", # TODO: Remove once torch dependency is >=2.3.0 "transformers==4.51.3", # test-only @@ -108,6 +112,7 @@ gcp = [ tpu = [ "axlearn[gcp]", "jax[tpu]==0.6.2", # must be >=0.4.19 for compat with v5p. + "jax[tpu]==0.6.2", # must be >=0.4.19 for compat with v5p. "pathwaysutils==0.1.1", # For JAX+Pathways single-controller accelerator coordinator. ] # Vertex AI tensorboard. TODO(markblee): Merge with `gcp`. @@ -135,9 +140,11 @@ gpu = [ "triton==2.1.0", "jax[cuda12]==0.6.2", "nvidia-ml-py>=12.560.30", + "jax[cuda12]==0.6.2", + "nvidia-ml-py>=12.560.30", # pin nccl version, otherwise jax[cuda12] will pull latest version "nvidia-nccl-cu12==2.27.5", - "nvidia-cudnn-cu12>=9.8.0.87" # Pin CuDNN to at least 9.8 for Jax >= 0.6.2 + "nvidia-cudnn-cu12>=9.8.0.87", # Pin CuDNN to at least 9.8 for Jax >= 0.6.2 ] # Open API inference. open_api = [ @@ -190,7 +197,6 @@ line-length = 100 target-version = 'py39' [tool.pytest.ini_options] -timeout = 300 addopts = "-rs -s -p no:warnings --junitxml=test-results/testing.xml" markers = [ "gs_login: tests needing GS login.", @@ -210,13 +216,9 @@ junit_family="xunit2" line_length = 100 profile = "black" -[tool.uv.pip] +[tool.uv] find-links = [ + "https://storage.googleapis.com/axlearn-wheels/wheels.html", "https://storage.googleapis.com/jax-releases/libtpu_releases.html", "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html", -] - -[tool.uv] -override-dependencies = [ - "ml-dtypes>=0.5,<0.6", -] +] \ No newline at end of file diff --git a/test_benchmark.py b/test_benchmark.py new file mode 100644 index 000000000..7f64299b8 --- /dev/null +++ b/test_benchmark.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python3 +""" +Standalone script to preload a model from GCS using Colocated Python. + +This script reads the checkpoint index to determine the model structure and creates +appropriate TensorSpec objects for preloading. + +Usage: + python load_model_colocated.py --ckpt_path gs://your-bucket/path/to/checkpoint +""" + +import argparse +import asyncio +import functools +import os +import sys +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, Sequence + +import jax +import jax.numpy as jnp +import pathwaysutils +from jax._src.mesh import thread_resources +from jax.experimental import colocated_python, mesh_utils +from jax.experimental.array_serialization import serialization as array_serialization +from jax.experimental.array_serialization import tensorstore_impl + +from axlearn.common import utils +from axlearn.common.array_serialization import _async_deserialize +from axlearn.common.checkpointer import parse_step_from_dir, read_index_file +from axlearn.common.utils import TensorSpec, infer_mesh_shape +import logging + + +def _colocated_deserialize( + shardings: Sequence[jax.sharding.NamedSharding], + tensorstore_specs: Sequence[Dict[str, Any]], + global_shapes: Sequence[tuple], + dtypes: Sequence[jnp.dtype], +): + # concurrent_bytes = 1099511627776 + concurrent_bytes = 34359738368 * 6 # multiple of 32GB + cpu_devices = colocated_python.colocated_cpu_devices(jax.devices()) + print(f"{cpu_devices=}") + + if len(cpu_devices) > 1: + cpu_mesh = colocated_python.colocated_cpu_devices(thread_resources.env.physical_mesh) + cpu_shardings = [ + jax.sharding.NamedSharding(cpu_mesh, sharding.spec) for sharding in shardings + ] + else: + cpu_shardings = [ + jax.sharding.SingleDeviceSharding(cpu_devices[0]) for sharding in shardings + ] + + def output_spec_fn(): + return [ + jax.ShapeDtypeStruct(shape=shape, dtype=dtype, sharding=sharding) + for shape, dtype, sharding in zip(global_shapes, dtypes, cpu_shardings) + ] + + @colocated_python.colocated_python + def run_deserializer(): + # Object should be created once per process. + # pylint: disable=protected-access + # print("Print statement inside colocated") + logging.info("Logging statement inside colocated") + sys.stderr.write("Stdder statement in colocated") + start_colocated_time=time.perf_counter() + byte_limiter = tensorstore_impl._LimitInFlightBytes(concurrent_bytes) + h2d_limiter = tensorstore_impl._LimitInFlightBytes(concurrent_bytes) + thread_pool = ThreadPoolExecutor(1) + multi_thread_pool = ThreadPoolExecutor(2) + + future_arrays = jax.tree.map( + functools.partial( + _async_deserialize, + byte_limiter=byte_limiter, + h2d_limiter=h2d_limiter, + single_thread_pool=thread_pool, + multi_thread_pool=multi_thread_pool, + ), + cpu_shardings, + tensorstore_specs, + global_shapes, + dtypes, + ) + + async def gather_func(): + return await asyncio.gather(*future_arrays) + + result = asyncio.run(gather_func()) + logging.info(f"Deserialize took {time.perf_counter() - start_colocated_time:.2f} seconds") + return result + + run_deserializer = run_deserializer.specialize( + devices=cpu_devices, + out_specs_fn=output_spec_fn, + ) + + # Try running in the current event loop if one exists, otherwise create new one + result = run_deserializer() + return result + + +def create_mesh(mesh_shape=(1, 1, 1, 1, 1, -1)): + """Create a JAX mesh for distributed computation.""" + inferred_mesh_shape = infer_mesh_shape(mesh_shape) + print(f"Using mesh shape {inferred_mesh_shape} for {len(jax.devices())} devices") + devices = mesh_utils.create_device_mesh(inferred_mesh_shape) + return jax.sharding.Mesh(devices, ("pipeline", "data", "expert", "fsdp", "seq", "model")) + + +def create_state_spec_from_checkpoint(ckpt_path: str): + """Create a NestedTensorSpec from checkpoint index information.""" + index = read_index_file(ckpt_path) + print(f"Read checkpoint index with {len(index)} entries") + + state_spec = {} + + for path, value in index: + if path == "step": + continue + + # Filter out learner state + if is_learner_path(path): + continue + + if isinstance(value, dict) and "shape" in value and "dtype" in value: + # pylint: disable=eval-used + shape = eval(value["shape"]) if isinstance(value["shape"], str) else value["shape"] + dtype_str = value["dtype"] + + # Convert dtype string to jax dtype + dtype = getattr(jnp, dtype_str, jnp.float32) + if dtype == jnp.float32: + dtype = jnp.bfloat16 + + # Create nested dict structure from path + keys = path.split("/") + current = state_spec + for key in keys[:-1]: + if key not in current: + current[key] = {} + current = current[key] + + current[keys[-1]] = TensorSpec(shape=shape, dtype=dtype) + + return state_spec + + +def is_learner_path(path: str) -> bool: + """Check if a path is part of the learner state.""" + # Exclude all learner paths (optimizer state, ema, etc.) + return path.startswith("learner/") + + +def create_checkpoint_spec_from_state(ckpt_dir: str, state_spec: dict): + """Create checkpoint spec following the pattern from TensorStoreStateStorage._get_spec.""" + + tensorstore_specs = [] + shapes = [] + dtypes = [] + shardings = [] + + # Get current mesh for creating shardings + mesh = thread_resources.env.physical_mesh + if not mesh.shape: + raise RuntimeError("Checkpoint restoration must take place within the context of a Mesh") + + # Process each tensor in the state spec + for path, value in utils.flatten_items(state_spec, separator="/"): + if isinstance(value, TensorSpec): + # Get dtype + dtype = getattr(value.dtype, "dtype", value.dtype) + + # Create storage path and tensorstore spec + gda_path = os.path.join(ckpt_dir, "gda", path) + tensorstore_spec = array_serialization.get_tensorstore_spec(gda_path) + + # Get inference-friendly partition spec based on tensor path and shape + model_axis_size = mesh.shape.get("model", 1) + # Replicate small 1D tensors that cannot be sharded. + if len(value.shape) == 1 and value.shape[0] < model_axis_size: + partition_spec = jax.sharding.PartitionSpec() + else: + partition_spec = jax.sharding.PartitionSpec("model") + + # Create sharding with the appropriate partition spec + sharding = jax.sharding.NamedSharding(mesh, partition_spec) + + tensorstore_specs.append(tensorstore_spec) + shapes.append(value.shape) + dtypes.append(dtype) + shardings.append(sharding) + + return tensorstore_specs, shardings, shapes, dtypes + + +def _default_deserialize( + shardings: Sequence[jax.sharding.NamedSharding], + tensorstore_specs: Sequence[Dict[str, Any]], + global_shapes: Sequence[tuple], + dtypes: Sequence[jnp.dtype], +): + # concurrent_bytes = 1099511627776 + concurrent_bytes = 34359738368 * 6 # multiple of 32GB + # Object should be created once per process. + # pylint: disable=protected-access + byte_limiter = tensorstore_impl._LimitInFlightBytes(concurrent_bytes) + h2d_limiter = tensorstore_impl._LimitInFlightBytes(34359738368) + thread_pool = ThreadPoolExecutor(1) + multi_thread_pool = ThreadPoolExecutor(2) + + future_arrays = jax.tree.map( + functools.partial( + _async_deserialize, + byte_limiter=byte_limiter, + h2d_limiter=h2d_limiter, + single_thread_pool=thread_pool, + multi_thread_pool=multi_thread_pool, + ), + shardings, + tensorstore_specs, + global_shapes, + dtypes, + ) + + async def gather_func(): + return await asyncio.gather(*future_arrays) + result = asyncio.run(gather_func()) + return result + + +def load_model_default(ckpt_path: str): + """Main function to preload a model from GCS checkpoint.""" + step = parse_step_from_dir(ckpt_path) + print(f"Starting model preload from: {ckpt_path} (step {step})") + + if not ckpt_path.startswith("gs://"): + raise ValueError(f"Only GCS paths (gs://) are supported, got: {ckpt_path}") + + with create_mesh(): + print("Reading checkpoint structure...") + state_spec = create_state_spec_from_checkpoint(ckpt_path) + + print(f"Found {len(jax.tree_util.tree_leaves(state_spec))} tensors in checkpoint") + + tensorstore_specs, shardings, shapes, dtypes = create_checkpoint_spec_from_state( + ckpt_path, state_spec + ) + + print("Preloading checkpoint to TPU memory...") + start_time = time.perf_counter() + + restored_values = _default_deserialize( + shardings=shardings, + tensorstore_specs=tensorstore_specs, + global_shapes=shapes, + dtypes=dtypes, + ) + + preload_time = time.perf_counter() - start_time + print(f"Preload completed in {preload_time:.2f} seconds") + print(f"Preloaded {len(restored_values)} arrays") + + return restored_values + + +def load_model_colocated(ckpt_path: str): + """Main function to preload a model from GCS checkpoint.""" + step = parse_step_from_dir(ckpt_path) + print(f"Starting model preload from: {ckpt_path} (step {step})") + + if not ckpt_path.startswith("gs://"): + raise ValueError(f"Only GCS paths (gs://) are supported, got: {ckpt_path}") + + with create_mesh(): + print("Reading checkpoint structure...") + state_spec = create_state_spec_from_checkpoint(ckpt_path) + + print(f"Found {len(jax.tree_util.tree_leaves(state_spec))} tensors in checkpoint") + + tensorstore_specs, shardings, shapes, dtypes = create_checkpoint_spec_from_state( + ckpt_path, state_spec + ) + + print("Preloading checkpoint to CPU memory...") + start_time = time.perf_counter() + + preloaded_values = _colocated_deserialize( + shardings=shardings, + tensorstore_specs=tensorstore_specs, + global_shapes=shapes, + dtypes=dtypes, + ) + + preload_time = time.perf_counter() - start_time + print(f"Preload completed in {preload_time:.2f} seconds") + print(f"Preloaded {len(preloaded_values)} arrays") + + print("Transferring arrays to TPU...") + start_time = time.perf_counter() + + # Create a mesh with only local devices. + local_mesh = jax.sharding.Mesh(jax.local_devices(), ("model",)) + # Recreate shardings with the local mesh, which is what device_put expects. + local_shardings = [jax.sharding.NamedSharding(local_mesh, s.spec) for s in shardings] + restored_values = [jax.device_put(x, s) for x, s in zip(preloaded_values, local_shardings)] + + transfer_time = time.perf_counter() - start_time + print(f"Transfer completed in {transfer_time:.2f} seconds") + + return restored_values + + +def main(): + parser = argparse.ArgumentParser(description="Preload model from GCS checkpoint") + parser.add_argument( + "--ckpt_path", + required=True, + help="GCS path to checkpoint directory (e.g., gs://bucket/path/to/checkpoint)", + ) + args = parser.parse_args() + + if os.getenv("JAX_PLATFORMS") == "proxy": + pathwaysutils.initialize() + else: + jax.distributed.initialize() + + print(f"JAX devices: {jax.devices()}") + + print("--- Running colocated benchmark ---") + # Extract profile dir from ckpt_path. The profile dir should be gs://bucket/profiles/ + hostname = os.uname().nodename + profile_dir = f"gs://{args.ckpt_path.split('/')[2]}/profiles/{hostname}/colocated-test/" + #jax.profiler.start_trace(log_dir=profile_dir) + start_colocated_time = time.perf_counter() + loaded_values_colocated = load_model_colocated(ckpt_path=args.ckpt_path) + for x in loaded_values_colocated: + x.block_until_ready() + print(f"✅ Successfully loaded model from {args.ckpt_path}") + print(f"Deserialize took {time.perf_counter() - start_colocated_time:.2f} seconds") + print(f" Total parameters: {sum(x.size for x in loaded_values_colocated):,}") + #jax.profiler.stop_trace() + + # Exit early if on pathways + if os.getenv("JAX_PLATFORMS") == "proxy": + sys.exit(0) + + print("\n--- Running default benchmark ---") + start_default_time = time.perf_counter() + loaded_values_default = load_model_default(ckpt_path=args.ckpt_path) + print(f"✅ Successfully loaded model from {args.ckpt_path}") + print(f"Deserialize took {time.perf_counter() - start_default_time:.2f} seconds") + print(f" Total parameters: {sum(x.size for x in loaded_values_default):,}") + + +if __name__ == "__main__": + main() \ No newline at end of file