In [None]:
%reload_ext autoreload
%autoreload 2

import gc
import importlib
import torch
import numpy as np
import transformers
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, EulerDiscreteScheduler
from PIL import Image
from IPython.display import display, clear_output

from autumn.notebook import *
from autumn.math import *
from autumn.images import *
from autumn.prompting import PromptEncoder

In [None]:
# # # torch / cuda # # #

DTYPE = torch.float16

torch.backends.cuda.matmul.allow_tf32 = True

def save_raw_latents(latents):
    lmin = latents.min()
    l = latents - lmin
    lmax = latents.max()
    l = latents / lmax
    l = l.float() / 2 + 0.5
    l = l.detach().cpu().numpy() * 255
    l = l.round().astype("uint8")
    

    ims = []
    
    for lat in l:
        row1 = np.concatenate([lat[0], lat[1]])
        row2 = np.concatenate([lat[2], lat[3]])
        grid = np.concatenate([row1, row2], axis=1)
        #for channel in lat:
        im = Image.fromarray(grid)
        im = im.resize(size=(1024, 1024), resample=Image.NEAREST)
        ims += [im]
    
    #clear_output()
    #print(f"normalized with coefficients  lmin={lmin}  lmax={lmax}")
    for im in ims:
        im.save("out/tmp_raw_latents.png")
        #display(im)

approximation_matrix = [
    [0.85, 0.85, 0.6], # seems to be mainly value
    [-0.35, 0.2, 0.5], # mainly blue? maybe a little green, def not red
    [0.15, 0.15, 0], # yellow. but mainly encoding texture not color, i think
    [0.15, -0.35, -0.35] # inverted value? but also red
]

def save_approx_decode(latents):
    lmin = latents.min()
    l = latents - lmin
    lmax = latents.max()
    l = latents / lmax
    l = l.float() / 2 + 0.5
    ims = []
    for lat in l:
        apx_mat = torch.tensor(approximation_matrix).to("cuda")
        approx_decode = torch.einsum("...lhw,lr -> ...rhw", lat, apx_mat)
        lat -= lat.min()
        lat /= lat.max()
        im_data = approx_decode.permute(1,2,0).detach().cpu().numpy() * 255
        im_data = im_data.round().astype("uint8")
        im = Image.fromarray(im_data).resize(size=(1024,1024), resample=Image.NEAREST)
        ims += [im]

    #clear_output()
    for im in ims:
        im.save("out/tmp_approx_decode.png")
        #display(im)


In [None]:
%%settings

model = "model/name/goes/here"
XL_MODEL = True

p0 = "sixteenth century painting of a very cute dog rolling around in a grassy field, oil on canvas, neo-cubist glitch art, extremely detailed maya render of a fractal"

np0 = "blurry, ugly, indistinct, jpeg artifacts, watermark, text, signature"

# Generally: e2_prompt has the strongest effect, but terms from e1_prompt will show up in results. e2_pool_prompts has an influence but very indirect idk

# Prompts for encoder 1
e1_prompts = [np0, p0]

# Prompts for encoder 2; defaults to e1_prompts if None
e2_prompts = None

# Prompts for pooled encoding of encoder 2; defaults to e2_prompts if None
e2_pool_prompts = None

# Method by which predictions for different prompts will be recombined to make one noise prediction
scale = 10
combine_predictions = lambda p: p[0] + scale * (p[1] - p[0])

seed = 42069

# 16 => 1024
height = 16
width = 16

steps = 25

#distort = sigmoid(0.319, 0.987, 0.0304, 0.025)
#distort = scale_f(distort, 127, 1.0)

embedding_distortion = None

#!# Settings above this line will be replaced with the contents of settings.py if it exists. #!#

height *= 64;
width *= 64;

In [None]:
# # # models # # #

model_source = model

with Timer("total"):
    with Timer("vae"):
        vae = AutoencoderKL.from_pretrained(
            model_source, subfolder="vae", torch_dtype=torch.float32
        )
        vae.to(device="cuda")

        # (as of torch 2.2.2) VAE compilation yields negligible gains, at least for decoding once/step
        #vae = torch.compile(vae, mode="default", fullgraph=True)
    
    with Timer("unet"):
        unet = UNet2DConditionModel.from_pretrained(
            model_source, subfolder="unet", torch_dtype=DTYPE
        )
        unet.to(device="cuda")
    
        # compilation will not actually happen until first use of unet
        # (as of torch 2.2.2) "default" provides the best result on my machine
        # don't use this if you're gonna be changing resolutions a lot
        unet_c = torch.compile(unet, mode="default", fullgraph=True)
    
    with Timer("clip"):
        p_encoder = PromptEncoder(model_source, XL_MODEL, torch_dtype=DTYPE)
    
    with Timer("scheduler"):
        scheduler = EulerDiscreteScheduler.from_pretrained(
            model, subfolder="scheduler", torch_dtype=DTYPE
        )


In [None]:
# # # run # # #

run_id = uuid.uuid4()

try:
    _seed = int(seed)
except:
    _seed = random.randint(1, 2**31 - 1)

torch.manual_seed(_seed)
np.random.seed(_seed)

unet_batch_size = len(e1_prompts)

(all_penult_states, enc2_pooled) = p_encoder.encode(e1_prompts, e2_prompts, e2_pool_prompts)

if embedding_distortion is not None:
    all_penult_states = svd_distort_embeddings(all_penult_states.to(torch.float32), embedding_distortion).to(torch.float16)

with torch.no_grad():
    vae_scale = 2 ** (len(vae.config.block_out_channels) - 1)

    latents = torch.randn(
        (1, unet.config.in_channels, height // vae_scale, width // vae_scale),
        device="cuda",
        dtype=DTYPE,
    )

    scheduler.set_timesteps(steps)
    
    latents = latents * scheduler.init_noise_sigma
    #latents = vae.encode(latents.to(torch.float32)).latent_dist.sample().to(torch.float16)

    original_size = (height, width)
    target_size = (height, width)
    crop_coords_top_left = (0, 0)

    add_time_ids = list(original_size + crop_coords_top_left + target_size)

    passed_add_embed_dim = (unet.config.addition_time_embed_dim * len(add_time_ids) + p_encoder.text_encoder_2.config.projection_dim)
    expected_add_embed_dim = unet.add_embedding.linear_1.in_features

    if passed_add_embed_dim != expected_add_embed_dim:
        print("embed dim is messed up")

    add_time_ids = torch.tensor([add_time_ids], dtype=DTYPE).repeat(unet_batch_size,1).to("cuda")

    added_cond_kwargs = {"text_embeds": enc2_pooled, "time_ids": add_time_ids}

    with Timer("core loop"):
        for step in scheduler.timesteps:
            latents_expanded = latents.repeat(unet_batch_size, 1, 1, 1)
            
            latents_expanded = scheduler.scale_model_input(
                latents_expanded, timestep=step
            )

            noise_prediction = unet(
                latents_expanded, step, return_dict=False, encoder_hidden_states=all_penult_states,
                added_cond_kwargs=added_cond_kwargs
            )[0]
            
            predictions_split = noise_prediction.chunk(unet_batch_size)
    
            noise_prediction = combine_predictions(predictions_split)
    
            sched_out = scheduler.step(
                #noise_prediction, step, svd_distort(latents.to(torch.float32), distort).to(torch.float16)
                noise_prediction, step, latents
            )
            
            latents = sched_out.prev_sample
            save_raw_latents(sched_out.pred_original_sample)
            save_approx_decode(sched_out.pred_original_sample)
        
        images_pil = PILify(sched_out.pred_original_sample, vae)

        #clear_output()
        for im in images_pil:
            display(im)


In [None]:
# # # save # # #

Path(daily_directory).mkdir(exist_ok=True, parents=True)

for n in range(len(images_pil)):
    images_pil[n].save(f"{daily_directory}/{settings_id}_{n}.png")