In [None]:
import rp
import torch
from IPython.display import clear_output
from diffusers import DiffusionPipeline, DDIMScheduler
from icecream import ic

In [None]:
#Uncomment ONE model:
# model_ckpt = "stabilityai/stable-diffusion-2-base"            ; scheduler_kwargs={}
# model_ckpt = "stabilityai/stable-diffusion-2-1-base"          ; scheduler_kwargs={}
# model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"       ; scheduler_kwargs={}
# model_ckpt = "ByteDance/sd2.1-base-zsnr-laionaes5"            ; scheduler_kwargs = dict(timestep_spacing="trailing",rescale_betas_zero_snr=True); rescale_betas_zero_snr=True  #Zero-terminal SNR; can use rescale_betas_zero_snr=True and timestep_spacing="trailing"
model_ckpt = "ByteDance/sd2.1-base-zsnr-laionaes6-perceptual" ; scheduler_kwargs = dict(timestep_spacing="trailing",rescale_betas_zero_snr=True); guidance_scale=0 #Zero-terminal SNR + it uses guidance_scale=0

In [None]:
scheduler = DDIMScheduler.from_pretrained(
    model_ckpt,
    subfolder="scheduler",
    
    # thresholding = True,
    **scheduler_kwargs,
)

pipe = DiffusionPipeline.from_pretrained(
    model_ckpt, 
    scheduler=scheduler, 
    safety_checker=None,
    torch_dtype=torch.float16,
)

pipe = pipe.to("cuda")

ic(pipe.scheduler.config.prediction_type)

In [None]:
images = []
for _ in range(20):
    out=pipe(
        "",
        # "A photo geisha isolated on a solid green background",
        guidance_scale=0, # Default: 7.5
    )
    images += out.images

clear_output()
rp.display_image(rp.tiled_images(images, length=10))

In [None]:
pipe.device

In [None]:
pipe.dtype

In [None]:
latent_height = pipe.unet.config.sample_size
latent_width  = pipe.unet.config.sample_size
height = latent_height * pipe.vae_scale_factor
width  = latent_width  * pipe.vae_scale_factor
latent_num_channels = pipe.vae.config.latent_channels
batch_size = 1
pure_noise_latents = torch.randn(batch_size, latent_num_channels, latent_height, latent_width)
pure_noise_latents = pure_noise_latents.to(pipe.dtype).to(pipe.device)

ic(height, pipe.vae_scale_factor, latent_height)
ic(pure_noise_latents.shape, pure_noise_latents.dtype, pure_noise_latents.device)

In [None]:
# import random
# print(pure_noise_latent.flatten()[:10])
# random.seed(123)
# torch.random.manual_seed(213)

pipe_output = pipe(
    "A cute puppy on an ocean",
    # "A photo geisha isolated on a solid green background",
    guidance_scale=0, # Default: 7.5

    latents = pure_noise_latents,
)

image = pipe_output.images[0]
rp.display_image(image)

In [None]:
def latent_as_image(latent, contrast=1/3, scale=8):
    global height, width
    latent=rp.as_numpy_image(latent)
    latent=latent * contrast
    latent+=1/2
    latent=rp.cv_resize_image(latent, scale, interp='nearest')
    return latent

rp.display_image(latent_as_image(pure_noise_latents[0]))

In [None]:
def roll_torch_images(torch_images, dx=0, dy=0):
    """torch_images is in BCHW form"""
    return torch_images.roll((dx,dy),(-1,-2))

image_slideshow=[]
for dx in range(20):
    rolled_latents = roll_torch_images(pure_noise_latents, dx)
    frame = latent_as_image(rolled_latents[0])
    # frame = rp.with_alpha_checkerboard(frame, tile_size=128)
    image_slideshow.append(frame)

rp.display_image_slideshow(image_slideshow)

In [None]:
image_slideshow=[]

shifts = range(64)
display_eta = rp.eta(len(shifts))
for i, dx in enumerate(shifts):    
    rolled_latents = roll_torch_images(pure_noise_latents, dx)

    pipe_output = pipe(
        "A cute puppy on an the beach",
        guidance_scale=guidance_scale,
    
        latents = rolled_latents,
    )
    
    pipe_image = pipe_output.images[0]

    frame = rp.horizontally_concatenated_images(
        latent_as_image(rolled_latents[0]),
        pipe_image,
    )
    
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)

clear_output()
saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('rolling_latents.mp4'),
    framerate=15,
)
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(image_slideshow)