From f9ebe7ccdfa4cd39affd411f59a180d60497ce90 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Wed, 14 Dec 2022 15:49:38 +0800 Subject: [PATCH] revise hypernetwork support --- modules/api/api.py | 6 +++--- modules/hypernetworks/hypernetwork.py | 22 +++++++++++++++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 67e2cb95661..9b5aca227fa 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -413,14 +413,14 @@ def invocations(self, req: InvocationsRequest): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() response = self.img2imgapi(req.img2img_payload) shared.opts.data = default_options - return response + return response elif req.task == 'extras-single-image': response = self.extras_single_image_api(req.extras_single_payload) - shared.opts.data = default_options + shared.opts.data = default_options return response elif req.task == 'extras-batch-images': response = self.extras_batch_images_api(req.extras_batch_payload) - shared.opts.data = default_options + shared.opts.data = default_options return response elif req.task == 'sd-models': return self.get_sd_models() diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c406ffb379a..a7fcc4d93e3 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -20,6 +20,8 @@ from collections import defaultdict, deque from statistics import stdev, mean +import requests +import json optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} @@ -255,11 +257,21 @@ def load(self, filename): def list_hypernetworks(path): res = {} - for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)): - name = os.path.splitext(os.path.basename(filename))[0] - # Prevent a hypothetical "None.pt" from being listed. - if name != "None": - res[name + f"({sd_models.model_hash(filename)})"] = filename + if shared.cmd_opts.pureui: + response = requests.get(f'{shared.api_endpoint}/sd/hypernetwork') + if response.status_code == 200: + hypernetwork_names = json.loads(response.text) + for hypernetwork_name in sorted(hypernetwork_names): + filename = 'f{hypernetwork_name}.pt' + # Prevent a hypothetical "None.pt" from being listed. + if not hypernetwork_name.startswith("None"): + res[hypernetwork_name] = filename + else: + for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)): + name = os.path.splitext(os.path.basename(filename))[0] + # Prevent a hypothetical "None.pt" from being listed. + if name != "None": + res[name + f"({sd_models.model_hash(filename)})"] = filename return res