In [1]:
%reload_ext autoreload
%autoreload 2

import math
import gc
import importlib
import torch
import torch.nn as nn
from torch.nn import functional as func
import numpy as np
import transformers
from diffusers import UNet2DConditionModel
from PIL import Image
from matplotlib import pyplot as plt
from IPython.display import display, clear_output

from autumn.notebook import *
from autumn.math import *
from autumn.images import *
from autumn.guidance import *
from autumn.scheduling import *
from autumn.solvers import *
from autumn.py import *

from models.clip import PromptEncoder
from models.sdxl import Decoder

notebook reloaded
math reloaded


In [None]:
%%settings

# Most settings are specified as lambdas that take in a "context" dict that can provide the current step index, run index, &c

# model needs to be a local path actually, easiest to use HF lib to download the model into the HF cache
model = "stabilityai/stable-diffusion-xl-base-0.9"
XL_MODEL = True # this notebook in its current state is only really guaranteed to work with SDXL
latent_scale = 0.13025 # found in the model config

run_ids = range(5)

# run context will contain only context.run_id, anything returned in this dictionary will be added to it
add_run_context = lambda context: {}

p0 = "photograph of a very cute dog"
#np0 = "blurry, ugly, indistinct, jpeg artifacts, watermark, text, signature"

prompts = lambda context: {
    # Prompts for encoder 1
    "encoder_1": [p0],
    # Prompts for encoder 2; defaults to same as encoder_1 if None
    "encoder_2": None,
    # Prompts for pooled encoding of encoder 2; defaults to same as encoder_2 if None
    "encoder_2_pooled": None
}

method = "custom"

# Method by which predictions for different prompts will be recombined to make one noise prediction, for "custom" method.
combine_predictions = lambda context: true_noise_removal(context, [1])

seed = lambda context: 42069 + context.run_id

# these get multiplied by 64
width_height = lambda context: (16, 16)

steps = lambda context: 15

# for scaling by a sqrt or **2 curve, &c
timestep_power = lambda c: 1
timestep_max = lambda c: 999
timestep_min = lambda c: 0

# differential equation solver. see autumn/solvers.py
solver_step = lambda c: rk4_step

# step context will contain run_id & step_index, anything returned in this dictionary will be added to it
add_step_context = lambda context: {}

embedding_distortion = lambda context: None

save_output = lambda context: True
save_approximates = lambda context: False
save_raw = lambda context: False

def modify_initial_latents(context, latents):
    pass

#!# Settings above this line will be replaced with the contents of settings.py if it exists. #!#

main_device = "cuda:0"
decoder_device = "cuda:1"
clip_device = "cuda:1"
main_dtype = torch.float64
noise_predictor_dtype = torch.float16
decoder_dtype = torch.float32
prompt_encoder_dtype = torch.float16

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("medium")

In [None]:
# # # models # # #

torch.set_grad_enabled(False)

with Timer("total"):
    with Timer("decoder"):
        decoder = Decoder()
        decoder.load_safetensors(decoder_model)
        decoder.to(device=decoder_device)
    
        #decoder = torch.compile(decoder, mode="default", fullgraph=True)
    
    with Timer("noise_predictor"):
        noise_predictor = UNet2DConditionModel.from_pretrained(
            noise_predictor_model, subfolder="unet", torch_dtype=noise_predictor_dtype
        )
        noise_predictor.to(device=main_device)
    
        # compilation will not actually happen until first use of noise_predictor
        # (as of torch 2.2.2) "default" provides the best result on my machine
        # don't use this if you're gonna be changing resolutions a lot
        #noise_predictor = torch.compile(noise_predictor, mode="default", fullgraph=True)
    
    with Timer("clip"):
        prompt_encoder = PromptEncoder(base_model, XL_MODEL, (clip_device, main_device), prompt_encoder_dtype)


In [None]:
# # # run # # #

variance_range = (0.00085, 0.012) # should come from model config!
forward_noise_schedule = default_variance_schedule(variance_range).to(main_dtype) # beta
forward_noise_total = forward_noise_schedule.cumsum(dim=0)
forward_signal_product = torch.cumprod((1 - forward_noise_schedule), dim=0) # alpha_bar
partial_signal_product = lambda s, t: torch.prod((1 - forward_noise_schedule)[s+1:t]) # alpha_bar_t / alpha_bar_s (but computed more directly from the forward noise)
part_noise = (1 - forward_signal_product).sqrt() # sigma
part_signal = forward_signal_product.sqrt() # mu?

def get_signal_ratio(from_timestep, to_timestep):
    if from_timestep < to_timestep: # forward
        return 1 / partial_signal_product(from_timestep, to_timestep).sqrt()
    else: # backward
        return partial_signal_product(to_timestep, from_timestep).sqrt()

def step_by_noise(latents, noise, from_timestep, to_timestep):
    signal_ratio = get_signal_ratio(from_timestep, to_timestep)
    return latents / signal_ratio + noise * (part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio)

def stupid_simple_step_by_noise(latents, noise, from_timestep, to_timestep):
    signal_ratio = get_signal_ratio(from_timestep, to_timestep)
    return latents / signal_ratio + noise * (1 - 1 / signal_ratio)

def cfgpp_step_by_noise(latents, combined, base, from_timestep, to_timestep):
    signal_ratio = get_signal_ratio(from_timestep, to_timestep)
    return latents / signal_ratio + base * part_noise[to_timestep] - combined * (part_noise[from_timestep] / signal_ratio)

def tnr_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep):
    signal_ratio = get_signal_ratio(from_timestep, to_timestep)
    diff_coefficient = part_noise[from_timestep] / signal_ratio
    base_coefficient = part_noise[to_timestep] - diff_coefficient
    #print((1/signal_ratio).item(), base_coefficient.item(), diff_coefficient.item())
    return latents / signal_ratio + base_term * base_coefficient + diff_term * diff_coefficient

def tnrb_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep):
    signal_ratio = get_signal_ratio(from_timestep, to_timestep)
    base_coefficient = part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio
    measure = lambda x: x.abs().max().item()
    #print(measure(latents / signal_ratio), measure(base_term * base_coefficient), measure(diff_term))
    return latents / signal_ratio + base_term * base_coefficient + diff_term

def shuffle_step(latents, first_noise, second_noise, timestep, intermediate_timestep):
    if from_timestep < to_timestep: # forward
        signal_ratio = 1 / partial_signal_product(timestep, intermediate_timestep).sqrt()
    else: # backward
        signal_ratio = partial_signal_product(intermediate_timestep, timestep).sqrt()
    return latents + (first_noise - second_noise) * (part_noise[intermediate_timestep] * signal_ratio - part_noise[timestep])

for run_id in run_ids:
    run_context = Context()
    run_context.run_id = run_id
    add_run_context(run_context)
    
    try:
        _seed = int(seed(run_context))
    except:
        _seed = 0
        print(f"non-integer seed, run {run_id}. replaced with 0.")
    
    torch.manual_seed(_seed)
    np.random.seed(_seed)

    run_context.steps = steps(run_context)

    diffusion_timesteps = linspace_timesteps(run_context.steps+1, timestep_max(run_context), timestep_min(run_context), timestep_power(run_context))

    run_prompts = prompts(run_context)
    
    noise_predictor_batch_size = len(run_prompts["encoder_1"])
    
    (all_penult_states, enc2_pooled) = prompt_encoder.encode(run_prompts["encoder_1"], run_prompts["encoder_2"], run_prompts["encoder_2_pooled"])

    for index in range(all_penult_states.shape[0]):
        run_context.embedding_index = index
        if embedding_distortion(run_context) is not None:
            all_penult_states[index] = svd_distort_embeddings(all_penult_states[index].to(main_dtype), embedding_distortion(run_context)).to(noise_predictor_dtype)

    width, height = width_height(run_context)

    if (width < 64): width *= 64
    if (height < 64): height *= 64
    
    #with torch.no_grad():
    decoder_dim_scale = 2 ** 3

    latents = torch.zeros(
        (1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
        device=main_device,
        dtype=main_dtype
    )

    
    noises = torch.randn(
        #(run_context.steps, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
        (1, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
        device=main_device,
        dtype=main_dtype
    )

    latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[0])
    modify_initial_latents(run_context, latents)
    
    original_size = (height, width)
    target_size = (height, width)
    crop_coords_top_left = (0, 0)

    # incomprehensible var name tbh go read the sdxl paper if u want to Understand
    add_time_ids = torch.tensor([list(original_size + crop_coords_top_left + target_size)], dtype=noise_predictor_dtype).repeat(noise_predictor_batch_size,1).to("cuda")

    added_cond_kwargs = {"text_embeds": enc2_pooled.to(noise_predictor_dtype), "time_ids": add_time_ids}

    
    out_index = 0
    with Timer("core loop"):
        for step_index in range(run_context.steps):
            step_context = Context(run_context)
            step_context.step_index = step_index
            add_step_context(step_context)

            #lerp_term = (part_signal[diffusion_timesteps[step_index]] + part_signal[diffusion_timesteps[step_index+1]]) / 2
            #step_context.sqrt_signal = part_signal[diffusion_timesteps[step_index+1]] ** 0.5
            #step_context.pnoise = (1-part_noise[diffusion_timesteps[step_index+1]]) ** 0.5
            #step_context.lerp_by_noise = lambda a, b: lerp(a, b, part_signal[diffusion_timesteps[step_index+1]] ** 0.5)

            noise = noises[0]


            start_timestep = index_interpolate(diffusion_timesteps, step_index).round().int()
            end_timestep = index_interpolate(diffusion_timesteps, step_index + 1).round().int()

            # ew TODO refactor this
            step_context.end_noise = part_noise[end_timestep]
            step_context.end_signal = part_signal[end_timestep]
            step_context.start_noise = part_noise[end_timestep]
            step_context.start_signal = part_signal[end_timestep]
            step_context.signal_ratio = get_signal_ratio(start_timestep, end_timestep)
            step_context.start = start_timestep
            step_context.end = end_timestep
            step_context.forward_noise_total = forward_noise_total

            #print(step_context.signal_ratio, step_context.end_signal, step_context.end_noise)

            sigratio = get_signal_ratio(start_timestep, end_timestep)
            #print("  S", ((2 - step_context.sqrt_signal) * part_noise[end_timestep] - part_noise[start_timestep] / sigratio).item())
            #print("1-S", ((step_context.sqrt_signal - 1) * part_noise[end_timestep] - part_noise[start_timestep] / sigratio).item())
            
            #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], diffusion_timesteps[step_index])
            #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], start_timestep)
            
            def predict_noise(latents, step=0):
                return noise_predictor(
                    latents.repeat(noise_predictor_batch_size, 1, 1, 1).to(noise_predictor_dtype),
                    index_interpolate(diffusion_timesteps, step_index + step).round().int(), 
                    encoder_hidden_states=all_penult_states.to(noise_predictor_dtype),
                    return_dict=False, 
                    added_cond_kwargs=added_cond_kwargs
                )[0]

            def standard_predictor(combiner):
                def _predict(latents, step=0):
                    predictions = predict_noise(latents, step)
                    return predictions, noise, combiner(predictions, noise)
                return _predict

            def constructive_predictor(combiner):
                def _predict(latents, step=0):
                    noised = step_by_noise(latents, noise, 0, index_interpolate(diffusion_timesteps, step_index + step).round().int())
                    predictions = predict_noise(noised, step)
                    return predictions, noise, combiner(latents, predictions, noise)
                return _predict

            
            def standard_diffusion_step(latents, noises, start, end):
                start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
                end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
                predictions, true_noise, combined_prediction = noises
                return step_by_noise(latents, combined_prediction, start_timestep, end_timestep)
            
            def stupid_simple_step(latents, noises, start, end):
                start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
                end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
                predictions, true_noise, combined_prediction = noises
                return stupid_simple_step_by_noise(latents, combined_prediction, start_timestep, end_timestep)

            def cfgpp_diffusion_step(choose_base, choose_combined):
                def _diffusion_step(latents, noises, start, end):
                    start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
                    end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
                    return cfgpp_step_by_noise(latents, choose_combined(noises), choose_base(noises), start_timestep, end_timestep)
                return _diffusion_step

            def tnr_diffusion_step(latents, noises, start, end):
                start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
                end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
                predictions, true_noise, combined_prediction = noises
                return tnr_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep)

            def tnrb_diffusion_step(latents, noises, start, end):
                start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
                end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
                predictions, true_noise, combined_prediction = noises
                return tnrb_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep)

            def constructive_step(latents, noises, start, end):
                start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
                end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
                predictions, true_noise, combined_prediction = noises
                return latents + combined_prediction
            
            def select_prediction(index):
                return lambda noises: noises[0][index]

            select_true_noise = lambda noises: noises[1]
            select_combined = lambda noises: noises[2]

            diffusion_method = method(step_context).lower()
            
            if diffusion_method == "standard":
                take_step = standard_diffusion_step
            if diffusion_method == "stupid":
                take_step = stupid_simple_step
            if diffusion_method == "cfg++":
                take_step = cfgpp_diffusion_step(select_prediction(0), select_combined)
            if diffusion_method == "tnr":
                take_step = tnr_diffusion_step
            if diffusion_method == "tnrb":
                take_step = tnrb_diffusion_step

            if diffusion_method == "cons":
                take_step = constructive_step
                get_derivative = constructive_predictor(combine_predictions(step_context))
            else:
                get_derivative = standard_predictor(combine_predictions(step_context))
            
            solver = solver_step(step_context)
            
            latents = solver(get_derivative, take_step, latents)
            
            if step_index < run_context.steps - 1 and diffusion_method != "cons":
                pred_original_sample = step_by_noise(latents, noise, diffusion_timesteps[step_index+1], diffusion_timesteps[-1])
                #pred_original_sample = step_by_noise(latents, noise, end_timestep, diffusion_timesteps[-1])
            else:
                pred_original_sample = latents
            
            #latents = step_by_noise(pred_original_sample, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index])
            #latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index])

            #latents = pred_original_sample
            
            if save_raw(step_context):
                save_raw_latents(pred_original_sample)
            if save_approximates(step_context):
                save_approx_decode(pred_original_sample, out_index)
                out_index += 1

            #if step_index > run_context.steps - 4:

        images_pil = pilify(pred_original_sample.to(device=decoder_device), decoder)
    
        for im in images_pil:
            display(im)

        if save_output(run_context):
            for n in range(len(images_pil)):
                images_pil[n].save(f"{settings_directory}/{n}_{run_id:05d}.png")


In [None]:
# # # save # # #

Path(daily_directory).mkdir(exist_ok=True, parents=True)
Path(f"{daily_directory}/{settings_id}_{run_id}").mkdir(exist_ok=True, parents=True)

for n in range(len(images_pil)):
    images_pil[n].save(f"{daily_directory}/{settings_id}_{run_id}/{n}.png")

In [None]:
steps = 1000
0.1 * 30 / steps, 0.01 * 30 / steps

In [None]:
forward_noise_schedule[-1].item(), forward_noise_schedule[0].item()

In [None]:
forward_signal_product[-1].item(), forward_signal_product[0].item()

In [None]:
part_signal[-1].item(), part_signal[0].item()

In [None]:
part_noise[-1].item(), part_noise[0].item()

In [None]:
get_signal_ratio(500, 510)

In [None]:
# TODO:
plot everything; figure out zeros formula from physical notes;
calculate sum of diff/correction term
refactor settings for combiner / method specification
ensure TNR-corrected CFG still working

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.plot(forward_noise_schedule, label="betas")
plt.plot(part_signal, label="signal")
plt.plot(part_noise, label="noise")
plt.legend();

In [None]:
#plt.plot(diffusion_timesteps, label="step")
s_c = 20
s = linspace_timesteps(s_c+1, 999, 0, 1)
plt.figure().set_figheight(12)
#plt.figure().set_figwidth(8)
#plt.plot([part_signal[s] for s in s], label="signal")
#plt.plot([part_noise[s] for s in s], label="noise")
#plt.plot([part_signal[s] + part_noise[s] for s in s], label="signal+noise")
tnr = [lerp(0.1 * 30 / s_c, 0.01 * 30 / s_c, (1 - (s_c - n - 1) / (s_c - 1))) for n in range(s_c)]
#plt.plot([get_signal_ratio(a,b) for a,b in pairs(s)], label="sqrt(alpha_bar_a / alpha_bar_b)")
#plt.plot([-1 + 1 / get_signal_ratio(a,b) for a,b in pairs(s)], label="-(1 - s_b / s_a)")
plt.plot([(-1 + 1 / get_signal_ratio(a,b))*0.6 for a,b in pairs(s)], label="(-(1 - s_b / s_a)) * 0.5")
plt.plot([(forward_noise_total[a] - forward_noise_total[b]) for a,b in pairs(s)], label="beta sum")
#plt.plot([-(part_noise[b] - part_noise[a] / get_signal_ratio(a,b)) for a,b in pairs(s)], label="-(n_b - n_a * s_b / s_a)")
#plt.plot([part_noise[b] for a,b in pairs(s)], label="n_b")
#plt.plot([part_noise[a] / get_signal_ratio(a,b) for a,b in pairs(s)], label="n_a * s_b / s_a")
plt.plot(tnr, label="hand-picked numbers")
total = [1 / get_signal_ratio(a,b) + part_noise[b] - part_noise[a] / get_signal_ratio(a,b) for a,b in pairs(s)]
#plt.plot(total, label="s_b / s_a + (n_b - n_a * s_b / s_a)")
#plt.plot([1-n for n in total], label="1 - s_b / s_a + (n_b - n_a * s_b / s_a)")
#plt.plot([a+b for a,b in zip(tnr, total)], label="total w/ tnr")

beta = forward_noise_schedule
alpha = 1 - beta
alpha_bar = alpha.cumprod(dim=0)
sqrt_alpha_bar = alpha_bar.sqrt()

select = lambda l: [l[s] for s in s[:-1]]

plt.plot(select(beta*7), label="beta * 7")
#plt.plot(select(alpha), label="alpha")
#plt.plot(select(alpha_bar), label="alpha_bar")
#plt.plot(select(sqrt_alpha_bar), label="sqrt(alpha_bar)")
plt.plot(sarbs, label="better hand-picked numbers")
#plt.plot([1-s for s in sarbs], label="1-sarbs")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5));

In [None]:
torch.manual_seed(999)
#torch.manual_seed(234235333)

latents = torch.zeros(
    (1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
    device=main_device,
    dtype=main_dtype
)

result = torch.zeros(
    (1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
    device=main_device,
    dtype=main_dtype
)

noises = torch.randn(
    #(run_context.steps, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
    (1, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
    device=main_device,
    dtype=main_dtype
)

latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[0])

def p(latents, step):
    return noise_predictor(
        latents.repeat(noise_predictor_batch_size, 1, 1, 1).to(noise_predictor_dtype),
        step, 
        encoder_hidden_states=all_penult_states.to(noise_predictor_dtype),
        return_dict=False, 
        added_cond_kwargs=added_cond_kwargs
    )[0]

sho = lambda a: plt.imshow(a[0].flatten(0,1).t().cpu());
means = lambda t: [(a[0].item(), a[1].item()) for a in (torch.std_mean(t) for t in t)]



In [None]:


n = 999
n_prev = 999
s_size = 50
foo = 20

arbitrary_numbers = [25,10,8,2,0.5,0.4,0.3,0.3,0.2,0.1]
#arbitrary_numbers = [25,15,10,9,8,4,2,1,0.5,0.45,0.4,0.35,0.3,0.3,0.3,0.25,0.2,0.15,0.1,0.075]
arbitrary_numbers = [a/2 for a in arbitrary_numbers]
sarbs = [lerp(2/foo,0.8/foo,(n/(foo-1))**0.5) for n in range(foo)]

result *= 0

bsum = forward_noise_schedule.cumsum(0)

for x in range(foo):
    n_prev = n
    n -= s_size
    if n < 0:
        n = 0
    
    prediction = p(step_by_noise(result, noises[0], 0, n), n)
    
    diff = noises[0] - prediction
    diff_std = diff.std()
    #print("diff std", diff_std)
    #images_pil = pilify(diff.mul(4 * vae_scale/diff_std).to(device=decoder_device), decoder)
    
    #for im in images_pil:
    #    display(im)
    #result = result * 2 + diff * (2 * ((n) / 999))
    #result = result + (diff) * (part_noise[max(n-50, 0)]) / 2
    #result = result + diff * (part_signal[max(n-s_size,0)] - part_signal[n]) / diff_std
    #print("arb", arbitrary_numbers[x])
    #print("sarb", (arbitrary_numbers[x] * diff_std).item())
    #print("sarb", sarbs[x])
    #result = result + diff * sarbs[x] / diff_std
    #print((bsum[n_prev] - bsum[n]).item())
    result = result + diff * 2 * (bsum[n_prev] - bsum[n]) #/ diff_std
    #result = result + diff * 12 * forward_noise_schedule[n_prev] #/ diff_std
    #result = result + diff * arbitrary_numbers[x]#step_by_noise(result, diff/2, n, 0)
    #res_std = result.abs().max()#.std()
    #result /= res_std
    #latents = step_by_noise(result, noises[0], 0, n)

    if False:
        images_pil = pilify(result.to(device=decoder_device), decoder)
    
        for im in images_pil:
            display(im)

    #print(n)
    if n == 0:
        break

#diff = noises[0] - prediction

#result = result + diff * arbitrary_numbers[x]#step_by_noise(result, diff/2, n, 0)
        
images_pil = pilify(result.to(device=decoder_device), decoder)

for im in images_pil:
    display(im)
    

In [None]:
plt.plot(sarbs)