Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#4 from Klace/img2img_integration
Browse files Browse the repository at this point in the history
Img2img integration
  • Loading branch information
Klace committed Feb 1, 2023
2 parents 2c1bb46 + c88108f commit e3b1a85
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 27 deletions.
1 change: 1 addition & 0 deletions aes_scores.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions exif_data.json

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processed_image.save(os.path.join(output_dir, filename))


def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)

is_batch = mode == 5
Expand Down Expand Up @@ -132,6 +132,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
image_cfg_scale=image_cfg_scale,
width=width,
height=height,
restore_faces=restore_faces,
Expand Down
60 changes: 46 additions & 14 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import cv2
from skimage import exposure
from typing import Any, Dict, List, Optional
from torch import autocast


import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
Expand Down Expand Up @@ -186,7 +188,11 @@ def depth2img_image_conditioning(self, source_image):
return conditioning

def edit_image_conditioning(self, source_image):
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
#source_image = 2 * torch.tensor(np.array(source_image)).float() / 255 - 1
#source_image = rearrange(source_image, "h w c -> 1 c h w").to(shared.device)
#source_image = rearrange(source_image, "h w c -> 1 c h w").to(shared.device)
#conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()

return conditioning_image

Expand Down Expand Up @@ -450,11 +456,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Eta": (None),
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
}
Expand Down Expand Up @@ -622,15 +631,17 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"

print(f"c = {c} and uc = {uc}")
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)

x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
devices.test_for_nans(x, "vae")
#for x in x_samples_ddim:
# devices.test_for_nans(x, "vae")

x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
#x_samples_ddim = 255.0 * rearrange(x_samples_ddim, "1 c h w -> h w c")

del samples_ddim

Expand All @@ -645,7 +656,7 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)

#x_sample = 255.0 * rearrange(x_sample, "1 c h w -> h w c")
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
Expand Down Expand Up @@ -868,8 +879,8 @@ def save_intermediate(image, index):
save_intermediate(image, i)

image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
image = np.array(image).astype(np.float32) / 255.0 - 1
#image = np.moveaxis(image, 2, 0)
batch_images.append(image)

decoded_samples = torch.from_numpy(np.array(batch_images))
Expand Down Expand Up @@ -901,7 +912,7 @@ def save_intermediate(image, index):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None

def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, image_cfg_scale: float = 7.5, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs)

self.init_images = init_images
Expand All @@ -916,6 +927,7 @@ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_str
self.inpaint_full_res = inpaint_full_res
self.inpaint_full_res_padding = inpaint_full_res_padding
self.inpainting_mask_invert = inpainting_mask_invert
self.image_cfg_scale=image_cfg_scale
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
self.mask = None
self.nmask = None
Expand Down Expand Up @@ -983,9 +995,16 @@ def init(self, all_prompts, all_seeds, all_subseeds):

if add_color_corrections:
self.color_corrections.append(setup_color_correction(image))

image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
width, height = image.size
factor = self.width / max(width, height)
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
width = int((width * factor) // 64) * 64
height = int((height * factor) // 64) * 64
image = ImageOps.fit(image, (width, height), method=Image.Resampling.LANCZOS)

#image = 2 * torch.tensor(np.array(image)).float() / 255 - 1
#image = np.array(image).astype(np.float32) / 255.0
#image = np.moveaxis(image, 2, 0)

imgs.append(image)

Expand All @@ -1002,10 +1021,22 @@ def init(self, all_prompts, all_seeds, all_subseeds):
batch_images = np.array(imgs)
else:
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

image = torch.from_numpy(batch_images)
image = 2. * image - 1.
image = image.to(shared.device)

#image = torch.from_numpy(batch_images)
#width, height = image.size
#factor = 512 / max(width, height)
###factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
#width = int((width * factor) // 64) * 64
#height = int((height * factor) // 64) * 64
#image = ImageOps.fit(image, (width, height), method=Image.Resampling.LANCZOS)
##image = 2. * image - 1.
#image = rearrange(image, "h w c -> 1 c h w")
#image = image.to(shared.device)
#image = torch.from_numpy(batch_images)
#image = 2. * image - 1.
image = 2 * torch.tensor(np.array(image)).float() / 255 - 1
image = rearrange(image, "h w c -> 1 c h w").to(shared.device)
#image = image.to(shared.device)

self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

Expand All @@ -1032,6 +1063,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)

def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):

x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)

if self.initial_noise_multiplier != 1.0:
Expand Down
3 changes: 3 additions & 0 deletions modules/sd_samplers_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from PIL import Image
import torchsde._brownian.brownian_interval
from einops import rearrange
from modules import devices, processing, images, sd_vae_approx

from modules.shared import opts, state
Expand Down Expand Up @@ -38,8 +39,10 @@ def single_sample_to_image(sample, approximation=None):
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]

x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample, "1 c h w -> h w c")
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)

return Image.fromarray(x_sample)


Expand Down
49 changes: 38 additions & 11 deletions modules/sd_samplers_kdiffusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import deque
import torch
import inspect
import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common

Expand Down Expand Up @@ -57,17 +58,17 @@ def __init__(self, model):
self.init_latent = None
self.step = 0

def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)

for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)

denoised[i] += cond_scale * (x_out[cond_index] - denoised_uncond[i]) + image_scale * (denoised_uncond[i] - x_out[cond_index])
return denoised

def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_scale):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException

Expand All @@ -76,10 +77,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):

batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]

x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
x_in = einops.repeat(x, "1 ... -> n ...", n=3)
sigma_in = einops.repeat(sigma, "1 ... -> n ...", n=3)
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [image_cond])

denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
Expand All @@ -88,7 +88,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
sigma_in = denoiser_params.sigma

if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
cond_in = torch.cat([tensor, uncond, uncond])

if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
Expand All @@ -115,7 +115,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])

denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_scale)

if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
Expand All @@ -124,6 +124,32 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):

return denoised

class CFGDenoiserIp2p(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.mask = None
self.nmask = None
self.init_latent = None
self.step = 0

def forward(self, z, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
image_cond_in = einops.repeat(image_cond, "1 ... -> n ...", n=3)

conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
cond_in = torch.cat([tensor, uncond])
cfg_cond = {
"c_crossattn": [cond_in],
"c_concat": [image_cond_in],
}
out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
return out_uncond + cond_scale * (out_cond - out_img_cond) + 1.5 * (out_img_cond - out_uncond)


class TorchHijack:
def __init__(self, sampler_noises):
Expand Down Expand Up @@ -265,7 +291,8 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
'cond_scale': p.cfg_scale,
'image_scale': p.image_cfg_scale
}, disable=False, callback=self.callback_state, **extra_params_kwargs))

return samples
Expand Down
4 changes: 3 additions & 1 deletion modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,8 @@ def copy_image(img):

elif category == "cfg":
with FormGroup():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.5, elem_id="img2img_cfg_scale")
image_cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='Image CFG Scale', value=1.5, elem_id="img2img_cfg_scale")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")

elif category == "seed":
Expand Down Expand Up @@ -861,6 +862,7 @@ def select_img2img_tab(tab):
batch_count,
batch_size,
cfg_scale,
image_cfg_scale,
denoising_strength,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
Expand Down

0 comments on commit e3b1a85

Please sign in to comment.