Stable Diffusion 是文本条件隐式扩散模型

In [None]:
'''环境准备'''
import torch
import requests
from PIL import Image
from io import BytesIO
from matplotlib import pyplot as plt

from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    StableDiffusionInpaintPipeline,
    StableDiffusionDepth2ImgPipeline
)

In [None]:
def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert('RGB')

# img_url = ''
# init_img = download_image(img_url).resize((512,512))

device = (
    'mps'
    if torch.backends.mps.is_available()
    else 'cuda'
    if torch.cuda.is_available()
    else 'cpu'
)

In [None]:
'''从文本生成图像'''
#download pipeline
model_id = 'stabilityai/stable-diffusion-2-1-base'
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
# pipe = StableDiffusionPipeline.from_pretrained(model_id,
# revision='fp16', torch_dtype=torch.float16).to(device) #fp16版本载入

# pipe.enable_attention_slicing() #注意力切分，降速减小显存使用

In [None]:
#生成图像
generator = torch.Generator(device=device).manual_seed(42)

pipe_output = pipe(
    prompt='winter cityscape',
    negative_promote='low quality',
    height=480, width=640,
    guidance_scale=8,
    num_inference_steps=35,
    generator=generator
)

pipe_output.images[0]

In [None]:
#加大guidance_scale参数作用
cfg_scales = [1.1, 8, 12]
prompt = 'A dog with a pink hat'
fig, axs = plt.subplots(1, len(cfg_scales), figsize=(16,5))
for i, ax in enumerate(axs):
    im = pipe(
        ... #伪代码
        guidance_scale = cfg_scales[i]
    ).images[0]
    ax.show(im); ax.set_title(f'CFG Scale{cfg_scales[i]}')

In [None]:
'''官网代码https://huggingface.co/stabilityai/stable-diffusion-2-1-base'''

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch

model_id = "stabilityai/stable-diffusion-2-1-base"

scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("astronaut_rides_horse.png")
