In [None]:
print("HELO")

In [None]:
from icecream import ic
import numpy as np
import rp
import torch
import torch.nn as nn
import source.stable_diffusion as sd
from easydict import EasyDict
from source.learnable_textures import (LearnableImageFourier,
                                       LearnableImageFourierBilateral,
                                       LearnableImageRaster,
                                       LearnableImageRasterBilateral,
                                       LearnableTexturePackFourier,
                                       LearnableTexturePackRaster)

In [None]:
if 's' not in dir():
    model_name="CompVis/stable-diffusion-v1-4"
    gpu='cuda:0'
    s=sd.StableDiffusion(gpu,model_name)
device=s.device

In [None]:
class BaseLabel:
    def __init__(self, name:str, embedding:torch.Tensor):
        #Later on we might have more sophisticated embeddings, such as averaging multiple prompts
        #We also might have associated colors for visualization, or relations between labels
        self.name=name
        self.embedding=embedding
        
    def get_sample_image(self):
        with torch.no_grad():
            output=s.embeddings_to_imgs(self.embedding)[0]
        assert rp.is_image(output)
        return output
            
    def __repr__(self):
        return '%s(name=%s)'%(type(self).__name__,self.name)
        
class SimpleLabel(BaseLabel):
    def __init__(self, name:str):
        super().__init__(name, s.get_text_embeddings(name).to(device))

class NegativeLabel(BaseLabel):
    def __init__(self, name:str, negative_prompt=''):
        
        if '---' in name:
            #You can use '---' in a prompt to specify the negative part
            name,additional_negative_prompt=name.split('---',maxsplit=1)
            negative_prompt+=' '+additional_negative_prompt
            
        self.negative_prompt=negative_prompt
        old_uncond_text=s.uncond_text
        try:
            s.uncond_text=negative_prompt
            embedding = s.get_text_embeddings(name).to(device)
            super().__init__(name, embedding)
        finally:
            s.uncond_text=old_uncond_text

In [None]:
#ONLY GOOD PROMPTS HERE
good_prompts = EasyDict(
    kitten_in_box = 'A orange cute kitten in a cardboard box in times square',
    botw_landscape = 'The Legend of Zelda landscape atmospheric, hyper realistic, 8k, epic composition, cinematic, octane render, artstation landscape vista photography by Carr Clifton & Galen Rowell, 16K resolution, Landscape veduta photo by Dustin Lefevre & tdraw, 8k resolution, detailed landscape painting by Ivan Shishkin, DeviantArt, Flickr, rendered in Enscape, Miyazaki, Nausicaa Ghibli, Breath of The Wild, 4k detailed post processing, artstation, rendering by octane, unreal engine —ar 16:9',
    magic_emma_watson = 'ultra realistic photo portrait of Emma Watson cosmic energy, colorful, painting burst, beautiful symmetrical face, nonchalant kind look, realistic round eyes, tone mapped, intricate, elegant, highly detailed, digital painting, artstation, concept art, smooth, sharp focus, illustration, dreamy magical atmosphere, art by artgerm and greg rutkowski and alphonse mucha, 4k, 8k',
    yorkshire_terrier_santa = 'Insanely detailed studio portrait shot photo of intricately detailed beautiful yorkshire terrier dressed as santa clause, smirking mischievously at the camera with mischievous detailed yellow green eyes , very detailed, rim light, photo, rim light, ultra-realistic, photorealistic, hyper detailed, photography, shot on Canon DSLR, f/2. 8 , photography by Felix Kunze and Annie Leibovitz and retouched by Pratik Naik',
    norwegian_winter_girl = 'professional portrait photograph of a gorgeous Norwegian girl in winter clothing with long wavy blonde hair, freckles, gorgeous symmetrical face, cute natural makeup, wearing elegant warm winter fashion clothing, ((standing outside in snowy city street)), mid shot, central image composition, (((professionally color graded))), (((bright soft diffused light)))',
    magic_forest_temple = '8 k concept art from a hindu temple lost in the jungle by david mattingly and samuel araya and michael whelan and dave mckean and richard corben. realistic matte painting with photorealistic hdr volumetric lighting. composition and layout inspired by gregory crewdson. ',
    sailing_ship = 'a big sailing ship in heavy sea, hypermaximalistic, high details, cinematic, 8k resolution, beautiful detailed, insanely intricate details, artstation trending, octane render, unreal engine',
    bioshock_lighthouse = 'giant standalone lighthouse from bioshock infinite in england 1 9 century, half - ruined, covered by mold, staying in 2 kilometers far from a coast, opposite the dark cave - crack of giant rocks. when you see this lighthouse it makes you anxious. deep ones is living under this. view from sea, and view from the coast, by greg rutkowski',
    two_bunnys_hugging = 'photo of bunny hugging another bunny, dramatic light, pale sunrise, cinematic lighting',
    thomas_tank_military = 'thomas the tank engine as a military tank, intricate, highly detailed, centered, digital painting, artstation, concept art, smooth, sharp focus, illustration, artgerm, tomasz alen kopera, peter mohrbacher, donato giancola, joseph christian leyendecker, wlop, boris vallejo',
    wolf_on_rock = 'a wolf with a tail, standing heroically on a rock. adventurous, new adventure, with a tail, forest, rocks, stream, ripples, tribal armor, female, wolf wolf wolf, atmospheric lighting, stunning, brave. by makoto shinkai, stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, krenz cushart, sakimichan, d & d trending on artstation, digital art. ',
    lolita_dress_girl = 'lolita dress, angelic pretty, award winning photograph trending on artstation',
    lolita_dress_magical_elf = 'lolita dress, angelic pretty, portrait of magical lolita woman elf elven,  hyperrealism photography hdr 4k 3d, dreamy and ethereal, fantasy, intricate, elegant, many rainbow bubbles, rose tones, highly detailed, artstation, concept art, cyberpunk wearing, smooth, sharp focus, illustration, art by artgerm and greg rutkowskiand alphonse mucha',
    pencil_giraffe_head = 'an intricate detailed hb pencil sketch of a giraffe head',
    pencil_penguin = 'an intricate detailed hb pencil sketch of a penguin',
    pencil_violin = 'an intricate detailed hb pencil sketch of a violin',
    pencil_orca_whale = 'an orca whale spouting water intricate detailed hb pencil sketch of an black white spotted orca whale',
    pencil_cow = 'an intricate detailed hb pencil sketch of a black white spotted cow',
    pencil_walrus = 'an intricate detailed hb pencil sketch of a walrus',
    pencil_cat_head = 'an sketch of a cat head',
    ape_with_gun = 'detailed science - fiction character portrait of a silverback gorilla shooting a alien gun in space, intricate, wild, highly detailed, digital painting, artstation, concept art, smooth, sharp focus, illustration, art by artgerm and greg rutkowski and alphonse mucha',
    human_skeleton = 'weta disney pixar movie still macro close photo of a skeleton with triopan cones for hands. his hands are triopan cones. : : by weta, greg rutkowski, wlop, ilya kuvshinov, rossdraws, artgerm, octane render, iridescent, bright morning, anime, liosh, mucha : :',
    gold_coins = 'an old wooden table covered in gold coins and treasure, detailed oil painting, trending on Artstation',
    golf_ball_in_forest = 'photo of a golf ball in a magical forest. dof. Bokeh. By greg rutkowski. Nikon D850. Award winning',
    bear_in_forest = 'photo of a brown bear attacking the camera. Nikon D850. Award winning. Scary teeth claws full body shot cinematic movie',
    elephant_in_circus = 'photo of a elephant in a magical circus. dof. Bokeh. By greg rutkowski. Nikon D850. Award winning.',
    mickey_mouse = 'mickey mouse oil on canvas, artstation trending',
    mushroom = 'a mushroom in a magical forest. dof. Bokeh. By greg rutkowski. Nikon D850. Award winning',
    mario = 'mario 3d nintendo video game',
    burger = 'big juicy hamburger with cheese and tomato and lettuice. Sesame seed bun. Advertisement beautiful dlsr hdr bokeh. ',
    darth_vader = 'photo of a ultra realistic darth vader dramatic light, muscle, cinematic lighting, battered, low angle, static, 4k, hyper realistic, focused, extreme details, bokeh blackground, cinematic, masterpiece, intricate artwork, details,',
    gandalf = 'Gandalf the Grey Wizard in Moonlight by Alan Lee, Glowing staff, full body concept art, intricate clothing, micro detail, octane render, 4K, art station',
    fantasy_city = 'an ultra detailed matte painting of the quaint capital city of galic, grid shaped city cobblestone streets, fantasy city, light snowfall, wind, inspiring renaissance architecture, ultrawide lense, aerial photography, unreal engine, exquisite detail, 8 k, art by greg rutkowski and alphonse mucha',
    green_elf_girl = 'a highly detailed portrait painting of a beautiful healer elf female male, long brown hair with braids and green highlights, long elf ears, asian decent, facial tribal markings, by greg rutkowski and alphonse mucha, sharp focus, matte, concept art, artstation, digital painting',
    pikachu = 'Manga cover illustration of an extremely cute and adorable beautiful pikachu running through a flower field, summer vibrance, 3d render diorama by hayao miyazaki, official studio ghibli still, color graflex macro photograph, pixiv, daz studio 3d',
    spring = 'Photographic spring season, artstation trending Nikon D850. Award winning. A bee on a flower.',
    fall = 'Photographic fall season, artstation trending Nikon D850. Award winning. Giant Orange maple leaf. Leaf pile. Pumpkins. Halloween.',
    winter = 'Photographic winter season, artstation trending Nikon D850. Award winning. Snow and mountains. Snowman and log cabin. Snowflakes. Ice. Icicles. Cold.',
    summer = 'Photographic summer season, artstation trending Nikon D850. Award winning. Sun. Hot. Beach. Desert sand. Picnic. Umbrella from sun. ',
    miku = 'Hatsune miku, gorgeous, amazing, elegant, intricate, highly detailed, digital painting, artstation, concept art, sharp focus, illustration, art by ross tran',
    pyramids = 'An anthropomorphic beautiful great futuristic pyramid civilisation in a desert, gold, sphinx, dungeon temple, fine art, award winning, intricate, elegant, sharp focus, octane render, hyperrealistic, cinematic lighting, highly detailed, digital painting, 8 k concept art, art by jamie hewlett and z. w. gu, masterpiece, trending on artstation, 8 k',
    dinosaur = 'A t - rex in star wars, movie still frame, hd, remastered, cinematic lighting',
    lipstick = 'lipstick. product photo. glamour photography. 2 0 1 8. ',
)

In [None]:
rp.display_image(NegativeLabel(
    good_prompts.green_elf_girl
).get_sample_image())

In [None]:
#ONLY GOOD PROMPTS HERE
prompt_w = good_prompts.thomas_tank_military
prompt_w = good_prompts.human_skeleton

prompt_y = good_prompts.kitten_in_box
prompt_y = good_prompts.darth_vader
prompt_y = good_prompts.mickey_mouse

prompt_x = good_prompts.pikachu

prompt_z = good_prompts.mario
prompt_z = good_prompts.norwegian_winter_girl


# prompt_a = good_prompts.gandalf

In [None]:
#ONLY GOOD PROMPTS HERE
prompt_w = good_prompts.summer
prompt_y = good_prompts.fall
prompt_x = good_prompts.winter
prompt_z = good_prompts.spring

In [None]:

prompt_w, prompt_x, prompt_y, prompt_z = rp.gather(good_prompts, 'miku pyramids dinosaur lipstick'.split())

In [None]:
negative_prompt = ''
label_w = NegativeLabel(prompt_w,negative_prompt)
label_x = NegativeLabel(prompt_x,negative_prompt)
label_y = NegativeLabel(prompt_y,negative_prompt)
label_z = NegativeLabel(prompt_z,negative_prompt)

In [None]:
print("Factors")
rp.display_image(label_w.get_sample_image())
rp.display_image(label_x.get_sample_image())
rp.display_image(label_y.get_sample_image())
rp.display_image(label_z.get_sample_image())

In [None]:
#Parameters (this section takes vram)

#Select Learnable Image Type:
learnable_image_maker = lambda:LearnableImageFourier().to(s.device)
# learnable_image_maker = lambda:LearnableImageFourier(height=512,width=512,num_features=512,hidden_dim=512,scale=20).to(s.device)

factor_base=learnable_image_maker()
factor_rotator=learnable_image_maker()

In [None]:
brightness=3

CLEAN_MODE=False
def simulate_overlay(bottom, top):
    if CLEAN_MODE:
        exp=1
        brightness=3
        black=0
    else:
        exp=rp.random_float(.5,1)
        brightness=rp.random_float(1,5)
        black=rp.random_float(0,.5)
        bottom=rp.blend(bottom,black,rp.random_float())
        top=rp.blend(top,black,rp.random_float())
    return (bottom**exp * top**exp * brightness).clamp(0,99).tanh()

learnable_image_w=lambda: simulate_overlay(factor_base(), factor_rotator().rot90(k=0,dims=[1,2]))
learnable_image_x=lambda: simulate_overlay(factor_base(), factor_rotator().rot90(k=1,dims=[1,2]))
learnable_image_y=lambda: simulate_overlay(factor_base(), factor_rotator().rot90(k=2,dims=[1,2]))
learnable_image_z=lambda: simulate_overlay(factor_base(), factor_rotator().rot90(k=3,dims=[1,2]))

from itertools import chain

params=chain(
    factor_base.parameters(),
    factor_rotator.parameters(),
)
optim=torch.optim.SGD(params,lr=1e-4)

In [None]:
num=4
nums=[0,1,2,3]
# nums=[0  ,2,3]
# nums=[    2  ]
# nums=[0,1,2]
# nums=[1]
# nums=[0,1]


labels=[label_w,label_x,label_y,label_z]
learnable_images=[learnable_image_w,learnable_image_x,learnable_image_y,learnable_image_z]
weights=[1,1,3,1]

labels=[labels[i] for i in nums]
learnable_images=[learnable_images[i] for i in nums]
weights=[weights[i] for i in nums]

weights=rp.as_numpy_array(weights)
weights=weights/weights.sum()
weights=weights*len(weights)

In [None]:
pims=[]
ims=[]

In [None]:
CLEAN_MODE=True

In [None]:
NUM_ITER=100000
s.max_step=MAX_STEP=990
# s.min_step=MIN_STEP=450
s.min_step=MIN_STEP=10

et=rp.eta(NUM_ITER)

# folder='sd_previewer_results2/'+prompt[:100]+rp.random_namespace_hash()
# rp.make_folder(folder)

for iter_num in range(NUM_ITER):
    
    step = rp.blend(MAX_STEP,MIN_STEP,iter_num/NUM_ITER)
    # s.min_step = s.max_step = int(step)
    
    et(iter_num)

    # image=learnable_image()
    # variants=list(get_variants(image,label))
    # num_variants=len(variants)
    
    preds=[]
    for label,learnable_image,weight in rp.random_batch(list(zip(labels,learnable_images,weights)),1):
        pred=s.train_step(
            label.embedding,
            learnable_image()[None],

            #PRESETS (uncomment one):
            noise_coef=.1*weight,guidance_scale=60,#10
            # noise_coef=0,image_coef=-.01,guidance_scale=50,
            # noise_coef=0,image_coef=-.005,guidance_scale=50,
            # noise_coef=.1,image_coef=-.010,guidance_scale=50,
            # noise_coef=.1,image_coef=-.005,guidance_scale=50,
            # noise_coef=.1*weight, image_coef=-.005*weight, guidance_scale=50,
        )
        preds+=list(pred)

    with torch.no_grad():
        if not iter_num%(200*100):
            from IPython.display import clear_output
            clear_output()
        # if not iter_num%20:
        if not iter_num%200:
            im=rp.tiled_images(
                [
                    *[rp.as_numpy_image(image()) for image in learnable_images],
                    rp.as_numpy_image(factor_base()),
                    rp.as_numpy_image(factor_rotator()),
                ],
                length=len(learnable_images),
                border_thickness=0,
            )
            ims.append(im)
            # rp.save_image(im,folder+'/%06i.png'%iter_num)
            rp.display_image(im)
            
        if False and not iter_num%200:
            pim=rp.tiled_images([
                *rp.as_numpy_images(s.decode_latents(torch.stack(preds))),
            ])
            pims.append(pim)
            rp.display_image(pim)
            
    optim.step()
    optim.zero_grad()

In [None]:
def save_run(name):
    folder="untracked/rotator_multiplier_runs/%s"%name
    if rp.path_exists(folder):
        # folder+='_'+rp.random_namespace_hash(4)
        import time
        folder+='_%i'%time.time()
    rp.make_directory(folder)
    pims_names=['pims_%04i.png'%i for i in range(len(pims))]
    ims_names=['ims_%04i.png'%i for i in range(len(ims))]
    with rp.SetCurrentDirectoryTemporarily(folder):
        rp.save_images(pims,pims_names,show_progress=True)
        rp.save_images(ims,ims_names,show_progress=True)
    
    
    
save_run('dino_miku_lipstick_pyramid')

In [None]:
im=rp.tiled_images([
    *rp.as_numpy_array(rp.as_numpy_images(preds),)
])
# rp.save_image(im,folder+'/%06i.png'%iter_num)
rp.display_image(im)

In [None]:
torch.stack(preds).shape

In [None]:
torch.cuda.empty_cache()

In [None]:
rp.display_image(ims[-1])

In [None]:
rp.display_image_slideshow(ims)

In [None]:
out=rp.save_video_mp4(ims,'videos/%s.mp4'%rp.random_namespace_hash())
print(out)