@@ -27,6 +27,18 @@
"PYTHONPATH": "${workspaceFolder}/src/diffusers/src:${workspaceFolder}/src/k-diffusion:${workspaceFolder}/src:${env.PYTHONPATH}",
"PYTORCH_ENABLE_MPS_FALLBACK": "1",
}
},
{
"name": "Python: Structured Diffusion Play",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/scripts/structured_diffusion_play.py",
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"PYTHONPATH": "${workspaceFolder}/src/diffusers/src:${workspaceFolder}/src/k-diffusion:${workspaceFolder}/src:${env.PYTHONPATH}",
"PYTORCH_ENABLE_MPS_FALLBACK": "1",
}
}
]
}
@@ -9,7 +9,8 @@
from helpers.schedule_params import get_alphas, get_alphas_cumprod, get_betas
from helpers.get_seed import get_seed
from helpers.latents_to_pils import LatentsToPils, make_latents_to_pils
from helpers.embed_text import ClipCheckpoint, ClipImplementation, Embed, get_embedder
from helpers.clip_identifiers import ClipCheckpoint, ClipImplementation
from helpers.embed_text import Embed, get_embedder
from k_diffusion.external import DiscreteSchedule
from k_diffusion.sampling import get_sigmas_karras, sample_dpmpp_2m

@@ -17,14 +17,18 @@
from helpers.schedule_params import get_alphas, get_alphas_cumprod, get_betas
from helpers.get_seed import get_seed
from helpers.latents_to_pils import LatentsToPils, make_latents_to_pils
from helpers.embed_text import ClipCheckpoint, ClipImplementation, Embed, get_embedder
from helpers.clip_identifiers import ClipCheckpoint, ClipImplementation
from helpers.embed_text import Embed, get_embedder
from helpers.tokenize_text import CountTokens, get_token_counter
from helpers.structured_diffusion import get_structured_embedder, StructuredEmbed, StructuredEmbedding

from typing import List
from PIL import Image
import time

half = True
cfg_enabled = True
structured_diffusion = True

n_rand_seeds = 10
seeds = [
@@ -114,6 +118,14 @@
device=device,
torch_dtype=torch_dtype
)
count_tokens: CountTokens = get_token_counter(
impl=clip_impl,
ckpt=clip_ckpt,
)
sembed: StructuredEmbed = get_structured_embedder(
embed=embed,
count_tokens=count_tokens,
)

schedule_template = KarrasScheduleTemplate.Mastering
schedule: KarrasScheduleParams = get_template_schedule(schedule_template, unet_k_wrapped)
@@ -133,11 +145,11 @@
print(f"sigmas (quantized):\n{', '.join(['%.4f' % s.item() for s in sigmas_quantized])}")

# prompt='Emad Mostaque high-fiving Gordon Ramsay'
prompt = 'artoria pendragon (fate), carnelian, 1girl, general content, upper body, white shirt, blonde hair, looking at viewer, medium breasts, hair between eyes, floating hair, green eyes, blue ribbon, long sleeves, light smile, hair ribbon, watercolor (medium), traditional media'
prompt = 'two blue sheep and a red goat'
# prompt = 'artoria pendragon (fate), carnelian, 1girl, general content, upper body, white shirt, blonde hair, looking at viewer, medium breasts, hair between eyes, floating hair, green eyes, blue ribbon, long sleeves, light smile, hair ribbon, watercolor (medium), traditional media'
# prompt = "masterpiece character portrait of a blonde girl, full resolution, 4k, mizuryuu kei, akihiko. yoshida, Pixiv featured, baroque scenic, by artgerm, sylvain sarrailh, rossdraws, wlop, global illumination, vaporwave"

unprompts = [''] if cfg_enabled else []
prompts = [*unprompts, prompt]
prompts = [prompt]

sample_path='out'
intermediates_path='intermediates'
@@ -152,13 +164,21 @@
height = width
latents_shape = (batch_size * num_images_per_prompt, unet.in_channels, height // 8, width // 8)
with no_grad():
text_embeddings: Tensor = embed(prompts)
chunked = text_embeddings.chunk(text_embeddings.size(0))
if cfg_enabled:
uc, c = chunked
if structured_diffusion:
structured_embedding: StructuredEmbedding = sembed(prompts, gimme_uncond=cfg_enabled)
uc = structured_embedding.uncond
c = structured_embedding.embeds
np_arities = structured_embedding.np_arities
else:
uc = None
c, = chunked
unprompts = [''] if cfg_enabled else []
text_embeddings: Tensor = embed([*unprompts, *prompts])
chunked = text_embeddings.chunk(text_embeddings.size(0))
if cfg_enabled:
uc, c = chunked
else:
uc = None
c, = chunked
np_arities = None

batch_tic = time.perf_counter()
for seed in seeds:
@@ -167,7 +187,12 @@

tic = time.perf_counter()

denoiser: Denoiser = denoiser_factory(uncond=uc, cond=c, cond_scale=cond_scale)
denoiser: Denoiser = denoiser_factory(
uncond=uc,
cond=c,
cond_scale=cond_scale,
np_arities=np_arities,
)
noise_sampler = BrownianTreeNoiseSampler(
latents,
sigma_min=sigma_min,
@@ -0,0 +1,60 @@
import torch
from torch import Tensor, no_grad

from helpers.device import DeviceLiteral, get_device_type
from helpers.clip_identifiers import ClipCheckpoint, ClipImplementation
from helpers.embed_text import Embed, get_embedder
from helpers.tokenize_text import CountTokens, get_token_counter
from helpers.structured_diffusion import get_structured_embedder, StructuredEmbed, StructuredEmbedding
from typing import List

device_type: DeviceLiteral = get_device_type()
device = torch.device(device_type)
torch_dtype=torch.float32
cfg_enabled=True

model_name = (
# 'CompVis/stable-diffusion-v1-4'
# 'hakurei/waifu-diffusion'
'runwayml/stable-diffusion-v1-5'
# 'stabilityai/stable-diffusion-2'
# 'stabilityai/stable-diffusion-2-1'
# 'stabilityai/stable-diffusion-2-base'
# 'stabilityai/stable-diffusion-2-1-base'
)

sd2_768_models = { 'stabilityai/stable-diffusion-2', 'stabilityai/stable-diffusion-2-1' }
sd2_base_models = { 'stabilityai/stable-diffusion-2-base', 'stabilityai/stable-diffusion-2-1-base' }
sd2_models = { *sd2_768_models, *sd2_base_models }

laion_embed_models = { *sd2_models }
penultimate_clip_hidden_state_models = { *sd2_models }

needs_laion_embed = model_name in laion_embed_models
needs_penultimate_clip_hidden_state = model_name in penultimate_clip_hidden_state_models

clip_impl = ClipImplementation.HF
clip_ckpt = ClipCheckpoint.LAION if needs_laion_embed else ClipCheckpoint.OpenAI
clip_subtract_hidden_state_layers = 1 if needs_penultimate_clip_hidden_state else 0
embed: Embed = get_embedder(
impl=clip_impl,
ckpt=clip_ckpt,
subtract_hidden_state_layers=clip_subtract_hidden_state_layers,
device=device,
torch_dtype=torch_dtype
)
count_tokens: CountTokens = get_token_counter(
impl=clip_impl,
ckpt=clip_ckpt,
)
sembed: StructuredEmbed = get_structured_embedder(
embed=embed,
count_tokens=count_tokens,
)

prompt = 'two blue sheep with a red car'
prompts: List[str] = [prompt]
with no_grad():
structured_embedding: StructuredEmbedding = sembed(prompts, gimme_uncond=cfg_enabled)
uc = structured_embedding.uncond
c = structured_embedding.embeds
@@ -1,6 +1,8 @@
from functools import partial
from .diffusers_denoiser import DiffusersSDDenoiser
from torch import Tensor, cat
from typing import Optional, Protocol, NamedTuple
import torch
from torch import Tensor, LongTensor, cat, tensor
from typing import Optional, Protocol, NamedTuple, List
from abc import ABC, abstractmethod

class Denoiser(Protocol):
@@ -51,6 +53,7 @@ def get_cfg_conds(self, x: Tensor, sigma: Tensor) -> CFGConds:

class ParallelCFGDenoiser(AbstractCFGDenoiser):
cond_in: Tensor
noised_latent_copies_needed: int
def __init__(
self,
denoiser: DiffusersSDDenoiser,
@@ -59,12 +62,14 @@ def __init__(
cond_scale: float,
):
self.cond_in = cat([uncond, cond])
self.noised_latent_copies_needed = self.cond_in.size(dim=0)
super().__init__(
denoiser=denoiser,
cond_scale=cond_scale,
)

def get_cfg_conds(self, x: Tensor, sigma: Tensor) -> CFGConds:
x_in = x.expand(self.cond_in.size(dim=0), -1, -1, -1)
x_in = x.expand(self.noised_latent_copies_needed, -1, -1, -1)
del x
uncond, cond = self.denoiser(input=x_in, sigma=sigma, encoder_hidden_states=self.cond_in).chunk(self.cond_in.size(dim=0))
return CFGConds(uncond, cond)
@@ -78,6 +83,27 @@ def __init__(self, denoiser: DiffusersSDDenoiser, cond: Tensor):
def __call__(self, x: Tensor, sigma: Tensor) -> Tensor:
return self.denoiser(input=x, sigma=sigma, encoder_hidden_states=self.cond)

class StructuredDiffusionDenoiser(ParallelCFGDenoiser):
def __init__(
self,
denoiser: DiffusersSDDenoiser,
uncond: Tensor,
cond: Tensor,
np_arities: List[int],
cond_scale: float = 1.0,
):
super().__init__(
denoiser=denoiser,
uncond=uncond,
cond=cond,
cond_scale=cond_scale,
)
np_arities: LongTensor = tensor(np_arities, dtype=torch.long, device=cond.device)
self.denoiser = partial(denoiser, np_arities=np_arities)
self.uncond = uncond
self.cond = cond
self.noised_latent_copies_needed = 1 + np_arities.size(0)

class DenoiserFactory():
denoiser: DiffusersSDDenoiser
# this is a workaround which caters for some wacky experiments
@@ -94,7 +120,19 @@ def __call__(
cond: Tensor,
uncond: Optional[Tensor] = None,
cond_scale: float = 1.0,
np_arities: Optional[List[int]] = None,
) -> Denoiser:
if np_arities is not None:
# structured diffusion
assert uncond is not None and cond_scale > 1, 'structured diffusion only implemented for CFG; please provide uncond and set cond_scale > 1'
assert self.one_at_a_time is False, "structured diffusion only implemented for parallel usage, i.e. submitting uncond and cond simultaneously"
return StructuredDiffusionDenoiser(
uncond=uncond,
denoiser=self.denoiser,
cond=cond,
np_arities=np_arities,
cond_scale=cond_scale,
)
if uncond is None or cond_scale is None:
return NoCFGDenoiser(
denoiser=self.denoiser,
@@ -0,0 +1,10 @@
from enum import Enum, auto

class ClipImplementation(Enum):
HF = auto()
OpenCLIP = auto()
# OpenAI CLIP and clip-anytorch not implemented

class ClipCheckpoint(Enum):
OpenAI = auto()
LAION = auto()
@@ -1,8 +1,8 @@
from torch import Tensor, FloatTensor
from torch import Tensor, FloatTensor, LongTensor
from diffusers.models import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from k_diffusion.external import DiscreteEpsDDPMDenoiser, DiscreteVDDPMDenoiser
from typing import Union
from typing import Optional, Union
import torch

class DiffusersSDDenoiser(DiscreteEpsDDPMDenoiser):
@@ -18,12 +18,17 @@ def get_eps(
timestep: Union[Tensor, float, int],
encoder_hidden_states: Tensor,
return_dict: bool = True,
np_arities: Optional[LongTensor] = None,
) -> Tensor:
# don't pass np_arities arg if we don't need to. it's not supported in mainline diffusers,
# so let's not make it hard for ourselves to switch branches/versions
structured_diffusion_args = {} if np_arities is None else { 'np_arities': np_arities }
out: UNet2DConditionOutput = self.inner_model(
sample.to(self.inner_model.dtype),
timestep.to(self.inner_model.dtype),
encoder_hidden_states=encoder_hidden_states.to(self.inner_model.dtype),
return_dict=return_dict,
**structured_diffusion_args
)
return out.sample.to(self.sampling_dtype)

@@ -1,23 +1,15 @@
import torch
from typing import Callable, Union, Iterable
from typing import Callable
from typing_extensions import TypeAlias
from torch import Tensor, LongTensor, no_grad
from enum import Enum, auto
from .log_level import log_level
from .device import DeviceType
from .clip_identifiers import ClipImplementation, ClipCheckpoint
from .tokenize_text import get_hf_tokenizer
from .prompt_type import Prompts

Prompts: TypeAlias = Union[str, Iterable[str]]
Embed: TypeAlias = Callable[[Prompts], Tensor]

class ClipImplementation(Enum):
HF = auto()
OpenCLIP = auto()
# OpenAI CLIP and clip-anytorch not implemented

class ClipCheckpoint(Enum):
OpenAI = auto()
LAION = auto()

def get_embedder(
impl: ClipImplementation,
ckpt: ClipCheckpoint,
@@ -27,24 +19,22 @@ def get_embedder(
) -> Embed:
match(impl):
case ClipImplementation.HF:
from transformers import CLIPTextModel, PreTrainedTokenizer, CLIPTokenizer, logging
from transformers import CLIPTextModel, PreTrainedTokenizer, logging
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.tokenization_utils_base import BatchEncoding
match(ckpt):
case ClipCheckpoint.OpenAI:
model_name = 'openai/clip-vit-large-patch14'
tokenizer_extra_args = {}
encoder_extra_args = {}
extra_args = {}
case ClipCheckpoint.LAION:
# model_name = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
model_name = 'stabilityai/stable-diffusion-2'
tokenizer_extra_args = {'subfolder': 'tokenizer'}
encoder_extra_args = {'subfolder': 'text_encoder'}
extra_args = {'subfolder': 'text_encoder'}
case _:
raise "never heard of '{ckpt}' ClipCheckpoint."
tokenizer: PreTrainedTokenizer = CLIPTokenizer.from_pretrained(model_name, **tokenizer_extra_args)
tokenizer: PreTrainedTokenizer = get_hf_tokenizer(ckpt=ckpt)
with log_level(logging.ERROR):
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(model_name, torch_dtype=torch_dtype, **encoder_extra_args).to(device).eval()
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(model_name, torch_dtype=torch_dtype, **extra_args).to(device).eval()

def embed(prompts: Prompts) -> Tensor:
tokens: BatchEncoding = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")
@@ -0,0 +1,4 @@
from typing import Union, Iterable
from typing_extensions import TypeAlias

Prompts: TypeAlias = Union[str, Iterable[str]]