Skip to content

Commit

Permalink
add support for switching model checkpoints at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Sep 17, 2022
1 parent b8be33d commit 247f58a
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 58 deletions.
2 changes: 1 addition & 1 deletion modules/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[height]", str(p.height))
x = x.replace("[sampler]", sd_samplers.samplers[p.sampler_index].name)

x = x.replace("[model_hash]", shared.sd_model_hash)
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())

if cmd_opts.hide_ui_dir_config:
Expand Down
2 changes: 1 addition & 1 deletion modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def infotext(iteration=0, position_in_batch=0):
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model_hash else shared.sd_model_hash),
"Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
Expand Down
148 changes: 148 additions & 0 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import glob
import os.path
import sys
from collections import namedtuple
import torch
from omegaconf import OmegaConf


from ldm.util import instantiate_from_config

from modules import shared

CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash'])
checkpoints_list = {}

try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.

from transformers import logging

logging.set_verbosity_error()
except Exception:
pass


def list_models():
checkpoints_list.clear()

model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir)

def modeltitle(path, h):
abspath = os.path.abspath(path)

if abspath.startswith(model_dir):
name = abspath.replace(model_dir, '')
else:
name = os.path.basename(path)

if name.startswith("\\") or name.startswith("/"):
name = name[1:]

return f'{name} [{h}]'

cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)

if os.path.exists(model_dir):
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
h = model_hash(filename)
title = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h)


def model_hash(filename):
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()

file.seek(0x100000)
m.update(file.read(0x10000))
return m.hexdigest()[0:8]
except FileNotFoundError:
return 'NOFILE'


def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info

if len(checkpoints_list) == 0:
print(f"Checkpoint {model_checkpoint} not found and no other checkpoints found", file=sys.stderr)
return None

checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)

return checkpoint_info


def load_model_weights(model, checkpoint_file, sd_model_hash):
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")

pl_sd = torch.load(checkpoint_file, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]

model.load_state_dict(sd, strict=False)

if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)

if not shared.cmd_opts.no_half:
model.half()

model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file


def load_model():
from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint()

sd_config = OmegaConf.load(shared.cmd_opts.config)
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)

This comment has been minimized.

Copy link
@RetGal

RetGal Sep 17, 2022

Only load if checkpoint_info is not None


if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else:
sd_model.to(shared.device)

sd_hijack.model_hijack.hijack(sd_model)

sd_model.eval()

print(f"Model loaded.")
return sd_model


def reload_model_weights(sd_model):
from modules import lowvram, devices
checkpoint_info = select_checkpoint()

if sd_model.sd_model_checkpint == checkpoint_info.filename:
return

if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)

load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)

if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)

print(f"Weights loaded.")
return sd_model
19 changes: 14 additions & 5 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
import modules.styles
import modules.interrogate
import modules.memmon
import modules.sd_models

sd_model_file = os.path.join(script_path, 'model.ckpt')
if not os.path.exists(sd_model_file):
sd_model_file = "models/ldm/stable-diffusion-v1/model.ckpt"
default_sd_model_file = sd_model_file

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=os.path.join(sd_path, sd_model_file), help="path to checkpoint of model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth')
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
Expand Down Expand Up @@ -88,13 +89,17 @@ def nextjob(self):

face_restorers = []

modules.sd_models.list_models()


class Options:
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None):
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange

data = None
hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None
Expand Down Expand Up @@ -150,6 +155,7 @@ def __init__(self, default=None, label="", component=None, component_args=None):
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}),
}

def __init__(self):
Expand Down Expand Up @@ -180,6 +186,10 @@ def load(self, filename):
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)

def onchange(self, key, func):
item = self.data_labels.get(key)
item.onchange = func


opts = Options()
if os.path.exists(config_filename):
Expand All @@ -188,7 +198,6 @@ def load(self, filename):
sd_upscalers = []

sd_model = None
sd_model_hash = ''

progress_print_out = sys.stdout

Expand Down
5 changes: 5 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,12 @@ def run_settings(*args):
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue

oldval = opts.data.get(key, None)
opts.data[key] = value

if oldval != value and opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()

up.append(comp.update(value=value))

opts.save(shared.config_filename)
Expand Down
61 changes: 10 additions & 51 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@

from modules.paths import script_path

import torch
from omegaconf import OmegaConf

import signal

from ldm.util import instantiate_from_config

from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.ui
Expand All @@ -24,6 +19,7 @@
import modules.lowvram
import modules.txt2img
import modules.img2img
import modules.sd_models


modules.codeformer_model.setup_codeformer()
Expand All @@ -33,29 +29,17 @@
esrgan.load_models(cmd_opts.esrgan_models_path)
realesrgan.setup_realesrgan()

queue_lock = threading.Lock()

def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model [{shared.sd_model_hash}] from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]

model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if cmd_opts.opt_channelslast:
model = model.to(memory_format=torch.channels_last)
model.eval()
return model
def wrap_queued_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)

return res

queue_lock = threading.Lock()
return f


def wrap_gradio_gpu_call(func):
Expand All @@ -80,33 +64,8 @@ def f(*args, **kwargs):

modules.scripts.load_scripts(os.path.join(script_path, "scripts"))

try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.

from transformers import logging

logging.set_verbosity_error()
except Exception:
pass

with open(cmd_opts.ckpt, "rb") as file:
import hashlib
m = hashlib.sha256()

file.seek(0x100000)
m.update(file.read(0x10000))
shared.sd_model_hash = m.hexdigest()[0:8]

sd_config = OmegaConf.load(cmd_opts.config)
shared.sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
shared.sd_model = (shared.sd_model if cmd_opts.no_half else shared.sd_model.half())

if cmd_opts.lowvram or cmd_opts.medvram:
modules.lowvram.setup_for_low_vram(shared.sd_model, cmd_opts.medvram)
else:
shared.sd_model = shared.sd_model.to(shared.device)

modules.sd_hijack.model_hijack.hijack(shared.sd_model)
shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))


def webui():
Expand Down

0 comments on commit 247f58a

Please sign in to comment.