In [None]:
import torch
import numpy as np
from tqdm import tqdm
from torch import autocast
from einops import rearrange

from stablediffusion.ldm.models.diffusion.ddim import DDIMSampler
from stablediffusion.ldm.models.diffusion.plms import PLMSSampler
from stablediffusion.scripts.txt2img import load_model_from_config

device = torch.device("cuda")

prompt_i = "a painting of a virus monster playing guitar"
prompt_j = "a forested landscape"

w_i = 0.5
w_j = 0.5

config = "stablediffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml"

timesteps = 1000

n = 1 # Number of samples / batch size
ch = 4 # Latent channels
f = 8 # Downsample factor
h = 512 # Image height
w = 512 # Image width

scale = 7.5 # Unconditional guidance scale
ddim_eta = 0.0 # 0.0 corresponds to deterministic sampling
shape = [ch, h // f, w // f]

b = n

model = load_model_from_config(config, 'sd-v2-1.ckpt')
model = model.to(device)
model = PLMSSampler(model)

with torch.no_grad():
    with autocast('cuda'):
        with model.ema_scope():
            uc = model.get_learned_conditioning(n * [""])
            c_i = model.get_learned_conditioning(n * [prompt_i])
            c_j = model.get_learned_conditioning(n * [prompt_j])

@torch.no_grad()
def p_sample(model, x, c, ts, index, old_eps=None, t_next=None):
    outs = model.p_sample_plms(x, c, ts, index=index, unconditional_guidance_scale=scale, unconditional_conditioning=uc,)
    x, _, e_t = outs
    old_eps.append(e_t)
    if len(old_eps) >= 4:
        old_eps.pop(0)

    return old_eps 

with torch.no_grad():
    with autocast('cuda'):
        with model.ema_scope():
            # Initialize sample x_T to N(0,I)
            x = torch.randn((n, ch, h // f, w // f)).to(device)

            model.make_schedule(ddim_num_steps=timesteps, ddim_eta=ddim_eta, verbose=False)
            timesteps = model.ddim_timesteps
            time_range = np.flip(timesteps)
            total_steps = timesteps.shape[0]
            e_ti = []
            e_tj = []
            for i, step in enumerate(tqdm(time_range, desc='PLMS Sampler', total=total_steps)):
                index = total_steps - i - 1
                ts = torch.full((b,), step, device=device, dtype=torch.long)
                ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
                
                # Compute conditional scores for each concept c_i
                e_ti = p_sample(model, x, c_i, ts, index, e_ti, ts_next) 
                e_tj = p_sample(model, x, c_j, ts, index, e_tj, ts_next)
                e_i = e_ti[-1]
                e_j = e_tj[-1]

                # Compute unconditional score
                e_t = p_sample(model, x, uc, ts, index, e_t, ts_next)
                e = e_t[-1]
                
                # Sampling
                mean = x - (e + w_i * (e_i - e) + w_j * (e_j - e))
                covar = model.betas[ts]
                x = torch.normal(mean, covar*torch.eye(h // f, w // f)) # Sampling

            x = model.decode_first_stage([x])
            x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)

            count = 0
            for sample in x:
                sample = 255 * rearrange(sample.cpu().numpy(), 'c h w -> h w c')
                img = Image.fromarray(sample.astype(np.uint8))
                img.save(os.path.join(sample_path, f"output_{count}.png"))
                count += 1

: 

In [3]:
import os
import torch
import numpy as np

from omegaconf import OmegaConf
from stablediffusion.scripts.txt2img import load_model_from_config

configpath = "stablediffusion/configs/stable-diffusion/v1-inference.yaml"

config = OmegaConf.load(configpath)

model = load_model_from_config(config, 'sd-v1-4_768.ckpt')

Loading model from sd-v1-4_768.ckpt
Global Step: 470000
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.14.self_attn.k_proj.bias', 'vision_model.encoder.layers.9.self_attn.q_proj.bias', 'vision_model.encoder.layers.2.self_attn.v_proj.weight', 'vision_model.encoder.layers.4.layer_norm1.bias', 'vision_model.encoder.layers.15.self_attn.out_proj.bias', 'vision_model.encoder.layers.2.self_attn.k_proj.weight', 'vision_model.encoder.layers.13.self_attn.v_proj.weight', 'vision_model.encoder.layers.8.self_attn.q_proj.bias', 'vision_model.encoder.layers.17.self_attn.k_proj.weight', 'vision_model.encoder.layers.0.self_attn.out_proj.weight', 'vision_model.encoder.layers.11.self_attn.q_proj.bias', 'vision_model.encoder.layers.23.self_attn.k_proj.bias', 'vision_model.encoder.layers.6.layer_norm1.weight', 'vision_model.encoder.layers.2.mlp.fc1.bias', 'vision_model.encoder.layers.11.self_attn.k_proj.weight', 'vision_model.encoder.layers.1.self_attn.v_proj.bia

In [5]:
for module in model.modules():
    print(module)

<generator object Module.modules at 0x7f06c032c900>
