diff --git a/end_to_end/tpu/eval_assert.py b/end_to_end/tpu/eval_assert.py index cd2d0cbb..07b5585a 100644 --- a/end_to_end/tpu/eval_assert.py +++ b/end_to_end/tpu/eval_assert.py @@ -14,6 +14,15 @@ limitations under the License. """ +""" +Example to run +python end_to_end/tpu/eval_assert.py avg_tflops metrics.txt 100 +python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100 +python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100 +""" + + + # pylint: skip-file """Reads and asserts over target values""" from absl import app @@ -34,11 +43,11 @@ def get_last_n_data(metrics_file, target, n=10): return last_n_data -def test_final_loss(metrics_file, target_loss): +def test_final_loss(metrics_file, target_loss, num_samples_str="10"): target_loss = float(target_loss) + num_samples = int(num_samples_str) with open(metrics_file, "r", encoding="utf8") as _: - use_last_n_data = 10 - last_n_data = get_last_n_data(metrics_file, "learning/loss", use_last_n_data) + last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples) avg_last_n_data = sum(last_n_data) / len(last_n_data) print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}") print(f"Target loss is {target_loss}") @@ -46,14 +55,77 @@ def test_final_loss(metrics_file, target_loss): print("Final loss test passed.") +def test_avg_step_time(metrics_file, max_avg_step_time_str, num_samples_str="10"): + """Tests if the average of the last N step times is below a maximum threshold.""" + max_avg_step_time = float(max_avg_step_time_str) + num_samples = int(num_samples_str) + metric_key = "perf/step_time_seconds" + last_n_step_times = get_last_n_data(metrics_file, metric_key, num_samples) + + if not last_n_step_times: + raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.") + + avg_last_n_step_time = sum(last_n_step_times) / len(last_n_step_times) + + print(f"Found {len(last_n_step_times)} data points for '{metric_key}'.") + print(f"Mean of last {len(last_n_step_times)} step times is {avg_last_n_step_time:.4f} s") + + assert ( + avg_last_n_step_time < max_avg_step_time + ), f"Average step time {avg_last_n_step_time:.4f}s is not less than target {max_avg_step_time}s." + print("Average step time test passed.") + + +def test_avg_tflops(metrics_file, min_avg_tflops_str, num_samples_str="10"): + """Tests if the average of the last N TFLOPs/sec values is above a minimum threshold.""" + min_avg_tflops = float(min_avg_tflops_str) + num_samples = int(num_samples_str) + metric_key = "perf/per_device_tflops_per_sec" + + last_n_tflops = get_last_n_data(metrics_file, metric_key, num_samples) + + if not last_n_tflops: + raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.") + + avg_last_n_tflops = sum(last_n_tflops) / len(last_n_tflops) + + print(f"Found {len(last_n_tflops)} data points for '{metric_key}'.") + print(f"Mean of last {len(last_n_tflops)} steps TFLOPs/sec is {avg_last_n_tflops:.2f}") + + assert ( + avg_last_n_tflops > min_avg_tflops + ), f"Average TFLOPs/sec {avg_last_n_tflops:.2f} is not greater than target {min_avg_tflops}." + print("Average TFLOPs/sec test passed.") + + def main(argv: Sequence[str]) -> None: + if len(argv) < 2: + print("Usage: python script.py [test_vars...]") + print("Available scenarios: final_loss, avg_step_time, avg_tflops") + raise ValueError("Test scenario not specified.") _, test_scenario, *test_vars = argv if test_scenario == "final_loss": - test_final_loss(*test_vars) + if len(test_vars) < 2: + raise ValueError("Usage: final_loss [num_samples]") + metrics_file, target_loss, *num_samples_opt = test_vars + num_samples = num_samples_opt[0] if num_samples_opt else "10" + test_final_loss(metrics_file, target_loss, num_samples) + elif test_scenario == "avg_step_time": + if len(test_vars) < 2: + raise ValueError("Usage: avg_step_time [num_samples]") + metrics_file, max_avg_step_time, *num_samples_opt = test_vars + num_samples = num_samples_opt[0] if num_samples_opt else "10" + test_avg_step_time(metrics_file, max_avg_step_time, num_samples) + elif test_scenario == "avg_tflops": + if len(test_vars) < 2: + raise ValueError("Usage: avg_tflops [num_samples]") + metrics_file, min_avg_tflops, *num_samples_opt = test_vars + num_samples = num_samples_opt[0] if num_samples_opt else "10" + test_avg_tflops(metrics_file, min_avg_tflops, num_samples) else: - raise ValueError(f"Unrecognized test_scenario {test_scenario}") + raise ValueError(f"Unrecognized test_scenario '{test_scenario}'. Available: final_loss, avg_step_time, avg_tflops") if __name__ == "__main__": diff --git a/end_to_end/tpu/test_sdxl_training_loss.sh b/end_to_end/tpu/test_sdxl_training_loss.sh old mode 100644 new mode 100755 diff --git a/requirements.txt b/requirements.txt index 1ca1dc79..defbb151 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,14 +19,14 @@ Pillow pylint pyink pytest==8.2.2 -tensorflow==2.17.0 +tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 git+https://github.com/mlperf/logging.git opencv-python-headless==4.10.0.84 orbax-checkpoint==0.10.3 tokenizers==0.21.0 -huggingface_hub==0.24.7 +huggingface_hub==0.30.2 transformers==4.48.1 einops==0.8.0 sentencepiece diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_stable_stack.txt index ba21b477..80ad1434 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_stable_stack.txt @@ -8,7 +8,7 @@ ftfy git+https://github.com/mlperf/logging.git google-cloud-storage==2.17.0 grain-nightly==0.0.10 -huggingface_hub==0.24.7 +huggingface_hub==0.30.2 jax>=0.4.30 jaxlib>=0.4.30 Jinja2 diff --git a/setup.py b/setup.py index 6cc926ac..c37f9080 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub==0.24.7", + "huggingface-hub==0.30.0", "requests-mock==1.10.0", "importlib_metadata", "invisible-watermark>=0.2.0", diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index 1768bbed..ac20ed6d 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty, # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ write_metrics: True gcs_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False log_period: 10000000000 # Flushes Tensorboard diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 4ff025c4..a1450abe 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty, # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ write_metrics: True gcs_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False log_period: 10000000000 # Flushes Tensorboard diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 4cf66f5d..ede91b10 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty, # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ write_metrics: True gcs_metrics: False + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False log_period: 10000000000 # Flushes Tensorboard diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 187bade0..49146e10 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -18,6 +18,10 @@ run_name: '' metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + gcs_metrics: False # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index 7c11e698..6928b31d 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty, # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ write_metrics: True gcs_metrics: False + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False log_period: 100 diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 188074a5..52659330 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty, # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ write_metrics: True gcs_metrics: False + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False log_period: 100 diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 24cfe399..542b7957 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -18,6 +18,10 @@ run_name: '' metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + gcs_metrics: False # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 60f6fb87..307e826a 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -18,6 +18,10 @@ run_name: '' metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ gcs_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False log_period: 100 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 3dff39e3..a28bcff1 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -18,8 +18,9 @@ # pylint: disable=bare-except, consider-using-generator """ Common Max Utils needed by multiple modules""" import functools -from functools import reduce +from functools import partial, reduce from contextlib import nullcontext +from typing import Dict, Callable import json import yaml import os @@ -564,6 +565,12 @@ def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSe return total_flops +def get_train_step_partial_with_signature(train_step: Callable, pipeline: object, params: Dict, config: object) -> Callable: + partial_train = partial(train_step, pipeline=pipeline, params=params, config=config) + partial_train.__name__ = "train_step" + return partial_train + + def calculate_num_params_from_pytree(params): """Calculates number of parameters from a pytree""" params_sizes = jax.tree_util.tree_map(jax.numpy.size, params) diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index d22867e4..81fca740 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -15,11 +15,25 @@ """ from abc import abstractmethod +import time +from typing import Any, Callable import jax from maxdiffusion import (max_utils, maxdiffusion_utils, max_logging) from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (BaseStableDiffusionCheckpointer) +# Define a filename for logging + + +def _log_to_file(message: str, log_file: str = ""): + """Appends a message to the global log file with a timestamp.""" + timestamp = time.strftime("%Y-%m-%d %H:%M:%S %Z", time.localtime()) + full_message = f"[{timestamp}] {message}\n" + if log_file: + with open(log_file, "a") as f: + f.write(full_message) + max_logging.log(full_message.strip()) + class BaseStableDiffusionTrainer(BaseStableDiffusionCheckpointer): @@ -67,6 +81,29 @@ def get_data_shardings(self): def create_scheduler(self, pipeline, params): pass + def _time_and_log_call( + self, func_obj: Callable[..., Any], *func_args: Any, description: str = "", **func_kwargs: Any + ) -> Any: + """ + Times a function call, logs its duration, and returns its result. + """ + if not description: + if hasattr(func_obj, "__name__"): + description = func_obj.__name__ + elif hasattr(func_obj, "__call__") and hasattr(type(func_obj), "__name__"): + description = type(func_obj).__name__ + log_file = "" + + if self.config.write_timing_metrics and self.config.timing_metrics_file: + log_file = self.config.get.timing_metrics_file + _log_to_file(f"Starting: {description}...", log_file=log_file) + start_time = time.perf_counter() # Use perf_counter for more precise duration measurement + result = func_obj(*func_args, **func_kwargs) + end_time = time.perf_counter() + duration = end_time - start_time + _log_to_file(f"Finished: {description} - Duration: {duration:.4f} seconds", log_file=log_file) + return result + def calculate_tflops(self, pipeline, params): per_device_tflops = maxdiffusion_utils.calculate_unet_tflops( self.config, pipeline, (self.config.per_device_batch_size * jax.local_device_count()), self.rng, train=True @@ -75,22 +112,28 @@ def calculate_tflops(self, pipeline, params): return per_device_tflops def start_training(self): - # Hook self.pre_training_steps() # Load checkpoint - will load or create states - pipeline, params = self.load_checkpoint() + pipeline, params = self._time_and_log_call(self.load_checkpoint) # create train states train_states = {} state_shardings = {} - vae_state, vae_state_mesh_shardings = self.create_vae_state( - pipeline=pipeline, params=params, checkpoint_item_name="vae_state", is_training=False + vae_state, vae_state_mesh_shardings = self._time_and_log_call( + self.create_vae_state, + # Arguments for create_vae_state + pipeline=pipeline, + params=params, + checkpoint_item_name="vae_state", + is_training=False, ) train_states["vae_state"] = vae_state state_shardings["vae_state_shardings"] = vae_state_mesh_shardings - text_encoder_state, text_encoder_state_mesh_shardings = self.create_text_encoder_state( + text_encoder_state, text_encoder_state_mesh_shardings = self._time_and_log_call( + self.create_text_encoder_state, + # Arguments for create_text_encoder_state pipeline=pipeline, params=params, checkpoint_item_name="text_encoder_state", @@ -99,8 +142,13 @@ def start_training(self): train_states["text_encoder_state"] = text_encoder_state state_shardings["text_encoder_state_shardings"] = text_encoder_state_mesh_shardings if hasattr(pipeline, "text_encoder_2"): - text_encoder_2_state, text_encoder_2_state_mesh_shardings = self.create_text_encoder_2_state( - pipeline, params, "text_encoder_2_state", is_training=self.config.train_text_encoder + text_encoder_2_state, text_encoder_2_state_mesh_shardings = self._time_and_log_call( + self.create_text_encoder_2_state, + # Arguments for create_text_encoder_2_state + pipeline=pipeline, + params=params, + checkpoint_item_name="text_encoder_2_state", + is_training=self.config.train_text_encoder, ) train_states["text_encoder_2_state"] = text_encoder_2_state state_shardings["text_encoder_2_state_shardings"] = text_encoder_2_state_mesh_shardings @@ -115,11 +163,12 @@ def start_training(self): self.per_device_tflops = per_device_tflops # Load dataset - data_iterator = self.load_dataset(pipeline, params, train_states) + data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states) if self.config.dataset_type == "grain": - data_iterator = self.restore_data_iterator_state(data_iterator) + data_iterator = self._time_and_log_call(self.restore_data_iterator_state, data_iterator=data_iterator) - unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state( + unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self._time_and_log_call( + self.create_unet_state, # ambiguous here, but if self.params.get("unet") doesn't exist # Then its 1 of 2 scenarios: # 1. unet state will be loaded directly from orbax @@ -134,11 +183,13 @@ def start_training(self): data_shardings = self.get_data_shardings() # Compile train_step - p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings) + p_train_step = self._time_and_log_call( + self.compile_train_step, pipeline, params, train_states, state_shardings, data_shardings + ) # Start training - train_states = self.training_loop( - p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler + train_states = self._time_and_log_call( + self.training_loop, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler ) # 6. save final checkpoint # Hook - self.post_training_steps(pipeline, params, train_states) + self._time_and_log_call(self.post_training_steps, pipeline, params, train_states) diff --git a/src/maxdiffusion/trainers/dreambooth_trainer.py b/src/maxdiffusion/trainers/dreambooth_trainer.py index bf10dc65..40a40190 100644 --- a/src/maxdiffusion/trainers/dreambooth_trainer.py +++ b/src/maxdiffusion/trainers/dreambooth_trainer.py @@ -16,7 +16,6 @@ from pathlib import Path import time -from functools import partial import datetime import os import numpy as np @@ -168,7 +167,7 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da self.rng, train_rngs = jax.random.split(self.rng) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): p_train_step = jax.jit( - partial(_train_step, config=self.config, pipeline=pipeline, params=params), + max_utils.get_train_step_partial_with_signature(_train_step, pipeline=pipeline, params=params, config=self.config), in_shardings=(state_shardings["unet_state_shardings"], None, data_shardings, None), out_shardings=(state_shardings["unet_state_shardings"], None, None, None), donate_argnums=(0,), diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index ed29fe91..b6a47c0d 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -285,14 +285,16 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da self.rng, train_rngs = jax.random.split(self.rng) guidance_vec = jnp.full((self.total_train_batch_size,), self.config.guidance_scale, dtype=self.config.activations_dtype) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + train_step_partial = partial( + _train_step, + guidance_vec=guidance_vec, + pipeline=pipeline, + scheduler=train_states["scheduler"], + config=self.config, + ) + train_step_partial.__name__ = "train_step" p_train_step = jax.jit( - partial( - _train_step, - guidance_vec=guidance_vec, - pipeline=pipeline, - scheduler=train_states["scheduler"], - config=self.config, - ), + train_step_partial, in_shardings=( state_shardings["flux_state_shardings"], data_shardings, diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index 88c9733f..bae48a57 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -153,7 +153,7 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da self.rng, train_rngs = jax.random.split(self.rng) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): p_train_step = jax.jit( - partial(_train_step, pipeline=pipeline, params=params, config=self.config), + max_utils.get_train_step_partial_with_signature(_train_step, pipeline=pipeline, params=params, config=self.config), in_shardings=( state_shardings["unet_state_shardings"], state_shardings["vae_state_shardings"], diff --git a/src/maxdiffusion/trainers/stable_diffusion_trainer.py b/src/maxdiffusion/trainers/stable_diffusion_trainer.py index 83158647..5844df3d 100644 --- a/src/maxdiffusion/trainers/stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/stable_diffusion_trainer.py @@ -148,7 +148,7 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da self.rng, train_rngs = jax.random.split(self.rng) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): p_train_step = jax.jit( - partial(_train_step, pipeline=pipeline, params=params, config=self.config), + max_utils.get_train_step_partial_with_signature(_train_step, pipeline=pipeline, params=params, config=self.config), in_shardings=( state_shardings["unet_state_shardings"], state_shardings["vae_state_shardings"], diff --git a/src/maxdiffusion/utils/dynamic_modules_utils.py b/src/maxdiffusion/utils/dynamic_modules_utils.py index f12c0b71..4477b210 100644 --- a/src/maxdiffusion/utils/dynamic_modules_utils.py +++ b/src/maxdiffusion/utils/dynamic_modules_utils.py @@ -25,9 +25,12 @@ from typing import Dict, Optional, Union from urllib import request -from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info +from huggingface_hub import HfFolder, hf_hub_download, model_info +import huggingface_hub from packaging import version +cached_download = None + from .. import __version__ from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging @@ -39,6 +42,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# https://github.com/huggingface/huggingface_hub/releases/tag/v0.26.0 +# `cached_download(), url_to_filename(), filename_to_url() methods are now completely removed. +# From now on, you will have to use hf_hub_download() to benefit from the new cache layout.` +if hasattr(huggingface_hub, "__version__"): + current_version = version.parse(huggingface_hub.__version__) + target_version = version.parse("0.26.0") + + if current_version < target_version: + try: + from huggingface_hub import cached_download + + except ImportError: + logger.error( + f"huggingface_hub version {current_version} is below 0.26.0, but 'cached_download' could not be imported. It might have been removed or deprecated in this version as well." + ) +else: + logger.error("Could not determine huggingface_hub version. Unable to conditionally import 'cached_download'.") + def get_diffusers_versions(): url = "https://pypi.org/pypi/diffusers/json"