In [None]:
%reload_ext autoreload
%autoreload 2

import math
import gc
import importlib
import torch
import numpy as np
import transformers
from diffusers import AutoencoderKL, UNet2DConditionModel
from PIL import Image
from IPython.display import display, clear_output

from autumn.notebook import *
from autumn.math import *
from autumn.images import *
from autumn.guidance import *
from autumn.scheduling import *
from autumn.prompting import PromptEncoder

In [None]:
%%settings

# model can be a local path or a huggingface model name
model = "stabilityai/stable-diffusion-xl-base-0.9"
XL_MODEL = True # this notebook in its current state is only really guaranteed to work with SDXL
vae_scale = 0.13025 # found in the model config

run_ids = range(1)

# run context will contain only context.run_id, anything returned in this dictionary will be added to it
add_run_context = lambda context: {}

p0 = "photograph of a very cute dog"
np0 = "blurry, ugly, indistinct, jpeg artifacts, watermark, text, signature"

# Generally: e2_prompt has the strongest effect, but terms from e1_prompt will show up in results.
# e2_pool_prompts has an influence but very indirect idk
# the model is trained w/ all encodings coming from the same prompt

prompts = lambda context: {
    # Prompts for encoder 1
    "encoder_1": [np0, p0],
    # Prompts for encoder 2; defaults to e1_prompts if None
    "encoder_2": None,
    # Prompts for pooled encoding of encoder 2; defaults to e2_prompts if None
    "encoder_2_pooled": None
}

# Method by which predictions for different prompts will be recombined to make one noise prediction
combine_predictions = lambda context: scaled_CFG(
    difference_scales = [
        (1, 0, sigmoid(0.1 * vae_scale, 3 * vae_scale))
    ], 
    steering_scale = id_, base_scale = id_, total_scale = id_
)

seed = lambda context: 42069

# these get multiplied by 64
width_height = lambda context: (16, 16)

steps = lambda context: 25

# step context will contain run_id & step_index, anything returned in this dictionary will be added to it
add_step_context = lambda context: {}

#distort = sigmoid(0.319, 0.987, 0.0304, 0.025)
#distort = scale_f(distort, 127, 1.0)
embedding_distortion = lambda context: None

dynamic_thresholding = lambda context: False
dynthresh_percentile = lambda context: 0.995
dynthresh_target = lambda context: 7

naive_rescaling = lambda context: False

save_output = lambda context: True
save_approximates = lambda context: False
save_raw = lambda context: False

#!# Settings above this line will be replaced with the contents of settings.py if it exists. #!#

main_dtype = torch.float64
unet_dtype = torch.float16
vae_dtype = torch.float32
encoder_dtype = torch.float16

torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
# # # models # # #

model_source = model

with Timer("total"):
    with Timer("vae"):
        vae = AutoencoderKL.from_pretrained(
            model_source, subfolder="vae", torch_dtype=vae_dtype
        )
        vae.to(device="cuda")

        # (as of torch 2.2.2) VAE compilation yields negligible gains, at least for decoding once/step
        #vae = torch.compile(vae, mode="default", fullgraph=True)
    
    with Timer("unet"):
        unet = UNet2DConditionModel.from_pretrained(
            model_source, subfolder="unet", torch_dtype=unet_dtype
        )
        unet.to(device="cuda")
    
        # compilation will not actually happen until first use of unet
        # (as of torch 2.2.2) "default" provides the best result on my machine
        # don't use this if you're gonna be changing resolutions a lot
        unet_c = torch.compile(unet, mode="default", fullgraph=True)
    
    with Timer("clip"):
        p_encoder = PromptEncoder(model_source, XL_MODEL, torch_dtype=encoder_dtype)


In [None]:
# # # run # # #

variance_range = (0.00085, 0.012) # should come from model config!
forward_noise_schedule = default_variance_schedule(variance_range).to(main_dtype) # beta
forward_signal_product = torch.cumprod((1 - forward_noise_schedule), dim=0) # alpha_bar
partial_signal_product = lambda s, t: torch.prod((1 - forward_noise_schedule)[s+1:t]) # alpha_bar_t / alpha_bar_s (but computed more directly from the forward noise)
part_noise = (1 - forward_signal_product).sqrt() # sigma
part_signal = forward_signal_product.sqrt() # mu?

def step_by_noise(latents, noise, from_timestep, to_timestep):
    if from_timestep < to_timestep: # forward
        signal_ratio = 1 / partial_signal_product(from_timestep, to_timestep).sqrt()
    else: # backward
        signal_ratio = partial_signal_product(to_timestep, from_timestep).sqrt()
    return latents / signal_ratio + noise * (part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio)

def noisiness(at_timestep):
    signal_ratio = partial_signal_product(at_timestep, 999).sqrt()
    return part_noise[999] - part_noise[at_timestep] / signal_ratio

def report_passthrough(c, x, s):
    hist = torch.histogram(x.float().cpu() / vae_scale, bins=bins*s).hist

    width = (len(bins) + 2) * 5
    height = 100
    plot = torch.ones([height, width])
    print(f"max {hist.max()}")
    hist /= hist.max()
    
    for i in range(len(bins) - 1):
        bottom = height - 11
        top = height - (int((height - 21) * (hist[i].item()) + 11))
        left = 5 * (i + 1) + 2
        right = 5 * (i + 1) + 4
        plot[top:bottom,left:right] = 0

    plot[:,width//2] = 0.5
    plot[height-10,:] = 0.5
    
    mshow(plot)

    return x


for run_id in run_ids:
    run_context = Context()
    run_context.run_id = run_id
    add_run_context(run_context)
    
    try:
        _seed = int(seed(run_context))
    except:
        _seed = 0
        print(f"non-integer seed, run {run_id}. replaced with 0.")
    
    torch.manual_seed(_seed)
    np.random.seed(_seed)

    run_context.steps = steps(run_context)

    diffusion_timesteps = linspace_timesteps(run_context.steps, timestep_max(run_context), timestep_min(run_context), timestep_power(run_context))

    run_prompts = prompts(run_context)
    
    unet_batch_size = len(run_prompts["encoder_1"])
    
    (all_penult_states, enc2_pooled) = p_encoder.encode(run_prompts["encoder_1"], run_prompts["encoder_2"], run_prompts["encoder_2_pooled"])
    
    if embedding_distortion(run_context) is not None:
        all_penult_states = svd_distort_embeddings(all_penult_states.to(torch.float32), embedding_distortion(run_context)).to(torch.float16)

    width, height = width_height(run_context)

    if (width < 64): width *= 64
    if (height < 64): height *= 64
    
    with torch.no_grad():
        vae_dim_scale = 2 ** (len(vae.config.block_out_channels) - 1)
    
        latents = torch.zeros(
            (1, unet.config.in_channels, height // vae_dim_scale, width // vae_dim_scale),
            device="cuda",
            dtype=main_dtype
        )
        
        #true_noise = torch.randn_like(latents)

        noises = torch.randn(
            (run_context.steps, 1, unet.config.in_channels, height // vae_dim_scale, width // vae_dim_scale),
            device="cuda",
            dtype=main_dtype
        )

        latents = step_by_noise(latents, true_noise[0], diffusion_timesteps[-1], diffusion_timesteps[0])
        
        original_size = (height, width)
        target_size = (height, width)
        crop_coords_top_left = (0, 0)

        # incomprehensible var name tbh go read the sdxl paper if u want to Understand
        add_time_ids = torch.tensor([list(original_size + crop_coords_top_left + target_size)], dtype=unet_dtype).repeat(unet_batch_size,1).to("cuda")
    
        added_cond_kwargs = {"text_embeds": enc2_pooled.to(unet_dtype), "time_ids": add_time_ids}

        
        out_index = 0
        with Timer("core loop"):
            #for step_index in range(run_context.steps)[::-1]:
            #    latents = step_by_noise(latents, noises[step_index], diffusion_timesteps[step_index+1], diffusion_timesteps[step_index])
            
            for step_index in range(run_context.steps):
                step_context = Context(run_context)
                step_context.step_index = step_index
                add_step_context(step_context)

                #true_noise = noises[step_index]
                #latents = step_by_noise(latents, true_noise, diffusion_timesteps[-1], diffusion_timesteps[step_index])
                scaled_true_noise = true_noise#step_by_noise(torch.zeros_like(true_noise), true_noise, diffusion_timesteps[-1], diffusion_timesteps[step_index])
                
                def predict_noise(latents, step=diffusion_timesteps[step_index]):
                    predictions = unet(
                        latents.repeat(unet_batch_size, 1, 1, 1).to(unet_dtype),
                        step, 
                        encoder_hidden_states=all_penult_states.to(unet_dtype),
                        return_dict=False, 
                        added_cond_kwargs=added_cond_kwargs
                    )[0]
    
                    return combine_predictions(step_context)(predictions.to(main_dtype), scaled_true_noise)

                if method == "euler":
                    final_prediction = predict_noise(latents)

                if method == "heun":
                    first_prediction = predict_noise(latents)

                    tentative_result = step_by_noise(latents, first_prediction, diffusion_timesteps[step_index], diffusion_timesteps[step_index + 1])
                    
                    second_prediction = predict_noise(tentative_result, diffusion_timesteps[step_index + 1])
    
                    final_prediction = (first_prediction + second_prediction) / 2

                if method == "rk2":
                    first_prediction = predict_noise(latents)
                    
                    half_step = (diffusion_timesteps[step_index + 1] + diffusion_timesteps[step_index]) // 2
                    
                    tentative_result = step_by_noise(latents, first_prediction, diffusion_timesteps[step_index], half_step)
                    
                    final_prediction = predict_noise(tentative_result, half_step)

                if method == "rk4":
                    half_step = (diffusion_timesteps[step_index + 1] + diffusion_timesteps[step_index]) // 2
                    
                    prediction_1 = predict_noise(latents)
                    result_1 = step_by_noise(latents, prediction_1, diffusion_timesteps[step_index], half_step)
                    
                    prediction_2 = predict_noise(result_1, half_step)
                    result_2 = step_by_noise(latents, prediction_2, diffusion_timesteps[step_index], half_step)
                    
                    prediction_3 = predict_noise(result_2, half_step)
                    result_3 = step_by_noise(latents, prediction_3, diffusion_timesteps[step_index], diffusion_timesteps[step_index + 1])
                    
                    prediction_4 = predict_noise(result_3, diffusion_timesteps[step_index + 1])

                    final_prediction = (prediction_1 + 2 * (prediction_2 + prediction_3) + prediction_4) / 6
                    

                pred_original_sample = step_by_noise(latents, final_prediction, diffusion_timesteps[step_index], diffusion_timesteps[-1])
                
                latents = step_by_noise(latents, final_prediction, diffusion_timesteps[step_index], diffusion_timesteps[step_index + 1])

                if save_raw(step_context):
                    save_raw_latents(pred_original_sample)
                if save_approximates(step_context):
                    save_approx_decode(pred_original_sample, out_index)
                    out_index += 1
            
            images_pil = pilify(pred_original_sample, vae)
    
            for im in images_pil:
                display(im)

            if save_output(run_context):
                for n in range(len(images_pil)):
                    images_pil[n].save(f"{settings_directory}/{run_id}_final{n}.png")


In [None]:
# # # save # # #

Path(daily_directory).mkdir(exist_ok=True, parents=True)
Path(f"{daily_directory}/{settings_id}_{run_id}").mkdir(exist_ok=True, parents=True)

for n in range(len(images_pil)):
    images_pil[n].save(f"{daily_directory}/{settings_id}_{run_id}/{n}.png")

In [None]:
images_pil = pilify(pred_original_sample * 5 -4, vae)
    
for im in images_pil:
    display(im)