Skip to content

Commit

Permalink
[Init] Make sure shape mismatches are caught early (huggingface#2847)
Browse files Browse the repository at this point in the history
Improve init
  • Loading branch information
patrickvonplaten authored and Jimmy committed Apr 26, 2024
1 parent 7ab41b9 commit a5bf681
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" those weights or else make sure your checkpoint file is correct."
)

empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys()
)

if empty_state_dict[param_name].shape != param.shape:
raise ValueError(
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)

if accepts_dtype:
set_module_tensor_to_device(
model, param_name, param_device, value=param, dtype=torch_dtype
Expand Down
24 changes: 24 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,30 @@ def test_one_request_upon_cached(self):

diffusers.utils.import_utils._safetensors_available = True

def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="unet",
cache_dir=tmpdirname,
in_channels=9,
)

# make sure that error message states what keys are missing
assert "Cannot load" in str(error_context.exception)

with tempfile.TemporaryDirectory() as tmpdirname:
model = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="unet",
cache_dir=tmpdirname,
in_channels=9,
low_cpu_mem_usage=False,
ignore_mismatched_sizes=True,
)

assert model.config.in_channels == 9


class ModelTesterMixin:
def test_from_save_pretrained(self):
Expand Down

0 comments on commit a5bf681

Please sign in to comment.