In [None]:
import rp
import torch
from IPython.display import clear_output
from diffusers import DiffusionPipeline, DDIMScheduler
from icecream import ic
import numpy as np

In [None]:
prompt = "A cute puppy at the beach"
# prompt = "a delicate apple made of opal hung on branch in the early morning light, adorned with glistening dewdrops. in the background beautiful valleys, divine iridescent glowing, opalescent textures, volumetric light, ethereal, sparkling, light inside body, bioluminescence, studio photo, highly detailed, sharp focus, photorealism, photorealism, 8k, best quality"

In [None]:
#Uncomment ONE model:
# model_ckpt = "stabilityai/stable-diffusion-2-base"            ; scheduler_kwargs={} ; guidance_scale=7.5
model_ckpt = "stabilityai/stable-diffusion-2-1-base"          ; scheduler_kwargs={} ; guidance_scale=7.5
# model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"       ; scheduler_kwargs={} ; guidance_scale=7.5
# model_ckpt = "ByteDance/sd2.1-base-zsnr-laionaes5"            ; scheduler_kwargs = dict(timestep_spacing="trailing",rescale_betas_zero_snr=True); rescale_betas_zero_snr=True  #Zero-terminal SNR; can use rescale_betas_zero_snr=True and timestep_spacing="trailing"
# model_ckpt = "ByteDance/sd2.1-base-zsnr-laionaes6-perceptual" ; scheduler_kwargs = dict(timestep_spacing="trailing",rescale_betas_zero_snr=True); guidance_scale=0 #Zero-terminal SNR + it uses guidance_scale=0

In [None]:
scheduler = DDIMScheduler.from_pretrained(
    model_ckpt,
    subfolder="scheduler",
    
    # thresholding = True,
    **scheduler_kwargs,
)

pipe = DiffusionPipeline.from_pretrained(
    model_ckpt, 
    scheduler=scheduler, 
    safety_checker=None,
    torch_dtype=torch.float16,
)

pipe = pipe.to("cuda")

ic(pipe.scheduler.config.prediction_type)

In [None]:
images = []
for _ in range(5):
    out=pipe(
        "",
        guidance_scale=0, # Default: 7.5
    )
    images += out.images

clear_output()
rp.display_image(rp.tiled_images(images, length=10))

In [None]:
latent_height = pipe.unet.config.sample_size
latent_width  = pipe.unet.config.sample_size
height = latent_height * pipe.vae_scale_factor
width  = latent_width  * pipe.vae_scale_factor
latent_num_channels = pipe.vae.config.latent_channels
batch_size = 1
pure_noise_latents = torch.randn(batch_size, latent_num_channels, latent_height, latent_width)
pure_noise_latents = pure_noise_latents.to(pipe.dtype).to(pipe.device)

ic(height, pipe.vae_scale_factor, latent_height)
ic(pure_noise_latents.shape, pure_noise_latents.dtype, pure_noise_latents.device)

In [None]:
pipe_output = pipe(
    prompt,
    guidance_scale=0, # Default: 7.5

    latents = pure_noise_latents,
)

image = pipe_output.images[0]
rp.display_image(image)

In [None]:
def latent_as_image(latent, contrast=1/3, scale=8):
    global height, width
    latent=rp.as_numpy_image(latent)
    latent=latent * contrast
    latent+=1/2
    latent=rp.cv_resize_image(latent, scale, interp='nearest')
    return latent

rp.display_image(latent_as_image(pure_noise_latents[0]))

In [None]:
def roll_torch_images(torch_images, dx=0, dy=0):
    """torch_images is in BCHW form"""
    return torch_images.roll((dx,dy),(-1,-2))

image_slideshow=[]
for dx in range(20):
    rolled_latents = roll_torch_images(pure_noise_latents, dx)
    frame = latent_as_image(rolled_latents[0])
    # frame = rp.with_alpha_checkerboard(frame, tile_size=128)
    image_slideshow.append(frame)

rp.display_image_slideshow(image_slideshow)

In [None]:
image_slideshow=[]

shifts = range(64)
display_eta = rp.eta(len(shifts))
for i, dx in enumerate(shifts):    
    rolled_latents = roll_torch_images(pure_noise_latents, dx)

    pipe_output = pipe(
        prompt,
        guidance_scale=guidance_scale,
        latents = rolled_latents,
    )
    pipe_image = pipe_output.images[0]    

    frame = rp.horizontally_concatenated_images(
        latent_as_image(rolled_latents[0]),
        pipe_image,
    )
    
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)

clear_output()
saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('rolling_latents.mp4'),
    framerate=15,
)
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(image_slideshow)

In [None]:
def roll_torch_latents_via_reencoding(latents, dx=0, dy=0):
    """Now dx and dy are in pixel space"""
    with torch.no_grad():
        latents = latents.to(pipe.dtype).to(pipe.device)
    
        #NOTE: This assumes the latents have the same mean/std as images
        #Which is NOT a unit normal distribution! If given a unit normal,
        #the accuracy will probably suffer as its out of distribution for
        #the encoder and decoder. Nevertheless, lets see how it does!
        factor = 1.0
        factor = factor * pipe.vae.config.scaling_factor #Idk if this is what the number is for?
        latents = latents / factor 
        
        images = pipe.vae.decode(latents).sample
        images = roll_torch_images(images, dx, dy)
        latents = pipe.vae.encode(images).latent_dist.mean
    
        latents = latents * factor
        
        return latents

In [None]:
#Recursion test for redecoded_roll_torch_latents
image_slideshow=[]
recursed_latents = pure_noise_latents

recursion_iterations = 20
display_eta = rp.eta(recursion_iterations)
for i in range(recursion_iterations):    
    pipe_output = pipe(
        prompt,
        guidance_scale=guidance_scale,
        latents = recursed_latents,
    )
    pipe_image = pipe_output.images[0]

    frame = rp.horizontally_concatenated_images(
        latent_as_image(recursed_latents[0]),
        pipe_image,
    )
    frame = rp.labeled_image(frame, 'Recursion Iterations: '+str(i))
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)

    recursed_latents = roll_torch_latents_via_reencoding(recursed_latents)


saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('recused_latents.mp4'),
    framerate=15,
)
clear_output()
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(image_slideshow)

In [None]:
image_slideshow=[]

shifts = range(30) #In pixels
display_eta = rp.eta(len(shifts))
for i,dx in enumerate(shifts):    
    shifted_latents = roll_torch_latents_via_reencoding(pure_noise_latents, dx)
    
    pipe_output = pipe(
        prompt,
        guidance_scale=guidance_scale,
        latents = shifted_latents,
    )
    pipe_image = pipe_output.images[0]

    frame = rp.horizontally_concatenated_images(
        latent_as_image(shifted_latents[0]),
        pipe_image,
    )
    frame = rp.labeled_image(frame, 'Shift: dx='+str(dx)+" pixels")
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)


saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('reencode_shifted_latents.mp4'),
    framerate=15,
)
clear_output()
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(image_slideshow)

In [None]:
image_slideshow=[]

row_nums = range(64)
display_eta = rp.eta(len(row_nums))
for i,row_num in enumerate(row_nums):    
    latents = pure_noise_latents.clone()
    
    latents[0,:,row_num,:]=0
    
    pipe_output = pipe(
        prompt,
        guidance_scale=guidance_scale,
        latents = latents,
    )
    pipe_image = pipe_output.images[0]

    frame = rp.horizontally_concatenated_images(
        latent_as_image(latents[0]),
        pipe_image,
    )
    frame = rp.labeled_image(frame, 'Set latent row #'+str(row_num)+" to 0")
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)


saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('row_deletion.mp4'),
    framerate=15,
)
clear_output()
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(
    rp.resize_images(image_slideshow, 1/4)
)

In [None]:
image_slideshow=[]

col_nums = range(64)
display_eta = rp.eta(len(col_nums))
for i,col_num in enumerate(col_nums):    
    latents = pure_noise_latents.clone()
    
    latents[0,:,:,col_num]=0
    
    pipe_output = pipe(
        prompt,
        guidance_scale=guidance_scale,
        latents = latents,
    )
    pipe_image = pipe_output.images[0]

    frame = rp.horizontally_concatenated_images(
        latent_as_image(latents[0]),
        pipe_image,
    )
    frame = rp.labeled_image(frame, 'Set latent col #'+str(col_num)+" to 0")
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)


saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('col_deletion.mp4'),
    framerate=15,
)
clear_output()
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(
    rp.resize_images(image_slideshow, size=1/4)
)

In [None]:
image_slideshow=[]

upto_col_nums = range(64)
display_eta = rp.eta(len(upto_col_nums))
for i,upto_col_num in enumerate(upto_col_nums):    
    latents = pure_noise_latents.clone()
    
    latents[0,:,:,:upto_col_num]=0
    
    pipe_output = pipe(
        prompt,
        guidance_scale=guidance_scale,
        latents = latents,
    )
    pipe_image = pipe_output.images[0]

    frame = rp.horizontally_concatenated_images(
        latent_as_image(latents[0]),
        pipe_image,
    )
    frame = rp.labeled_image(frame, 'Set latent upto_col #'+str(upto_col_num)+" to 0")
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)


saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('upto_col_deletion.mp4'),
    framerate=15,
)
clear_output()
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(
    rp.resize_images(image_slideshow, size=1/4)
)

In [None]:
image_slideshow=[]

pixel_shifts = range(64)
display_eta = rp.eta(len(pixel_shifts))
for i, dx in enumerate(pixel_shifts):    

    latent_shift = dx / 8
    alpha = latent_shift % 1
    rolled_latents_ceil  = roll_torch_images(pure_noise_latents, int(np.ceil (latent_shift)))
    rolled_latents_floor = roll_torch_images(pure_noise_latents, int(np.floor(latent_shift)))
    rolled_latents = rp.blend(rolled_latents_floor, rolled_latents_ceil, alpha)

    pipe_output = pipe(
        prompt,
        guidance_scale=guidance_scale,
        latents = rolled_latents,
    )
    pipe_image = pipe_output.images[0]    

    frame = rp.horizontally_concatenated_images(
        latent_as_image(rolled_latents[0]),
        pipe_image,
    )
    frame = rp.labeled_image(frame, "Blended latent shift, dx="+str(dx)+" (pixel space)")
    
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)

clear_output()
saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('blended_rolling_latents.mp4'),
    framerate=15,
)
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(rp.resize_images(image_slideshow, size=1/4))

In [None]:
image_slideshow=[]

pixel_shifts = range(64)
display_eta = rp.eta(len(pixel_shifts))
for i, dx in enumerate(pixel_shifts):    

    latent_shift = dx / 8
    alpha = latent_shift % 1
    rolled_latents_floor = roll_torch_images(pure_noise_latents, int(np.floor(latent_shift)))
    rolled_latents_ceil  = roll_torch_images(pure_noise_latents, int(np.ceil (latent_shift)))

    #Unlike simple alpha blending, this preserves the variance
    rolled_latents = ((1 - alpha) ** .5) * rolled_latents_floor + (alpha ** .5) * rolled_latents_ceil
    ic(dx, rolled_latents.mean(), rolled_latents.std())

    pipe_output = pipe(
        prompt,
        guidance_scale=guidance_scale,
        latents = rolled_latents,
    )
    pipe_image = pipe_output.images[0]    

    frame = rp.horizontally_concatenated_images(
        latent_as_image(rolled_latents[0]),
        pipe_image,
    )
    frame = rp.labeled_image(frame, "Blended latent shift, dx="+str(dx)+" (pixel space)")
    
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)

clear_output()
saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('variance_preserving_blended_rolling_latents.mp4'),
    framerate=15,
)
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(rp.resize_images(image_slideshow, size=1/4))

In [None]:
image_slideshow=[]

shifts = range(64)
display_eta = rp.eta(len(shifts))
for i, dx in enumerate(shifts):    
    tiled_latents = pure_noise_latents.clone()
    # tiled_latents[:,:,:32,:32]=tiled_latents[:,:,:32,:32]
    # tiled_latents[:,:,32:,:32]=tiled_latents[:,:,:32,:32]
    # tiled_latents[:,:,32:,32:]=tiled_latents[:,:,:32,:32]
    # tiled_latents[:,:,:32,32:]=tiled_latents[:,:,:32,:32]


    tiled_latents[:,:,16:32,16:32] = 0
    tiled_latents[:,:,32:,:32]
    tiled_latents[:,:,32:,32:]
    tiled_latents[:,:,:32,32:]
    
    rolled_latents = roll_torch_images(tiled_latents, dx)

    pipe_output = pipe(
        prompt,
        guidance_scale=guidance_scale,
        latents = rolled_latents,
    )
    pipe_image = pipe_output.images[0]    

    frame = rp.horizontally_concatenated_images(
        latent_as_image(rolled_latents[0]),
        pipe_image,
    )
    
    image_slideshow.append(frame)

    clear_output()
    display_eta(i)
    rp.display_image(frame)

clear_output()
saved_video_path = rp.save_video_mp4(
    image_slideshow,
    rp.get_unique_copy_path('tiled_rolling_latents.mp4'),
    framerate=15,
)
rp.fansi_print("Saved video: "+saved_video_path, 'green')
rp.display_image_slideshow(image_slideshow)