Skip to content

Commit

Permalink
initial SD3 support
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Jun 16, 2024
1 parent a7116aa commit 5b2a60b
Show file tree
Hide file tree
Showing 14 changed files with 333 additions and 44 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.

- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- Spandrel - https://github.com/chaiNNer-org/spandrel implementing
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
Expand Down
5 changes: 5 additions & 0 deletions configs/sd3-inference.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model:
target: modules.models.sd3.sd3_model.SD3Inferencer
params:
shift: 3
state_dict: null
4 changes: 3 additions & 1 deletion extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
else:
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)

for name, module in cond_stage_model.named_modules():
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
Expand Down
3 changes: 2 additions & 1 deletion modules/models/sd3/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch
import torch.nn as nn
from einops import rearrange, repeat
from other_impls import attention, Mlp
from modules.models.sd3.other_impls import attention, Mlp


class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding"""
Expand Down
14 changes: 7 additions & 7 deletions modules/models/sd3/sd3_impls.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
### Impls of the SD3 core diffusion model and VAE

import torch, math, einops
from mmdit import MMDiT
from modules.models.sd3.mmdit import MMDiT
from PIL import Image


Expand Down Expand Up @@ -46,16 +46,16 @@ def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):

class BaseModel(torch.nn.Module):
"""Wrapper around the core MM-DiT model"""
def __init__(self, shift=1.0, device=None, dtype=torch.float32, file=None, prefix=""):
def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
super().__init__()
# Important configuration values can be quickly determined by checking shapes in the source file
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2]
depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64
num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1]
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1]
context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
context_embedder_config = {
"target": "torch.nn.Linear",
"params": {
Expand Down
166 changes: 166 additions & 0 deletions modules/models/sd3/sd3_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import contextlib
import os
from typing import Mapping

import safetensors
import torch

import k_diffusion
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat

from modules import shared, modelloader, devices

CLIPG_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors"
CLIPG_CONFIG = {
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
}

CLIPL_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors"
CLIPL_CONFIG = {
"hidden_act": "quick_gelu",
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
}

T5_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/t5xxl_fp16.safetensors"
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
"num_heads": 64,
"num_layers": 24,
"vocab_size": 32128,
}


class SafetensorsMapping(Mapping):
def __init__(self, file):
self.file = file

def __len__(self):
return len(self.file.keys())

def __iter__(self):
for key in self.file.keys():
yield key

def __getitem__(self, key):
return self.file.get_tensor(key)


class SD3Cond(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.tokenizer = SD3Tokenizer()

with torch.no_grad():
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)

self.weights_loaded = False

def forward(self, prompts: list[str]):
res = []

for prompt in prompts:
tokens = self.tokenizer.tokenize_with_weights(prompt)
l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)

res.append({
'crossattn': lgt_out[0].to(devices.device),
'vector': vector_out[0].to(devices.device),
})

return res

def load_weights(self):
if self.weights_loaded:
return

clip_path = os.path.join(shared.models_path, "CLIP")

clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
with safetensors.safe_open(clip_g_file, framework="pt") as file:
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))

clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)

t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)

self.weights_loaded = True

def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX


class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas):
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
self.inner_model = inner_model

def forward(self, input, sigma, **kwargs):
return self.inner_model.apply_model(input, sigma, **kwargs)


class SD3Inferencer(torch.nn.Module):
def __init__(self, state_dict, shift=3, use_ema=False):
super().__init__()

self.shift = shift

with torch.no_grad():
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
self.first_stage_model.dtype = self.model.diffusion_model.dtype

self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)

self.cond_stage_model = SD3Cond()
self.cond_stage_key = 'txt'

self.parameterization = "eps"
self.model.conditioning_key = "crossattn"

self.latent_format = SD3LatentFormat()
self.latent_channels = 16

def after_load_weights(self):
self.cond_stage_model.load_weights()

def ema_scope(self):
return contextlib.nullcontext()

def get_learned_conditioning(self, batch: list[str]):
return self.cond_stage_model(batch)

def apply_model(self, x, t, cond):
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])

def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent)
return self.first_stage_model.decode(latent)

def encode_first_stage(self, image):
latent = self.first_stage_model.encode(image)
return self.latent_format.process_in(latent)

def create_denoiser(self):
return SD3Denoiser(self, self.model.model_sampling.sigmas)
3 changes: 2 additions & 1 deletion modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]

p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)

if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
Expand Down
Loading

0 comments on commit 5b2a60b

Please sign in to comment.