In [1]:
from stable_diffusion_pytorch import model_loader, pipeline
from stable_diffusion_pytorch.samplers.k_lms import KLMSSampler
from stable_diffusion_pytorch.tokenizer import Tokenizer
import stable_diffusion_pytorch.util as util
from PIL import Image
from IPython.display import display
from tqdm import tqdm
import torch, sys
import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

In [3]:
print("Python: ", sys.version)
print("PyTorch: ", torch.__version__)
print("Device: ", device)

Python:  3.10.8 | packaged by conda-forge | (main, Nov 24 2022, 14:07:00) [MSC v.1916 64 bit (AMD64)]
PyTorch:  2.0.0+cpu
Device:  cpu


In [7]:
prompt = "1girl, purple hair, genshin, high quality, masterpiece, raiden shogun, japanese, kimono"
prompts = [prompt]

uncond_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
uncond_prompts = [uncond_prompt] if uncond_prompt else None
uncond_prompts = uncond_prompts or [""] * len(prompts)
input_images = [Image.open("C:\\Users\\frank\\Documents\\카카오톡 받은 파일\\KakaoTalk_20220604_234317644.png")]
strength = 0.8
do_cfg = True
cfg_scale = 7.5
height = 512
width = 512
n_inference_steps = 50
seed = 42
use_jit = True
export = False and not use_jit

In [5]:
if use_jit:
    models = dict()
    for name in ['clip', 'decoder', 'diffusion', 'encoder']:
        models[name] = torch.jit.load(f"nvai-jit/{name}.pt")
        models[name].eval()
else:
    torch.jit.enable_onednn_fusion(True)
    models = model_loader.preload_models(device)

In [8]:
with torch.no_grad():
    tokenizer = Tokenizer()
    generator = torch.Generator(device='cpu')
    generator.manual_seed(seed)
    tokens = tokenizer.encode_batch(prompts)
    tokens = torch.tensor(tokens, dtype=torch.long, device=device)
    cond_context = models['clip'](tokens)
    uncond_tokens = tokenizer.encode_batch(uncond_prompts or [""] * len(prompts))
    uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
    uncond_context = models['clip'](uncond_tokens)
    context = torch.cat([cond_context, uncond_context])

    if export:
        with torch.jit.optimized_execution(True):
            clip_traced = torch.jit.trace(models['clip'], tokens)
            clip_traced.eval()
            clip_traced = torch.jit.freeze(clip_traced)
        clip_traced.save("nvai-jit/clip.pt")
        del clip_traced

In [9]:
del models['clip']

In [10]:
context.shape

torch.Size([2, 77, 768])

In [12]:
with torch.no_grad():
    sampler = KLMSSampler(n_inference_steps=n_inference_steps)
    noise_shape = (len(prompts), 4, height // 8, width // 8)
    processed_input_images = []
    for input_image in input_images:
        input_image = input_image.resize((width, height))
        input_image = np.array(input_image)[:, :, :3]
        input_image = torch.tensor(input_image, dtype=torch.float32)
        input_image = util.rescale(input_image, (0, 255), (-1, 1))
        processed_input_images.append(input_image)
    input_images_tensor = torch.stack(processed_input_images).to(device)
    input_images_tensor = util.move_channel(input_images_tensor, to="first")

    _, _, height, width = input_images_tensor.shape

    encoder_noise = torch.randn(noise_shape, generator=generator, device=device)
    print(input_images_tensor.shape, encoder_noise.shape)
    latents = models['encoder'](input_images_tensor, encoder_noise)

    latents_noise = torch.randn(noise_shape, generator=generator, device=device)
    sampler.set_strength(strength=strength)
    latents += latents_noise * sampler.initial_scale

    if export:
        with torch.jit.optimized_execution(True):
            encoder_traced = torch.jit.trace(models['encoder'], (input_images_tensor, encoder_noise))
            encoder_traced.eval()
            encoder_traced = torch.jit.freeze(encoder_traced)
        encoder_traced.save("nvai-jit/encoder.pt")
        del encoder_traced

torch.Size([1, 3, 512, 512]) torch.Size([1, 4, 64, 64])


In [None]:
del models['encoder']

In [None]:
latents.shape

In [None]:
sampler.timesteps

In [None]:
sampler.sigmas

In [None]:
with torch.no_grad():
    timesteps = tqdm(sampler.timesteps)
    for i, timestep in enumerate(timesteps):
        time_embedding = util.get_time_embedding(timestep).to(device)
        input_latents = latents * sampler.get_input_scale()
        if do_cfg:
            input_latents = input_latents.repeat(2, 1, 1, 1)
        output = models['diffusion'](input_latents, context, time_embedding)
        if export and i == 0:
            with torch.jit.optimized_execution(True):
                diffusion_traced = torch.jit.trace(models['diffusion'], (input_latents, context, time_embedding))
                diffusion_traced.eval()
                diffusion_traced = torch.jit.freeze(diffusion_traced)
            diffusion_traced.save("nvai-jit/diffusion.pt")
            del diffusion_traced
            print("Exported diffusion model.")
        if do_cfg:
            output_cond, output_uncond = output.chunk(2)
            output = cfg_scale * (output_cond - output_uncond) + output_uncond
        latents = sampler.step(latents, output)

In [None]:
del models['diffusion']

In [None]:
output.shape

In [None]:
latents.shape

In [None]:
with torch.no_grad():
    res = models['decoder'](latents)
    print(res.shape)
    images = util.rescale(res, (-1, 1), (0, 255), clamp=True)
    images = util.move_channel(images, to="last")
    images = images.to('cpu', torch.uint8).numpy()
    results = [Image.fromarray(image) for image in images]
    if export:
        with torch.jit.optimized_execution(True):
            decoder_traced = torch.jit.trace(models['decoder'], latents)
            decoder_traced.eval()
            decoder_traced = torch.jit.freeze(decoder_traced)
        decoder_traced.save("nvai-jit/decoder.pt")
        del decoder_traced

In [None]:
res

In [None]:
t = np.floor((res.numpy()+1)*127.5)
t = t.astype(np.uint8)
t = t[0]

In [None]:
t.shape

In [None]:
t = t.transpose(1,2,0)

In [None]:
Image.fromarray(t)

In [None]:
images.shape

In [None]:
del models['decoder']

In [None]:
images[0].shape

In [None]:
for image in results:
    display(image)