Skip to content

Commit

Permalink
revise hypernetwork support
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Dec 14, 2022
1 parent 03a84e1 commit f9ebe7c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
6 changes: 3 additions & 3 deletions modules/api/api.py
Expand Up @@ -413,14 +413,14 @@ def invocations(self, req: InvocationsRequest):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.img2imgapi(req.img2img_payload)
shared.opts.data = default_options
return response
return response
elif req.task == 'extras-single-image':
response = self.extras_single_image_api(req.extras_single_payload)
shared.opts.data = default_options
shared.opts.data = default_options
return response
elif req.task == 'extras-batch-images':
response = self.extras_batch_images_api(req.extras_batch_payload)
shared.opts.data = default_options
shared.opts.data = default_options
return response
elif req.task == 'sd-models':
return self.get_sd_models()
Expand Down
22 changes: 17 additions & 5 deletions modules/hypernetworks/hypernetwork.py
Expand Up @@ -20,6 +20,8 @@

from collections import defaultdict, deque
from statistics import stdev, mean
import requests
import json


optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
Expand Down Expand Up @@ -255,11 +257,21 @@ def load(self, filename):

def list_hypernetworks(path):
res = {}
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name + f"({sd_models.model_hash(filename)})"] = filename
if shared.cmd_opts.pureui:
response = requests.get(f'{shared.api_endpoint}/sd/hypernetwork')
if response.status_code == 200:
hypernetwork_names = json.loads(response.text)
for hypernetwork_name in sorted(hypernetwork_names):
filename = 'f{hypernetwork_name}.pt'
# Prevent a hypothetical "None.pt" from being listed.
if not hypernetwork_name.startswith("None"):
res[hypernetwork_name] = filename
else:
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name + f"({sd_models.model_hash(filename)})"] = filename
return res


Expand Down

0 comments on commit f9ebe7c

Please sign in to comment.