Skip to content

Commit

Permalink
[Feat] Enable State Dict For Textual Inversion Loader (huggingface#3439)
Browse files Browse the repository at this point in the history
* enable state dict for textual inversion loader

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* add tests

* fix tests

* fix tests

* fix tests

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
2 people authored and Jimmy committed Apr 26, 2024
1 parent 94a7ff2 commit e08af4d
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 33 deletions.
71 changes: 38 additions & 33 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):

def load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str]],
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None,
**kwargs,
):
Expand All @@ -485,7 +485,7 @@ def load_textual_inversion(
</Tip>
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Expand All @@ -494,6 +494,8 @@ def load_textual_inversion(
- A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`.
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
Or a list of those elements.
token (`str` or `List[str]`, *optional*):
Expand Down Expand Up @@ -618,7 +620,7 @@ def load_textual_inversion(
"framework": "pytorch",
}

if isinstance(pretrained_model_name_or_path, str):
if not isinstance(pretrained_model_name_or_path, list):
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
else:
pretrained_model_name_or_paths = pretrained_model_name_or_path
Expand All @@ -643,16 +645,38 @@ def load_textual_inversion(
token_ids_and_embeddings = []

for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
if not isinstance(pretrained_model_name_or_path, dict):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e

model_file = None

if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand All @@ -663,28 +687,9 @@ def load_textual_inversion(
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e

model_file = None

if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path

# 2. Load token and embedding correcly from file
loaded_token = None
Expand Down
59 changes: 59 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,65 @@ def test_text_inversion_download(self):
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)

# single token state dict load
ten = {"<x>": torch.ones((32,))}
pipe.load_textual_inversion(ten)

token = pipe.tokenizer.convert_tokens_to_ids("<x>")
assert token == num_tokens + 10, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32
assert pipe._maybe_convert_prompt("<x>", pipe.tokenizer) == "<x>"

prompt = "hey <x>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)

# multi embedding state dict load
ten1 = {"<xxxxx>": torch.ones((32,))}
ten2 = {"<xxxxxx>": 2 * torch.ones((1, 32))}

pipe.load_textual_inversion([ten1, ten2])

token = pipe.tokenizer.convert_tokens_to_ids("<xxxxx>")
assert token == num_tokens + 11, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32
assert pipe._maybe_convert_prompt("<xxxxx>", pipe.tokenizer) == "<xxxxx>"

token = pipe.tokenizer.convert_tokens_to_ids("<xxxxxx>")
assert token == num_tokens + 12, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
assert pipe._maybe_convert_prompt("<xxxxxx>", pipe.tokenizer) == "<xxxxxx>"

prompt = "hey <xxxxx> <xxxxxx>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)

# auto1111 multi-token state dict load
ten = {
"string_to_param": {
"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])
},
"name": "<xxxx>",
}

pipe.load_textual_inversion(ten)

token = pipe.tokenizer.convert_tokens_to_ids("<xxxx>")
token_1 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_1")
token_2 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_2")

assert token == num_tokens + 13, "Added token must be at spot `num_tokens`"
assert token_1 == num_tokens + 14, "Added token must be at spot `num_tokens`"
assert token_2 == num_tokens + 15, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
assert pipe._maybe_convert_prompt("<xxxx>", pipe.tokenizer) == "<xxxx> <xxxx>_1 <xxxx>_2"

prompt = "hey <xxxx>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)

def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down

0 comments on commit e08af4d

Please sign in to comment.