In [None]:
from icecream import ic
import numpy as np
import rp
import torch
import torch.nn as nn
import source.stable_diffusion as sd
from easydict import EasyDict
from source.learnable_textures import LearnableImageFourier
from source.stable_diffusion_labels import NegativeLabel
from itertools import chain
import time

In [None]:
#ONLY GOOD PROMPTS HERE
example_prompts = rp.load_yaml_file('source/example_prompts.yaml')

#These prompts are all strings - you can replace them with whatever you want! By default it lets you choose from example prompts
prompt_w, prompt_x, prompt_y, prompt_z = rp.gather(example_prompts, 'miku froggo lipstick pyramids'.split())

negative_prompt = ''

print('Example prompts:', ', '.join(example_prompts))
print()
print('Negative prompt:',repr(negative_prompt))
print()
print('Chosen prompts:')
print('    prompt_w =', prompt_w)
print('    prompt_x =', prompt_x)
print('    prompt_y =', prompt_y)
print('    prompt_z =', prompt_z)

In [None]:
if 's' not in dir():
    model_name="CompVis/stable-diffusion-v1-4"
    gpu='cuda:0'
    s=sd.StableDiffusion(gpu,model_name)
device=s.device

In [None]:
label_w = NegativeLabel(prompt_w,negative_prompt)
label_x = NegativeLabel(prompt_x,negative_prompt)
label_y = NegativeLabel(prompt_y,negative_prompt)
label_z = NegativeLabel(prompt_z,negative_prompt)

In [None]:
#Parameters (this section takes vram)

#Select Learnable Image Size (this has big VRAM implications!):
learnable_image_maker = lambda: LearnableImageFourier(height=256, width=256, hidden_dim=256, num_features=128).to(s.device); SIZE=256
# learnable_image_maker = lambda: LearnableImageFourier(height=512,width=512,num_features=256,hidden_dim=256,scale=20).to(s.device);SIZE=512

factor_base=learnable_image_maker()
factor_rotator=learnable_image_maker()

In [None]:
brightness=3

CLEAN_MODE = True # If it's False, we augment the images by randomly simulating how good a random printer might be when making the overlays...

def simulate_overlay(bottom, top):
    if CLEAN_MODE:
        exp=1
        brightness=3
        black=0
    else:
        exp=rp.random_float(.5,1)
        brightness=rp.random_float(1,5)
        black=rp.random_float(0,.5)
        bottom=rp.blend(bottom,black,rp.random_float())
        top=rp.blend(top,black,rp.random_float())
    return (bottom**exp * top**exp * brightness).clamp(0,99).tanh()

learnable_image_w=lambda: simulate_overlay(factor_base(), factor_rotator().rot90(k=0,dims=[1,2]))
learnable_image_x=lambda: simulate_overlay(factor_base(), factor_rotator().rot90(k=1,dims=[1,2]))
learnable_image_y=lambda: simulate_overlay(factor_base(), factor_rotator().rot90(k=2,dims=[1,2]))
learnable_image_z=lambda: simulate_overlay(factor_base(), factor_rotator().rot90(k=3,dims=[1,2]))


params=chain(
    factor_base.parameters(),
    factor_rotator.parameters(),
)
optim=torch.optim.SGD(params,lr=1e-4)

In [None]:
num=4
nums=[0,1,2,3]

#Uncommenting one of the lines will disable some of the prompts, in case you don't want to use all four for some reason (like the Summer/Winter example)
# nums=[0  ,2,3]
# nums=[    2  ]
# nums=[0,1,2]
# nums=[1]
# nums=[0,1]
# nums=[0,2]


labels=[label_w,label_x,label_y,label_z]
learnable_images=[learnable_image_w,learnable_image_x,learnable_image_y,learnable_image_z]

#The weight coefficients for each prompt. For example, if we have [0,1,2,1], then prompt_w will provide no influence and prompt_y will have 1/2 the total influence
weights=[1,1,1,1]

labels=[labels[i] for i in nums]
learnable_images=[learnable_images[i] for i in nums]
weights=[weights[i] for i in nums]

weights=rp.as_numpy_array(weights)
weights=weights/weights.sum()
weights=weights*len(weights)

In [None]:
#For saving a timelapse
ims=[]

In [None]:
def get_display_image():
    return rp.tiled_images(
        [
            *[rp.as_numpy_image(image()) for image in learnable_images],
            rp.as_numpy_image(factor_base()),
            rp.as_numpy_image(factor_rotator()),
        ],
        length=len(learnable_images),
        border_thickness=0,
    )

In [None]:
NUM_ITER=10000

#Set the minimum and maximum noise timesteps for the dream loss (aka score distillation loss)
s.max_step=MAX_STEP=990
s.min_step=MIN_STEP=10 

display_eta=rp.eta(NUM_ITER, title='Status: ')

DISPLAY_INTERVAL = 200

print('Every %i iterations we display an image in the form [[image_w, image_x, image_y, image_z], [bottom_image, top_image]] where')
print('    image_w = bottom_image * top_image')
print('    image_x = bottom_image * top_image.rot90()')
print('    image_y = bottom_image * top_image.rot180()')
print('    image_z = bottom_image * top_image.rot270()')
print()
print('Interrupt the kernel at any time to return the currently displayed image')

try:
    for iter_num in range(NUM_ITER):
        display_eta(iter_num) #Print the remaining time

        preds=[]
        for label,learnable_image,weight in rp.random_batch(list(zip(labels,learnable_images,weights)),1):
            pred=s.train_step(
                label.embedding,
                learnable_image()[None],

                #PRESETS (uncomment one):
                noise_coef=.1*weight,guidance_scale=60,#10
                # noise_coef=0,image_coef=-.01,guidance_scale=50,
                # noise_coef=0,image_coef=-.005,guidance_scale=50,
                # noise_coef=.1,image_coef=-.010,guidance_scale=50,
                # noise_coef=.1,image_coef=-.005,guidance_scale=50,
                # noise_coef=.1*weight, image_coef=-.005*weight, guidance_scale=50,
            )
            preds+=list(pred)

        with torch.no_grad():
            if iter_num and not iter_num%(DISPLAY_INTERVAL*50):
                #Wipe the slate every 50 displays so they don't get cut off
                from IPython.display import clear_output
                clear_output()

            if not iter_num%DISPLAY_INTERVAL:
                im = get_display_image()
                ims.append(im)
                rp.display_image(im)

        optim.step()
        optim.zero_grad()
except KeyboardInterrupt:
    print()
    print('Interrupted early at iteration %i'%iter_num)
    im = get_display_image()
    ims.append(im)
    rp.display_image(im)

In [None]:
print('Bottom image:')
rp.display_image(rp.as_numpy_image(factor_base()))

print('Top image:')
rp.display_image(rp.as_numpy_image(factor_rotator()))

In [None]:
def save_run(name):
    folder="untracked/rotator_multiplier_runs/%s"%name
    if rp.path_exists(folder):
        folder+='_%i'%time.time()
    rp.make_directory(folder)
    ims_names=['ims_%04i.png'%i for i in range(len(ims))]
    with rp.SetCurrentDirectoryTemporarily(folder):
        rp.save_images(ims,ims_names,show_progress=True)
    print()
    print('Saved timelapse to folder:',repr(folder))
    
save_run('untitled') #You can give it a good custom name if you want!