In [None]:
%reload_ext autoreload
%autoreload 2

import math
import gc
import importlib
import torch
import torch.nn as nn
from torch.nn import functional as func
import numpy as np
import transformers
from diffusers import UNet2DConditionModel
from PIL import Image
from IPython.display import display, clear_output

from autumn.notebook import *
from autumn.math import *
from autumn.images import *
from autumn.guidance import *
from autumn.scheduling import *
from autumn.solvers import *
from autumn.prompting import PromptEncoder
from autumn.vae import SDXL_VAE

In [None]:
%%settings

# model can be a local path or a huggingface model name
model = "stabilityai/stable-diffusion-xl-base-0.9"
XL_MODEL = True # this notebook in its current state is only really guaranteed to work with SDXL
vae_scale = 0.13025 # found in the model config

run_ids = range(5)

# run context will contain only context.run_id, anything returned in this dictionary will be added to it
add_run_context = lambda context: {}

p0 = "photograph of a very cute dog"
#np0 = "blurry, ugly, indistinct, jpeg artifacts, watermark, text, signature"

# Generally: e2_prompt has the strongest effect, but terms from e1_prompt will show up in results.
# e2_pool_prompts has an influence but very indirect idk
# the model is trained w/ all encodings coming from the same prompt

prompts = lambda context: {
    # Prompts for encoder 1
    "encoder_1": [p0],
    # Prompts for encoder 2; defaults to e1_prompts if None
    "encoder_2": None,
    # Prompts for pooled encoding of encoder 2; defaults to e2_prompts if None
    "encoder_2_pooled": None
}

# Method by which predictions for different prompts will be recombined to make one noise prediction
combine_predictions = lambda context: true_noise_removal(context, [1])

seed = lambda context: 42069 + context.run_id

# these get multiplied by 64
width_height = lambda context: (16, 16)

steps = lambda context: 15

# for scaling by a sqrt or **2 curve, &c
timestep_power = lambda c: 1
timestep_max = lambda c: 999
timestep_min = lambda c: 0

# differential equation solver. see autumn/solvers.py
solver_step = lambda c: rk4_step

# step context will contain run_id & step_index, anything returned in this dictionary will be added to it
add_step_context = lambda context: {}

embedding_distortion = lambda context: None

save_output = lambda context: True
save_approximates = lambda context: False
save_raw = lambda context: False

#!# Settings above this line will be replaced with the contents of settings.py if it exists. #!#

main_dtype = torch.float64
unet_dtype = torch.float16
vae_dtype = torch.float32
encoder_dtype = torch.float16

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("medium")

In [None]:
# # # models # # #

model_source = model

with Timer("total"):
    with Timer("vae"):
        vae = SDXL_VAE()
        vae.load_safetensors(model_source)
        vae.to(device="cuda")
    
        # (as of torch 2.2.2) VAE compilation yields negligible gains, at least for decoding once/step
        #vae = torch.compile(vae, mode="default", fullgraph=True)
    
    with Timer("unet"):
        unet = UNet2DConditionModel.from_pretrained(
            model_source, subfolder="unet", torch_dtype=unet_dtype
        )
        unet.to(device="cuda")
    
        # compilation will not actually happen until first use of unet
        # (as of torch 2.2.2) "default" provides the best result on my machine
        # don't use this if you're gonna be changing resolutions a lot
        unet_c = torch.compile(unet, mode="default", fullgraph=True)
    
    with Timer("clip"):
        p_encoder = PromptEncoder(model_source, XL_MODEL, torch_dtype=encoder_dtype)


In [None]:
# # # run # # #

variance_range = (0.00085, 0.012) # should come from model config!
forward_noise_schedule = default_variance_schedule(variance_range).to(main_dtype) # beta
forward_signal_product = torch.cumprod((1 - forward_noise_schedule), dim=0) # alpha_bar
partial_signal_product = lambda s, t: torch.prod((1 - forward_noise_schedule)[s+1:t]) # alpha_bar_t / alpha_bar_s (but computed more directly from the forward noise)
part_noise = (1 - forward_signal_product).sqrt() # sigma
part_signal = forward_signal_product.sqrt() # mu?

def step_by_noise(latents, noise, from_timestep, to_timestep):
    if from_timestep < to_timestep: # forward
        signal_ratio = 1 / partial_signal_product(from_timestep, to_timestep).sqrt()
    else: # backward
        signal_ratio = partial_signal_product(to_timestep, from_timestep).sqrt()
    return latents / signal_ratio + noise * (part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio)

def shuffle_step(latents, first_noise, second_noise, timestep, intermediate_timestep):
    if from_timestep < to_timestep: # forward
        signal_ratio = 1 / partial_signal_product(timestep, intermediate_timestep).sqrt()
    else: # backward
        signal_ratio = partial_signal_product(intermediate_timestep, timestep).sqrt()
    return latents + (first_noise - second_noise) * (part_noise[intermediate_timestep] * signal_ratio - part_noise[timestep])

for run_id in run_ids:
    run_context = Context()
    run_context.run_id = run_id
    add_run_context(run_context)
    
    try:
        _seed = int(seed(run_context))
    except:
        _seed = 0
        print(f"non-integer seed, run {run_id}. replaced with 0.")
    
    torch.manual_seed(_seed)
    np.random.seed(_seed)

    run_context.steps = steps(run_context)

    diffusion_timesteps = linspace_timesteps(run_context.steps+1, timestep_max(run_context), timestep_min(run_context), timestep_power(run_context))

    run_prompts = prompts(run_context)
    
    unet_batch_size = len(run_prompts["encoder_1"])
    
    (all_penult_states, enc2_pooled) = p_encoder.encode(run_prompts["encoder_1"], run_prompts["encoder_2"], run_prompts["encoder_2_pooled"])

    for index in range(all_penult_states.shape[0]):
        run_context.embedding_index = index
        if embedding_distortion(run_context) is not None:
            all_penult_states[index] = svd_distort_embeddings(all_penult_states[index].to(main_dtype), embedding_distortion(run_context)).to(unet_dtype)

    width, height = width_height(run_context)

    if (width < 64): width *= 64
    if (height < 64): height *= 64
    
    with torch.no_grad():
        vae_dim_scale = 2 ** 3
    
        latents = torch.zeros(
            (1, unet.config.in_channels, height // vae_dim_scale, width // vae_dim_scale),
            device="cuda",
            dtype=main_dtype
        )

        modify_initial_latents(run_context, latents)
        
        noises = torch.randn(
            #(run_context.steps, 1, unet.config.in_channels, height // vae_dim_scale, width // vae_dim_scale),
            (1, 1, unet.config.in_channels, height // vae_dim_scale, width // vae_dim_scale),
            device="cuda",
            dtype=main_dtype
        )

        latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[0])
        
        original_size = (height, width)
        target_size = (height, width)
        crop_coords_top_left = (0, 0)

        # incomprehensible var name tbh go read the sdxl paper if u want to Understand
        add_time_ids = torch.tensor([list(original_size + crop_coords_top_left + target_size)], dtype=unet_dtype).repeat(unet_batch_size,1).to("cuda")
    
        added_cond_kwargs = {"text_embeds": enc2_pooled.to(unet_dtype), "time_ids": add_time_ids}

        
        out_index = 0
        with Timer("core loop"):
            for step_index in range(run_context.steps):
                step_context = Context(run_context)
                step_context.step_index = step_index
                add_step_context(step_context)

                #lerp_term = (part_signal[diffusion_timesteps[step_index]] + part_signal[diffusion_timesteps[step_index+1]]) / 2
                #step_context.sqrt_signal = part_signal[diffusion_timesteps[step_index+1]] ** 0.5
                #step_context.pnoise = (1-part_noise[diffusion_timesteps[step_index+1]]) ** 0.5
                #step_context.lerp_by_noise = lambda a, b: lerp(a, b, part_signal[diffusion_timesteps[step_index+1]] ** 0.5)

                noise = noises[0]


                start_timestep = index_interpolate(diffusion_timesteps, step_index).round().int()
                end_timestep = index_interpolate(diffusion_timesteps, step_index + 1).round().int()

                
                step_context.sqrt_signal = part_signal[end_timestep] ** 0.5
                step_context.pnoise = (1-part_noise[end_timestep]) ** 0.5
                step_context.lerp_by_noise = lambda a, b: lerp(a, b, part_signal[end_timestep] ** 0.5)
                
                #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], diffusion_timesteps[step_index])
                #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], start_timestep)
                
                def predict_noise(latents, step=0):
                    predictions = unet(
                        latents.repeat(unet_batch_size, 1, 1, 1).to(unet_dtype),
                        index_interpolate(diffusion_timesteps, step_index + step).round().int(), 
                        encoder_hidden_states=all_penult_states.to(unet_dtype),
                        return_dict=False, 
                        added_cond_kwargs=added_cond_kwargs
                    )[0]
    
                    return combine_predictions(step_context)(predictions.to(main_dtype), noise)
                
                def standard_diffusion_step(latents, noise, start, end):
                    start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
                    end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
                    return step_by_noise(latents, noise, start_timestep, end_timestep)

                def cfgpp_diffusion_step(latents, start, end, lda):
                    start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
                    end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()

                    predictions = unet(
                        latents.repeat(unet_batch_size, 1, 1, 1).to(unet_dtype),
                        start_timestep, 
                        encoder_hidden_states=all_penult_states.to(unet_dtype),
                        return_dict=False, 
                        added_cond_kwargs=added_cond_kwargs
                    )[0]

                    eps = predictions[0] + lda * (predictions[1] - predictions[0])
                    
                    x_c = (latents - part_noise[start_timestep] * eps) / part_signal[start_timestep]
                    return part_signal[end_timestep] * x_c + part_noise[end_timestep] * predictions[0]

                def cfg_diffusion_step(latents, start, end, cfg_scale):
                    start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
                    end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()

                    predictions = unet(
                        latents.repeat(unet_batch_size, 1, 1, 1).to(unet_dtype),
                        start_timestep, 
                        encoder_hidden_states=all_penult_states.to(unet_dtype),
                        return_dict=False, 
                        added_cond_kwargs=added_cond_kwargs
                    )[0]

                    eps = predictions[0] + cfg_scale * (predictions[1] - predictions[0])
                    
                    x_c = (latents - part_noise[start_timestep] * eps) / part_signal[start_timestep]
                    return part_signal[end_timestep] * x_c + part_noise[end_timestep] * eps

                def foo_step(latents, start, end):
                    noise_prediction = solver_step(step_context)(predict_noise, standard_diffusion_step, latents)
                    return standard_diffusion_step(latents, noise_prediction, 0, 1)
                
                if method == "foo":
                    latents = foo_step(latents, 0, 1)
                if method == "cfg++":
                    latents = cfgpp_diffusion_step(latents, 0, 1, lda(step_context))
                if method == "cfg":
                    latents = cfg_diffusion_step(latents, 0, 1, cfg_scale(step_context))
                
                if step_index < run_context.steps - 1:
                    pred_original_sample = step_by_noise(latents, noise, diffusion_timesteps[step_index+1], diffusion_timesteps[-1])
                    #pred_original_sample = step_by_noise(latents, noise, end_timestep, diffusion_timesteps[-1])
                else:
                    pred_original_sample = latents
                
                #latents = step_by_noise(pred_original_sample, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index])
                #latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index])

                #latents = pred_original_sample
                
                if save_raw(step_context):
                    save_raw_latents(pred_original_sample)
                if save_approximates(step_context):
                    save_approx_decode(pred_original_sample, out_index)
                    out_index += 1

                #if step_index > run_context.steps - 4:

            images_pil = pilify(pred_original_sample, vae)
    
            for im in images_pil:
                display(im)

            if save_output(run_context):
                for n in range(len(images_pil)):
                    images_pil[n].save(f"{settings_directory}/{run_id}_final{n}.png")


In [None]:
# # # save # # #

Path(daily_directory).mkdir(exist_ok=True, parents=True)
Path(f"{daily_directory}/{settings_id}_{run_id}").mkdir(exist_ok=True, parents=True)

for n in range(len(images_pil)):
    images_pil[n].save(f"{daily_directory}/{settings_id}_{run_id}/{n}.png")