Skip to content

Commit

Permalink
Interrogate: add option to include ranks in output
Browse files Browse the repository at this point in the history
Since the UI also allows users to specify ranks, it can be useful to show people what ranks are being returned by interrogate

Testing Steps:
* Navigate to img2img tab, use interrogate DeepBooru, verify tags appears as before.
* Navigate to Settings tab, enable new option, click "apply settings"
* Navigate to img2img, Interrogate DeepBooru again, verify that weights appear and are properly formatted
  • Loading branch information
HunterVacui committed Oct 10, 2022
1 parent 8acc901 commit 6ed4faa
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
14 changes: 9 additions & 5 deletions modules/deepbooru.py
Expand Up @@ -3,7 +3,7 @@
from multiprocessing import get_context


def _load_tf_and_return_tags(pil_image, threshold):
def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
Expand Down Expand Up @@ -52,22 +52,26 @@ def _load_tf_and_return_tags(pil_image, threshold):
if result_dict[tag] >= threshold:
if tag.startswith("rating:"):
continue
result_tags_out.append(tag)
tag_formatted = tag.replace('_', ' ').replace(':', ' ')
if include_ranks:
result_tags_out.append(f'({tag_formatted}:{result_dict[tag]})')
else:
result_tags_out.append(tag_formatted)
result_tags_print.append(f'{result_dict[tag]} {tag}')

print('\n'.join(sorted(result_tags_print, reverse=True)))

return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
return ', '.join(result_tags_out)


def subprocess_init_no_cuda():
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


def get_deepbooru_tags(pil_image, threshold=0.5):
def get_deepbooru_tags(pil_image, threshold=0.5, include_ranks=False):
context = get_context('spawn')
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, include_ranks)
ret = f.result() # will rethrow any exceptions
return ret
7 changes: 5 additions & 2 deletions modules/interrogate.py
Expand Up @@ -123,7 +123,7 @@ def generate_caption(self, pil_image):

return caption[0]

def interrogate(self, pil_image):
def interrogate(self, pil_image, include_ranks=False):
res = None

try:
Expand Down Expand Up @@ -156,7 +156,10 @@ def interrogate(self, pil_image):
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
res += ", " + match
if include_ranks:
res += ", " + match
else:
res += f", ({match}:{score})"

except Exception:
print(f"Error interrogating", file=sys.stderr)
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Expand Up @@ -232,6 +232,7 @@ def options_section(section_identifier, options_dict):
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
Expand Down
5 changes: 2 additions & 3 deletions modules/ui.py
Expand Up @@ -305,13 +305,12 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):


def interrogate(image):
prompt = shared.interrogator.interrogate(image)

prompt = shared.interrogator.interrogate(image, include_ranks=opts.interrogate_return_ranks)
return gr_show(True) if prompt is None else prompt


def interrogate_deepbooru(image):
prompt = get_deepbooru_tags(image)
prompt = get_deepbooru_tags(image, include_ranks=opts.interrogate_return_ranks)
return gr_show(True) if prompt is None else prompt


Expand Down

0 comments on commit 6ed4faa

Please sign in to comment.