# Custom version of Stable Diffusion

In [None]:
!pip install transformers diffusers ftfy

In [None]:
from pathlib import Path
from huggingface_hub import notebook_login
if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

In [None]:
import torch
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import math

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer

In [None]:
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to("cuda")

In [None]:
from diffusers import AutoencoderKL, UNet2DConditionModel

# Here we use a different VAE to the original release, which has been fine-tuned for more steps
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16).to("cuda")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cuda")

In [None]:
from diffusers import LMSDiscreteScheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

### Making Text Embeddings for the Unet

In [None]:
prompt = ["a photograph of an astronaut riding a horse"]

In [None]:
token_info = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt")

token_embs = text_encoder(token_info.input_ids.to("cuda"))[0];
token_embs.shape

In [None]:
uncond_info = tokenizer([""] * len(prompt), padding="max_length", truncation=True, return_tensors="pt")

uncond_embs = text_encoder(uncond_info.input_ids.to("cuda"))[0]
uncond_embs.shape

In [None]:
text_embs = torch.cat([uncond_embs, token_embs])
text_embs.shape

### Initialize Latents & Schedular

In [None]:
height = 512
width = 512
steps = 50

In [None]:
torch.manual_seed(100)
latents = torch.randn(len(prompt), unet.in_channels, height // 8, width // 8).to("cuda").half()
latents.shape

In [None]:
scheduler.set_timesteps(50)

In [None]:
plt.plot(scheduler.timesteps, scheduler.sigmas[:-1])

In [None]:
latents = latents * scheduler.init_noise_sigma

### The Loop

In [None]:
guidance_scale = 7.5

In [None]:
for i, t in enumerate(tqdm(scheduler.timesteps)):
    input = torch.cat([latents] * 2)
    input = scheduler.scale_model_input(input, t)

    # predict the noise residual
    with torch.no_grad(): pred = unet(input, t, encoder_hidden_states=text_embs).sample

    # perform guidance
    pred_uncond, pred_text = pred.chunk(2)
    pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)

    # compute the "previous" noisy sample
    updated_info = scheduler.step(pred, t, latents)
    latents = updated_info.prev_sample

### Latent to the Image

In [None]:
with torch.no_grad():
    im_data = vae.decode(latents * 1 / 0.18215).sample[0]

In [None]:
norm_im_data = (im_data * 0.5 + 0.5).clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy()
norm_im_data.shape

In [None]:
rgb_im_data = (norm_im_data * 255).round().astype("uint8")

In [None]:
from PIL import Image

In [None]:
Image.fromarray(rgb_im_data).resize((256, 256))

## Methods

In [None]:
prompt = ["a photograph of an astronaut riding a horse"]

In [None]:
def make_token_embs(promt): 
    token_info = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt")
    token_embs = text_encoder(token_info.input_ids.to("cuda"))[0];
    return token_embs

In [None]:
def gen_image(token_embs, height=512, width=512, steps=50, gd=7.5, seed=100, get_all=False, return_preview=False):
    # make text_embs
    uncond_info = tokenizer([""] * len(token_embs), padding="max_length", truncation=True, return_tensors="pt")
    uncond_embs = text_encoder(uncond_info.input_ids.to("cuda"))[0]
    text_embs = torch.cat([uncond_embs, token_embs])

    torch.manual_seed(seed)
    latents = torch.randn(len(token_embs), unet.in_channels, height // 8, width // 8).to("cuda").half()
    latents.shape

    scheduler.set_timesteps(steps)

    latents = latents * scheduler.init_noise_sigma
    latents_list = []
    
    for i, t in enumerate(tqdm(scheduler.timesteps)):
        input = torch.cat([latents] * 2)
        input = scheduler.scale_model_input(input, t)

        # predict the noise residual
        with torch.no_grad(): pred = unet(input, t, encoder_hidden_states=text_embs).sample

        # perform guidance
        pred_uncond, pred_text = pred.chunk(2)
        pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)

        # compute the "previous" noisy sample
        updated_info = scheduler.step(pred, t, latents)
        latents = updated_info.prev_sample
        
        if get_all:
            latents_list.append(updated_info.pred_original_sample if return_preview else latents)
    
    if get_all:
        return latents_list
    
    return latents

In [None]:
def decode_latents(latents, scale_factor=1.0):
    with torch.no_grad():
        im_data = vae.decode(latents * 1 / 0.18215).sample[0]
        
    norm_im_data = (im_data * 0.5 + 0.5).clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy()
    rgb_im_data = (norm_im_data * 255).round().astype("uint8")
    im = Image.fromarray(rgb_im_data)
    
    return im.resize(((int)(im.width * scale_factor), (int)(im.height * scale_factor)))

In [None]:
def show_latents_grid(latents_list, cols=8, scale_factor=1.0):
    images = [decode_latents(item, scale_factor) for item in latents_list]
    
    w,h = images[0].size
    rows = math.ceil(len(images) / cols)
    grid = Image.new('RGB', size=(cols*w, rows*h))
    
    for i, img in enumerate(images): 
        grid.paste(img, box=(i%cols*w, i//cols*h))
        
    return grid

### Usage

In [None]:
prompt = ["a photograph of an astronaut riding a horse"]
latents = gen_image(make_token_embs(prompt))
decode_latents(latents, scale_factor=0.5)

In [None]:
prompt = ["a photograph of an astronaut riding a horse"]
latents_list = gen_image(make_token_embs(prompt), steps=20, get_all=True)
show_latents_grid(latents_list, scale_factor=0.3)

In [None]:
prompt = ["a photograph of an astronaut riding a horse"]
latents_list = gen_image(make_token_embs(prompt), steps=20, get_all=True, return_preview=True)
show_latents_grid(latents_list, scale_factor=0.3)

In [None]:
decode_latents(latents_list[-2], scale_factor=0.5)