In [None]:
%reload_ext autoreload
%autoreload 2

import math
import gc
import importlib
import torch
import numpy as np
import transformers
from diffusers import AutoencoderKL, UNet2DConditionModel
from PIL import Image
from IPython.display import display, clear_output

from autumn.notebook import *
from autumn.math import *
from autumn.images import *
from autumn.guidance import *
from autumn.scheduling import *
from autumn.prompting import PromptEncoder

In [None]:
# # # torch / cuda # # #

main_dtype = torch.double
unet_dtype = torch.float16
vae_dtype = torch.float32
encoder_dtype = torch.float32

torch.backends.cuda.matmul.allow_tf32 = True

def save_raw_latents(latents):
    lmin = latents.min()
    l = latents - lmin
    lmax = latents.max()
    l = latents / lmax
    l = l.float() * 127.5 + 127.5
    l = l.detach().cpu().numpy()
    l = l.round().astype("uint8")
    

    ims = []
    
    for lat in l:
        row1 = np.concatenate([lat[0], lat[1]])
        row2 = np.concatenate([lat[2], lat[3]])
        grid = np.concatenate([row1, row2], axis=1)
        #for channel in lat:
        im = Image.fromarray(grid)
        im = im.resize(size=(grid.shape[1]*4, grid.shape[0]*4), resample=Image.NEAREST)
        ims += [im]
    
    #clear_output()
    #print(f"normalized with coefficients  lmin={lmin}  lmax={lmax}")
    for im in ims:
        im.save("out/tmp_raw_latents.bmp")
        #display(im)

approximation_matrix = [
    [0.85, 0.85, 0.6], # seems to be mainly value
    [-0.35, 0.2, 0.5], # mainly blue? maybe a little green, def not red
    [0.15, 0.15, 0], # yellow. but mainly encoding texture not color, i think
    [0.15, -0.35, -0.35] # inverted value? but also red
]

def save_approx_decode(latents, index):
    lmin = latents.min()
    l = latents - lmin
    lmax = latents.max()
    l = latents / lmax
    l = l.float().mul_(0.5).add_(0.5)
    ims = []
    for lat in l:
        apx_mat = torch.tensor(approximation_matrix).to("cuda")
        approx_decode = torch.einsum("...lhw,lr -> ...rhw", lat, apx_mat).mul_(255).round()
        #lat -= lat.min()
        #lat /= lat.max()
        im_data = approx_decode.permute(1,2,0).detach().cpu().numpy().astype("uint8")
        #im_data = im_data.round().astype("uint8")
        im = Image.fromarray(im_data).resize(size=(im_data.shape[1]*8,im_data.shape[0]*8), resample=Image.NEAREST)
        ims += [im]

    #clear_output()
    for im in ims:
        #im.save(f"out/tmp_approx_decode/{index:06d}.bmp")
        im.save(f"out/tmp_approx_decode.bmp")
        #display(im)


In [None]:
%%settings

# model can be a local path or a huggingface model name
model = "stabilityai/stable-diffusion-xl-base-0.9"
XL_MODEL = True # this notebook in its current state is only really guaranteed to work with SDXL
vae_scale = 0.13025 # found in the model config

run_ids = range(1)

# run context will contain only context.run_id, anything returned in this dictionary will be added to it
add_run_context = lambda context: {}

p0 = "photograph of a very cute dog"
np0 = "blurry, ugly, indistinct, jpeg artifacts, watermark, text, signature"

# Generally: e2_prompt has the strongest effect, but terms from e1_prompt will show up in results.
# e2_pool_prompts has an influence but very indirect idk
# the model is trained w/ all encodings coming from the same prompt

prompts = lambda context: {
    # Prompts for encoder 1
    "encoder_1": [np0, p0],
    # Prompts for encoder 2; defaults to e1_prompts if None
    "encoder_2": None,
    # Prompts for pooled encoding of encoder 2; defaults to e2_prompts if None
    "encoder_2_pooled": None
}

# Method by which predictions for different prompts will be recombined to make one noise prediction
combine_predictions = lambda context: scaled_CFG(
    difference_scales = [
        (1, 0, sigmoid(0.1 * vae_scale, 3 * vae_scale))
    ], 
    steering_scale = id_, base_scale = id_, total_scale = id_
)

seed = lambda context: 42069

# these get multiplied by 64
width_height = lambda context: (16, 16)

steps = lambda context: 25

# step context will contain run_id & step_index, anything returned in this dictionary will be added to it
add_step_context = lambda context: {}

#distort = sigmoid(0.319, 0.987, 0.0304, 0.025)
#distort = scale_f(distort, 127, 1.0)
embedding_distortion = lambda context: None

dynamic_thresholding = lambda context: False
dynthresh_percentile = lambda context: 0.995
dynthresh_target = lambda context: 7

naive_rescaling = lambda context: False

save_output = lambda context: True
save_approximates = lambda context: False
save_raw = lambda context: False

#!# Settings above this line will be replaced with the contents of settings.py if it exists. #!#

In [None]:
# # # models # # #

model_source = model

with Timer("total"):
    with Timer("vae"):
        vae = AutoencoderKL.from_pretrained(
            model_source, subfolder="vae", torch_dtype=vae_dtype
        )
        vae.to(device="cuda")

        # (as of torch 2.2.2) VAE compilation yields negligible gains, at least for decoding once/step
        #vae = torch.compile(vae, mode="default", fullgraph=True)
    
    with Timer("unet"):
        unet = UNet2DConditionModel.from_pretrained(
            model_source, subfolder="unet", torch_dtype=unet_dtype
        )
        unet.to(device="cuda")
    
        # compilation will not actually happen until first use of unet
        # (as of torch 2.2.2) "default" provides the best result on my machine
        # don't use this if you're gonna be changing resolutions a lot
        unet_c = torch.compile(unet, mode="default", fullgraph=True)
    
    with Timer("clip"):
        p_encoder = PromptEncoder(model_source, XL_MODEL, torch_dtype=encoder_dtype)


In [None]:
# # # run # # #

variance_range = (0.00085, 0.012) # should come from model config!
forward_noise_schedule = default_variance_schedule(variance_range).to(main_dtype) # beta
forward_signal_product = torch.cumprod((1 - forward_noise_schedule), dim=0) # alpha_bar
partial_signal_product = lambda s, t: torch.prod((1 - forward_noise_schedule)[s+1:t]) # alpha_bar_t / alpha_bar_s

for run_id in run_ids:
    run_context = Context()
    run_context.run_id = run_id
    add_run_context(run_context)
    
    try:
        _seed = int(seed(run_context))
    except:
        _seed = 0
        print(f"non-integer seed, run {run_id}. replaced with 0.")
    
    torch.manual_seed(_seed)
    np.random.seed(_seed)

    run_prompts = prompts(run_context)
    
    unet_batch_size = len(run_prompts["encoder_1"])
    
    (all_penult_states, enc2_pooled) = p_encoder.encode(run_prompts["encoder_1"], run_prompts["encoder_2"], run_prompts["encoder_2_pooled"])
    
    if embedding_distortion(run_context) is not None:
        all_penult_states = svd_distort_embeddings(all_penult_states.to(torch.float32), embedding_distortion(run_context)).to(torch.float16)

    width, height = width_height(run_context)

    if (width < 64): width *= 64
    if (height < 64): height *= 64
    
    with torch.no_grad():
        vae_dim_scale = 2 ** (len(vae.config.block_out_channels) - 1)
    
        latents = torch.randn(
            (1, unet.config.in_channels, height // vae_dim_scale, width // vae_dim_scale),
            device="cuda",
            dtype=main_dtype,
        )

        _steps = steps(run_context)
        run_context.steps = _steps
    
        diffusion_timesteps = torch.arange(999, 0, -1000/_steps).round().int()
        
        step_signal_product = torch.from_numpy(np.interp(diffusion_timesteps, np.arange(0, len(forward_signal_product)), forward_signal_product))
        
        original_size = (height, width)
        target_size = (height, width)
        crop_coords_top_left = (0, 0)
    
        add_time_ids = list(original_size + crop_coords_top_left + target_size)
    
        passed_add_embed_dim = (unet.config.addition_time_embed_dim * len(add_time_ids) + p_encoder.text_encoder_2.config.projection_dim)
        expected_add_embed_dim = unet.add_embedding.linear_1.in_features
    
        if passed_add_embed_dim != expected_add_embed_dim:
            print("embed dim is messed up")
    
        add_time_ids = torch.tensor([add_time_ids], dtype=unet_dtype).repeat(unet_batch_size,1).to("cuda")
    
        added_cond_kwargs = {"text_embeds": enc2_pooled.to(unet_dtype), "time_ids": add_time_ids}
    
        out_index = 0
        with Timer("core loop"):
            for step_index in range(_steps):
                step_context = Context(run_context)
                step_context.step_index = step_index
                step_context.diffusion_timestep = diffusion_timesteps[step_index]
                add_step_context(step_context)
                
                latents_expanded = latents.repeat(unet_batch_size, 1, 1, 1)

                noise_prediction = unet(
                    latents_expanded.to(unet_dtype), 
                    step_context.diffusion_timestep, 
                    encoder_hidden_states=all_penult_states.to(unet_dtype),
                    return_dict=False, 
                    added_cond_kwargs=added_cond_kwargs
                )[0]
                
                predictions_split = noise_prediction.to(main_dtype).chunk(unet_batch_size)
    
                noise_prediction = combine_predictions(step_context)(predictions_split)
    
                if dynamic_thresholding(step_context):
                    apply_dynthresh(predictions_split, noise_prediction, dynthresh_target(step_context), dynthresh_percentile(step_context))
    
                if naive_rescaling(step_context):
                    apply_naive_rescale(predictions_split, noise_prediction)
                
                current_part_noise = (1 - step_signal_product[step_index]).sqrt()
                current_part_signal = step_signal_product[step_index].sqrt()
                next_part_noise = 0
                signal_ratio = 1
                if step_index < _steps - 1:
                    next_part_noise = (1 - step_signal_product[step_index + 1]).sqrt()
                    signal_ratio = partial_signal_product(diffusion_timesteps[step_index+1], diffusion_timesteps[step_index]).sqrt()

                pred_original_sample = (latents - current_part_noise * noise_prediction) / current_part_signal
                
                prev_sample = latents / signal_ratio + noise_prediction * (next_part_noise - current_part_noise / signal_ratio)
                
                latents = prev_sample
                
                if save_raw(step_context):
                    save_raw_latents(pred_original_sample)
                if save_approximates(step_context):
                    save_approx_decode(pred_original_sample, out_index)
                    out_index += 1
            
            images_pil = pilify(pred_original_sample, vae)
    
            for im in images_pil:
                display(im)

            if save_output(run_context):
                for n in range(len(images_pil)):
                    images_pil[n].save(f"{settings_directory}/{run_id}_final{n}.png")


In [None]:
images_pil = pilify(sigmoid(5, 2 / vae_scale)(sched_out.pred_original_sample), vae)
    
for im in images_pil:
    display(im)

images_pil = pilify(sigmoid(6, 2 / vae_scale)(sched_out.pred_original_sample), vae)
    
for im in images_pil:
    display(im)

In [None]:
# # # save # # #

Path(daily_directory).mkdir(exist_ok=True, parents=True)
Path(f"{daily_directory}/{settings_id}_{run_id}").mkdir(exist_ok=True, parents=True)

for n in range(len(images_pil)):
    images_pil[n].save(f"{daily_directory}/{settings_id}_{run_id}/{n}.png")