diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a03c454e9244..eec8df8a714b 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -134,7 +134,7 @@ class AudioPipelineOutput(BaseOutput): audios: np.ndarray -def is_safetensors_compatible(filenames, variant=None) -> bool: +def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: """ Checking for safetensors compatibility: - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch @@ -150,9 +150,14 @@ def is_safetensors_compatible(filenames, variant=None) -> bool: sf_filenames = set() + passed_components = passed_components or [] + for filename in filenames: _, extension = os.path.splitext(filename) + if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: + continue + if extension == ".bin": pt_filenames.append(filename) elif extension == ".safetensors": @@ -163,10 +168,8 @@ def is_safetensors_compatible(filenames, variant=None) -> bool: path, filename = os.path.split(filename) filename, extension = os.path.splitext(filename) - if filename == "pytorch_model": - filename = "model" - elif filename == f"pytorch_model.{variant}": - filename = f"model.{variant}" + if filename.startswith("pytorch_model"): + filename = filename.replace("pytorch_model", "model") else: filename = filename @@ -196,24 +199,51 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi weight_prefixes = [w.split(".")[0] for w in weight_names] # .bin, .safetensors, ... weight_suffixs = [w.split(".")[-1] for w in weight_names] + # -00001-of-00002 + transformers_index_format = "\d{5}-of-\d{5}" + + if variant is not None: + # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors` + variant_file_re = re.compile( + f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" + ) + # `text_encoder/pytorch_model.bin.index.fp16.json` + variant_index_re = re.compile( + f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" + ) - variant_file_regex = ( - re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})") - if variant is not None - else None + # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors` + non_variant_file_re = re.compile( + f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" ) - non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}") + # `text_encoder/pytorch_model.bin.index.json` + non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") if variant is not None: - variant_filenames = {f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None} + variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} + variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} + variant_filenames = variant_weights | variant_indexes else: variant_filenames = set() - non_variant_filenames = {f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None} + non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} + non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} + non_variant_filenames = non_variant_weights | non_variant_indexes + # all variant filenames will be used by default usable_filenames = set(variant_filenames) + + def convert_to_variant(filename): + if "index" in filename: + variant_filename = filename.replace("index", f"index.{variant}") + elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: + variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" + else: + variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" + return variant_filename + for f in non_variant_filenames: - variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}" + variant_filename = convert_to_variant(f) if variant_filename not in usable_filenames: usable_filenames.add(f) @@ -292,6 +322,27 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p return class_obj, class_candidates +def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None): + if custom_pipeline is not None: + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + + return get_class_from_dynamic_module( + custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision + ) + + if class_obj != DiffusionPipeline: + return class_obj + + diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) + return getattr(diffusers_module, config["_class_name"]) + + def load_sub_model( library_name: str, class_name: str, @@ -779,7 +830,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) - kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -794,8 +845,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token=use_auth_token, revision=revision, from_flax=from_flax, + use_safetensors=use_safetensors, custom_pipeline=custom_pipeline, + custom_revision=custom_revision, variant=variant, + **kwargs, ) else: cached_folder = pretrained_model_name_or_path @@ -810,29 +864,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P for folder in os.listdir(cached_folder): folder_path = os.path.join(cached_folder, folder) is_folder = os.path.isdir(folder_path) and folder in config_dict - variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path)) + variant_exists = is_folder and any( + p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) + ) if variant_exists: model_variants[folder] = variant # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it - if custom_pipeline is not None: - if custom_pipeline.endswith(".py"): - path = Path(custom_pipeline) - # decompose into folder & file - file_name = path.name - custom_pipeline = path.parent.absolute() - else: - file_name = CUSTOM_PIPELINE_FILE_NAME - - pipeline_class = get_class_from_dynamic_module( - custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision - ) - elif cls != DiffusionPipeline: - pipeline_class = cls - else: - diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) - pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + pipeline_class = _get_pipeline_class( + cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision + ) # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( @@ -1095,6 +1137,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) @@ -1153,7 +1196,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # this enables downloading schedulers, tokenizers, ... allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names] # also allow downloading config.json files with the model - allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names] + allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] allow_patterns += [ SCHEDULER_CONFIG_NAME, @@ -1162,17 +1205,28 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: CUSTOM_PIPELINE_FILE_NAME, ] + # retrieve passed components that should not be downloaded + pipeline_class = _get_pipeline_class( + cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision + ) + expected_components, _ = cls._get_signature_keys(pipeline_class) + passed_components = [k for k in expected_components if k in kwargs] + if ( use_safetensors and not allow_pickle - and not is_safetensors_compatible(model_filenames, variant=variant) + and not is_safetensors_compatible( + model_filenames, variant=variant, passed_components=passed_components + ) ): raise EnvironmentError( f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})" ) if from_flax: ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant): + elif use_safetensors and is_safetensors_compatible( + model_filenames, variant=variant, passed_components=passed_components + ): ignore_patterns = ["*.bin", "*.msgpack"] safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} @@ -1194,6 +1248,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." ) + # Don't download any objects that are passed + allow_patterns = [ + p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) + ] + # Don't download index files of forbidden patterns either + ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns] + re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 0525eaca50da..08cb03f55aaa 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -78,9 +78,7 @@ def test_one_request_upon_cached(self): with tempfile.TemporaryDirectory() as tmpdirname: with requests_mock.mock(real_http=True) as m: - DiffusionPipeline.download( - "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname - ) + DiffusionPipeline.download("hf-internal-testing/tiny-stable-diffusion-pipe", cache_dir=tmpdirname) download_requests = [r.method for r in m.request_history] assert download_requests.count("HEAD") == 15, "15 calls to files" @@ -101,6 +99,55 @@ def test_one_request_upon_cached(self): len(cache_requests) == 2 ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + def test_less_downloads_passed_object(self): + with tempfile.TemporaryDirectory() as tmpdirname: + cached_folder = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + # make sure safety checker is not downloaded + assert "safety_checker" not in os.listdir(cached_folder) + + # make sure rest is downloaded + assert "unet" in os.listdir(cached_folder) + assert "tokenizer" in os.listdir(cached_folder) + assert "vae" in os.listdir(cached_folder) + assert "model_index.json" in os.listdir(cached_folder) + assert "scheduler" in os.listdir(cached_folder) + assert "feature_extractor" in os.listdir(cached_folder) + + def test_less_downloads_passed_object_calls(self): + # TODO: For some reason this test fails on MPS where no HEAD call is made. + if torch_device == "mps": + return + + with tempfile.TemporaryDirectory() as tmpdirname: + with requests_mock.mock(real_http=True) as m: + DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + download_requests = [r.method for r in m.request_history] + # 15 - 2 because no call to config or model file for `safety_checker` + assert download_requests.count("HEAD") == 13, "13 calls to files" + # 17 - 2 because no call to config or model file for `safety_checker` + assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json" + assert ( + len(download_requests) == 28 + ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json" + + with requests_mock.mock(real_http=True) as m: + DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + cache_requests = [r.method for r in m.request_history] + assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD" + assert cache_requests.count("GET") == 1, "model info is only GET" + assert ( + len(cache_requests) == 2 + ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + def test_download_only_pytorch(self): with tempfile.TemporaryDirectory() as tmpdirname: # pipeline has Flax weights @@ -165,6 +212,54 @@ def test_download_safetensors(self): # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack assert not any(f.endswith(".bin") for f in files) + def test_download_safetensors_index(self): + for variant in ["fp16", None]: + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe-indexes", + cache_dir=tmpdirname, + use_safetensors=True, + variant=variant, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a safetensors file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-indexes/tree/main/text_encoder + if variant is None: + assert not any("fp16" in f for f in files) + else: + model_files = [f for f in files if "safetensors" in f] + assert all("fp16" in f for f in model_files) + + assert len([f for f in files if ".safetensors" in f]) == 8 + assert not any(".bin" in f for f in files) + + def test_download_bin_index(self): + for variant in ["fp16", None]: + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe-indexes", + cache_dir=tmpdirname, + use_safetensors=False, + variant=variant, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a safetensors file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-indexes/tree/main/text_encoder + if variant is None: + assert not any("fp16" in f for f in files) + else: + model_files = [f for f in files if "bin" in f] + assert all("fp16" in f for f in model_files) + + assert len([f for f in files if ".bin" in f]) == 8 + assert not any(".safetensors" in f for f in files) + def test_download_no_safety_checker(self): prompt = "hello" pipe = StableDiffusionPipeline.from_pretrained( @@ -362,6 +457,33 @@ def test_download_broken_variant(self): diffusers.utils.import_utils._safetensors_available = True + def test_local_save_load_index(self): + prompt = "hello" + for variant in [None, "fp16"]: + for use_safe in [True, False]: + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe-indexes", + variant=variant, + use_safetensors=use_safe, + safety_checker=None, + ) + pipe = pipe.to(torch_device) + generator = torch.manual_seed(0) + out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe_2 = StableDiffusionPipeline.from_pretrained( + tmpdirname, safe_serialization=use_safe, variant=variant + ) + pipe_2 = pipe_2.to(torch_device) + + generator = torch.manual_seed(0) + + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + assert np.max(np.abs(out - out_2)) < 1e-3 + def test_text_inversion_download(self): pipe = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None