Skip to content

Commit

Permalink
Comma backtrack padding (#2192)
Browse files Browse the repository at this point in the history
Comma backtrack padding
  • Loading branch information
hentailord85ez committed Oct 11, 2022
1 parent 8617396 commit 5e2627a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
19 changes: 18 additions & 1 deletion modules/sd_hijack.py
Expand Up @@ -107,6 +107,8 @@ def __init__(self, wrapped, hijack):
self.tokenizer = wrapped.tokenizer
self.token_mults = {}

self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]

tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
Expand Down Expand Up @@ -136,6 +138,7 @@ def tokenize_line(self, line, used_custom_terms, hijack_comments):
fixes = []
remade_tokens = []
multipliers = []
last_comma = -1

for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
Expand All @@ -144,6 +147,20 @@ def tokenize_line(self, line, used_custom_terms, hijack_comments):

embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)

if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]

remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)

rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults

if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
Expand Down Expand Up @@ -284,7 +301,7 @@ def forward(self, text):
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]

self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Expand Up @@ -227,6 +227,7 @@ def options_section(section_identifier, options_dict):
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
Expand Down

0 comments on commit 5e2627a

Please sign in to comment.