Skip to content

Commit

Permalink
Iterate over unique tokens to avoid duplicate replacements for multiv…
Browse files Browse the repository at this point in the history
…ector embeddings (huggingface#3588)

* iterate over unique tokens to avoid duplicate replacements

* added test for multiple references to multi embedding

* adhere to black formatting

* reorder test post-rebase
  • Loading branch information
lachlan-nicholson authored and Jimmy committed Apr 26, 2024
1 parent df54368 commit 9110dd8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
`str`: The converted prompt
"""
tokens = tokenizer.tokenize(prompt)
for token in tokens:
unique_tokens = set(tokens)
for token in unique_tokens:
if token in tokenizer.added_tokens_encoder:
replacement = token
i = 1
Expand Down
12 changes: 12 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,18 @@ def test_text_inversion_download(self):
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)

# multiple references to multi embedding
ten = {"<cat>": torch.ones(3, 32)}
pipe.load_textual_inversion(ten)

assert (
pipe._maybe_convert_prompt("<cat> <cat>", pipe.tokenizer) == "<cat> <cat>_1 <cat>_2 <cat> <cat>_1 <cat>_2"
)

prompt = "hey <cat> <cat>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)

def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down

0 comments on commit 9110dd8

Please sign in to comment.