Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comma backtrack padding #2192

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 18 additions & 1 deletion modules/sd_hijack.py
Expand Up @@ -107,6 +107,8 @@ def __init__(self, wrapped, hijack):
self.tokenizer = wrapped.tokenizer
self.token_mults = {}

self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]

tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
Expand Down Expand Up @@ -136,6 +138,7 @@ def tokenize_line(self, line, used_custom_terms, hijack_comments):
fixes = []
remade_tokens = []
multipliers = []
last_comma = -1

for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
Expand All @@ -144,6 +147,20 @@ def tokenize_line(self, line, used_custom_terms, hijack_comments):

embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)

if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]

remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)

rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults

if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
Expand Down Expand Up @@ -284,7 +301,7 @@ def forward(self, text):
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]

self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Expand Up @@ -226,6 +226,7 @@ def options_section(section_identifier, options_dict):
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
Expand Down