Skip to content

Commit

Permalink
extra networks UI
Browse files Browse the repository at this point in the history
rework of hypernets: rather than via settings, hypernets are added directly to prompt as <hypernet:name:weight>
  • Loading branch information
AUTOMATIC1111 committed Jan 21, 2023
1 parent e33cace commit 40ff6db
Show file tree
Hide file tree
Showing 25 changed files with 765 additions and 214 deletions.
Binary file added html/card-no-preview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 11 additions & 0 deletions html/extra-networks-card.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<div class='card' {preview_html} onclick='return cardClicked({prompt}, {allow_negative_prompt})'>
<div class='actions'>
<div class='additional'>
<ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick='return saveCardPreview(event, {tabname}, {local_preview})'>replace preview</a>
</ul>
</div>
<span class='name'>{name}</span>
</div>
</div>

8 changes: 8 additions & 0 deletions html/extra-networks-no-cards.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<div class='nocards'>
<h1>Nothing here. Add some content to the following directories:</h1>

<ul>
{dirs}
</ul>
</div>

60 changes: 60 additions & 0 deletions javascript/extraNetworks.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

function setupExtraNetworksForTab(tabname){
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')

gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh'))
gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close'))
}

var activePromptTextarea = null;
var activePositivePromptTextarea = null;

function setupExtraNetworks(){
setupExtraNetworksForTab('txt2img')
setupExtraNetworksForTab('img2img')

function registerPrompt(id, isNegative){
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");

if (activePromptTextarea == null){
activePromptTextarea = textarea
}
if (activePositivePromptTextarea == null && ! isNegative){
activePositivePromptTextarea = textarea
}

textarea.addEventListener("focus", function(){
activePromptTextarea = textarea;
if(! isNegative) activePositivePromptTextarea = textarea;
});
}

registerPrompt('txt2img_prompt')
registerPrompt('txt2img_neg_prompt', true)
registerPrompt('img2img_prompt')
registerPrompt('img2img_neg_prompt', true)
}

onUiLoaded(setupExtraNetworks)

function cardClicked(textToAdd, allowNegativePrompt){
textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea

textarea.value = textarea.value + " " + textToAdd
updateInput(textarea)

return false
}

function saveCardPreview(event, tabname, filename){
textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
button = gradioApp().getElementById(tabname + '_save_preview')

textarea.value = filename
updateInput(textarea)

button.click()

event.stopPropagation()
event.preventDefault()
}
2 changes: 2 additions & 0 deletions javascript/hints.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ titles = {
"\U0001F5D1": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show extra networks",


"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
Expand Down
9 changes: 4 additions & 5 deletions javascript/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
return [prompt, negative_prompt]
}



opts = {}
onUiUpdate(function(){
if(Object.keys(opts).length != 0) return;
Expand Down Expand Up @@ -239,11 +237,14 @@ onUiUpdate(function(){
return
}


prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative"

textarea.addEventListener("input", () => update_token_counter(id_button));
textarea.addEventListener("input", function(){
update_token_counter(id_button);
});
}

registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
Expand All @@ -261,10 +262,8 @@ onUiUpdate(function(){
})
}
}

})


onOptionsChanged(function(){
elem = gradioApp().getElementById('sd_checkpoint_hash')
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
Expand Down
7 changes: 3 additions & 4 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def train_embedding(self, args: dict):
def train_hypernetwork(self, args: dict):
try:
shared.state.begin()
initial_hypernetwork = shared.loaded_hypernetwork
shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
Expand All @@ -491,16 +491,15 @@ def train_hypernetwork(self, args: dict):
except Exception as e:
error = e
finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
except AssertionError as msg:
shared.state.end()
return TrainResponse(info = "train embedding error: {error}".format(error = error))
return TrainResponse(info="train embedding error: {error}".format(error=error))

def get_memory(self):
try:
Expand Down
147 changes: 147 additions & 0 deletions modules/extra_networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import re
from collections import defaultdict

from modules import errors

extra_network_registry = {}


def initialize():
extra_network_registry.clear()


def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network


class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []


class ExtraNetwork:
def __init__(self, name):
self.name = name

def activate(self, p, params_list):
"""
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
Passes arguments related to this extra network in params_list.
User passes arguments by specifying this in his prompt:
<name:arg1:arg2:arg3>
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
separated by colon.
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
in this case, all effects of this extra networks should be disabled.
Can be called multiple times before deactivate() - each new call should override the previous call completely.
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
params_list will be:
[
ExtraNetworkParams(items=["agm", "1.1"]),
ExtraNetworkParams(items=["ray"])
]
"""
raise NotImplementedError

def deactivate(self, p):
"""
Called at the end of processing for housekeeping. No need to do anything here.
"""

raise NotImplementedError


def activate(p, extra_network_data):
"""call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list"""

for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
print(f"Skipping unknown extra network: {extra_network_name}")
continue

try:
extra_network.activate(p, extra_network_args)
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")

for extra_network_name, extra_network in extra_network_registry.items():
args = extra_network_data.get(extra_network_name, None)
if args is not None:
continue

try:
extra_network.activate(p, [])
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}")


def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks"""

for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
continue

try:
extra_network.deactivate(p)
except Exception as e:
errors.display(e, f"deactivating extra network {extra_network_name}")

for extra_network_name, extra_network in extra_network_registry.items():
args = extra_network_data.get(extra_network_name, None)
if args is not None:
continue

try:
extra_network.deactivate(p)
except Exception as e:
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")


re_extra_net = re.compile(r"<(\w+):([^>]+)>")


def parse_prompt(prompt):
res = defaultdict(list)

def found(m):
name = m.group(1)
args = m.group(2)

res[name].append(ExtraNetworkParams(items=args.split(":")))

return ""

prompt = re.sub(re_extra_net, found, prompt)

return prompt, res


def parse_prompts(prompts):
res = []
extra_data = None

for prompt in prompts:
updated_prompt, parsed_extra_data = parse_prompt(prompt)

if extra_data is None:
extra_data = parsed_extra_data

res.append(updated_prompt)

return res, extra_data

21 changes: 21 additions & 0 deletions modules/extra_networks_hypernet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from modules import extra_networks
from modules.hypernetworks import hypernetwork


class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('hypernet')

def activate(self, p, params_list):
names = []
multipliers = []
for params in params_list:
assert len(params.items) > 0

names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)

hypernetwork.load_hypernetworks(names, multipliers)

def deactivate(p, self):
pass
12 changes: 3 additions & 9 deletions modules/generation_parameters_copypaste.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict):
from modules import ui

settings_map = {
'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash',
Expand Down Expand Up @@ -275,13 +273,9 @@ def parse_generation_parameters(x: str):
if "Clip skip" not in res:
res["Clip skip"] = "1"

if "Hypernet strength" not in res:
res["Hypernet strength"] = "1"

if "Hypernet" in res:
hypernet_name = res["Hypernet"]
hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
hypernet = res.get("Hypernet", None)
if hypernet is not None:
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""

if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
Expand Down
Loading

6 comments on commit 40ff6db

@mykeehu
Copy link
Contributor

@mykeehu mykeehu commented on 40ff6db Jan 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this innovation, but this update stopped my webUI because it was also using the monkeypatch extension, and that's what I got after the update:

None
Traceback (most recent call last):
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\launch.py", line 317, in <module>
    start()
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\launch.py", line 312, in start
    webui.webui()
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\webui.py", line 162, in webui
    initialize()
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\webui.py", line 91, in initialize
    shared.reload_hypernetworks()
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\extensions\Hypernetwork-MonkeyPatch-Extension\patches\shared.py", line 9, in reload_hypernetworks
    load_hypernetwork(opts.sd_hypernetwork)
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\extensions\Hypernetwork-MonkeyPatch-Extension\patches\hypernetwork.py", line 479, in load_hypernetwork
    if shared.loaded_hypernetwork is not None:
AttributeError: module 'modules.shared' has no attribute 'loaded_hypernetwork'. Did you mean: 'loaded_hypernetworks'?

Do you have a private discord group with the extension developers, where you let them know before such a change so they can prepare the extension for the revision? Because the user will see that it doesn't work in SD after the update. Plus the model makers can't use the extensions until then.

@etherealxx
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this innovation, but this update stopped my webUI because it was also using the monkeypatch extension, and that's what I got after the update:

None
Traceback (most recent call last):
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\launch.py", line 317, in <module>
    start()
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\launch.py", line 312, in start
    webui.webui()
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\webui.py", line 162, in webui
    initialize()
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\webui.py", line 91, in initialize
    shared.reload_hypernetworks()
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\extensions\Hypernetwork-MonkeyPatch-Extension\patches\shared.py", line 9, in reload_hypernetworks
    load_hypernetwork(opts.sd_hypernetwork)
  File "H:\Stable-Diffusion-Automatic\stable-diffusion-webui\extensions\Hypernetwork-MonkeyPatch-Extension\patches\hypernetwork.py", line 479, in load_hypernetwork
    if shared.loaded_hypernetwork is not None:
AttributeError: module 'modules.shared' has no attribute 'loaded_hypernetwork'. Did you mean: 'loaded_hypernetworks'?

Do you have a private discord group with the extension developers, where you let them know before such a change so they can prepare the extension for the revision? Because the user will see that it doesn't work in SD after the update. Plus the model makers can't use the extensions until then.

have you updated your monkeypatch extension? because i got the same problem few days ago, i told the devs, and hours later they fixed the problem

@aria1th
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently its fixed as aria1th/Hypernetwork-MonkeyPatch-Extension@3dbb5ab

But, current original webui won't be able to train hypernetwork, and cannot generate proper preview from it since modules.processing will remove any hypernetwork loaded, and it only applies hypernetwork when its present in txt2img prompt.

I did a hacky but clean fix by aria1th/Hypernetwork-MonkeyPatch-Extension@3dbb5ab#diff-2615363dc2dc5d0416f2d0a04afe86e995ad961f6a02e86678318ffd590e6573R28

because my extension name is monkey patch, but how should it be implemented in webui?

@Kilvoctu
Copy link
Contributor

@Kilvoctu Kilvoctu commented on 40ff6db Jan 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit appears to break hypernetwork API for new Web UI setups. I logged a bug report here #7036. In short, there seems to be currently no way to send hypernetwork into payload for API. The issue persists through current commit of f2eae61

If you could provide backwards compatibility or a new API process, that'd be much appreciated.

edit: Want to add that I do like the changes in this commit. Just that I'd like API support to return.

@ice051128
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What a stupid commit, this basically makes hypernetworks completely unusable as u cannot choose em from the dropdown menu anymore

@mykeehu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ice051128 Below the Generate button is the icon from which you can select, from the gallery instead of the menu

Please sign in to comment.