# Notebook Core Params

In [None]:
num_split=0 #Which notebook core is this?
num_splits=3 #How many notebooks are running? 
device_num=0 #Which GPU

# Dataset

In [None]:
import rp
from icecream import ic
from easydict import EasyDict
import copy

In [None]:
class TestSample:
    def __init__(self, image, prompt:str, path:str, name:str):
        self.path=rp.get_absolute_path(path)
        self.image=image
        self.prompt=prompt
        self.name=name
    
    def copy(self):
        return copy.copy(self)
    
    def display(self):
        print(self.prompt)
        print(self.path)
        rp.display_image(self.image)

In [None]:
class CroppedCocoDataset:
    
    def __init__(self, samples=None, prompt_replacements:dict={}):
        
        #Xiang made this dataset
        dataset_dir='/nfs/ws1/datasets/RefCOCO'
        self.dataset_dir = dataset_dir
        
        if samples is None:
            with rp.SetCurrentDirectoryTemporarily(dataset_dir):
#                 # image_names=rp.text_file_to_string('refcoco_files.txt').strip().splitlines()
#                 image_names=rp.get_all_image_files(relative=True)
#                 #It has a bunch of lines that look like '2008_003885\n2008_004212\n2008_004612\n2008_004621' etc
#                 image_paths=image_names#['cropped-'+x+'.jpg' for x in image_names]
#                 #These files should exist in the current directory
#                 images=rp.load_images(image_paths, use_cache=True, show_progress=True)

#                 assert all(rp.is_image_file(x) for x in image_paths)

                #Get the prompts for every image
                lines=rp.text_file_to_string('cropped.txt').strip().splitlines()
                print(rp.line_join(lines[:5]))
                #It has a bunch of lines that look like '2010_000241 bird\n2010_000342 bicycle\n2010_000628 car' etc
                prompts    =[x.split(maxsplit=1)[1] for x in lines]
                image_names=[x.split(maxsplit=1)[0] for x in lines]
                # prompts={image_name:prompt for image_name,prompt in zip(image_names,prompts)}
                image_paths=[rp.get_absolute_path(x) for x in image_names]
                images=rp.load_images(image_paths, use_cache=True, show_progress=True)
                

                samples=[TestSample(image,prompt,path,name) for image,prompt,path,name in zip(images,prompts,image_paths,image_names)]

        self.samples = samples
        
    @property
    def images(self): return [s.image for s in self.samples]
        
    @property
    def prompts(self): return [s.prompt for s in self.samples]

    @property
    def names(self): return [s.name for s in self.samples]

    @property
    def image_paths(self): return [s.image_path for s in self.samples]

    def __getitem__(self, index):
        output = self.samples[index]
        if isinstance(index, int):
            return output
        elif isinstance(index, slice):
            output = type(self)(output)
        return output
            
    def split(self, num_divisions, division_index):
        """
        Split the dataset into num_divisions parts and return the division_index part.
        
        num_divisions: An integer representing the number of equal parts to divide the dataset into.
        division_index: An integer representing the 0-indexed part to return.
        """
        division_size = len(self) // num_divisions
        start_index = division_size * division_index
        end_index = start_index + division_size
        if division_index == num_divisions - 1:
            end_index = len(self)
        return self[start_index:end_index]
    
    def __len__(self):
        return len(self.samples)

    def __iter__(self):
        return iter(self.samples)
    
    def __repr__(self):
        return f'CroppedCocoDataset(len={len(self)})'

In [None]:
# data=cropped_ref_voc_dataset.split(num_splits,num_split)

In [None]:
cropped_coco_dataset=CroppedCocoDataset()
data=cropped_coco_dataset.split(num_splits,num_split)

In [None]:
set(cropped_coco_dataset.prompts)

In [None]:
data[3:4]

In [None]:
rp.random_element(data).display()
rp.random_element(data).display()
rp.random_element(data).display()
rp.random_element(data).display()
rp.random_element(data).display()
rp.random_element(data).display()

# Peekaboo

In [None]:
import source.stable_diffusion as sd
if 's' not in dir():
    s=sd.StableDiffusion(device=device_num) #Initialize the singleton

In [None]:
import rp
import source.peekaboo as peekaboo
from source.peekaboo import run_peekaboo
from source.clip import get_clip_logits
from source.stable_diffusion_labels import get_mean_embedding, BaseLabel, SimpleLabel, MeanLabel
import torch
torch.cuda.set_device(device_num)

In [None]:
#Mario vs Luigi Part 1
results_collection=[]

In [None]:
def get_clip_logits_per_image(prompt, images):
    return [get_clip_logits(image, [prompt])[0] for image in images]

def random_colors(length=100):
    return [rp.random_rgb_float_color() for _ in range(length)]

def get_score(foreground, alpha, prompt:str, colors:list):
    
    alpha=alpha>.5
    rp.display_image(rp.blend_images(rp.random_rgb_float_color(),rp.cv_resize_image(foreground,rp.get_image_dimensions(alpha)),alpha))
    alpha=rp.cv_dilate(alpha,5,circular=True)
    
    assert rp.is_image(foreground) and rp.is_image(alpha)
    images = [rp.blend_images(foreground, color, alpha) for color in colors]
    scores = get_clip_logits_per_image(prompt, images)
    score = rp.mean(scores)
    return score

In [None]:
def ranked_results(results):
    colors = random_colors()
    def score(result):
        output = get_score(foreground = result.image,
                         alpha = result.alphas[0],
                         prompt = result.p_name,
                         colors = colors,
                        )
        print(output)
        return output
    scores = list(map(score,results))
    scores, results = rp.sync_sort(scores, results)
    
    #First is best
    scores, results = scores[::-1], results[::-1]
    
    return scores, results

def display_ranked_results(scores, results):
    for score, result in zip(scores, results):
        rp.display_image(
            rp.labeled_image(
                rp.horizontally_concatenated_images(
                    rp.cv_resize_image(result.image, (256, 256)), result.alphas[0]
                ),
                "%s : %f" % (result.p_name, score),
            )
        )

In [None]:
thumbnails=[]
def get_result_thumbnail(results):
    alpha, image = results.alphas[0], results.image
    height = width = 512 
    image = rp.cv_resize_image(image, (height, width))
    alpha = rp.cv_resize_image(alpha, (height, width),interp='nearest')
    method = " + ".join(
        [
            *(["CLIP"] if results.clip_coef else []),
            *(["StableDreamLoss"] if results.use_stable_dream_loss else []),
        ]
    )
    path = rp.get_relative_path(results.output_folder)
    name = results.p_name
    settings=[
        'representation',
        'LEARNING_RATE',
        'NUM_ITER',
        'GRAVITY',
        'clip_coef',
        'use_stable_dream_loss',
        'GUIDANCE_SCALE',
        'min_step',
        'max_step',
        'clip_coef',
    ]
    settings=[x+': '+str(results[x]) for x in settings]
    
    
    text=rp.line_join([name, ' ', method, ' ', path, '', 'Settings:',*settings])
    text=rp.wrap_string_to_width(text,60)
    text = rp.cv_text_to_image(text,monospace=False)
    text = rp.resize_image_to_fit(text, height, width)
    image=rp.horizontally_concatenated_images(image,alpha,text)
    
    out_dir='thumbnails_data8'
    rp.make_directory(out_dir)
    out_name='%s ____ %s ____ %i.png'%(method, name,len(rp.get_all_files(out_dir)))
    out_path=rp.path_join(out_dir,out_name)
    rp.save_image(image,out_path)
    
    thumbnails.append(image)
    rp.display_image(image)
    return image

In [None]:
experiment_input_data = [
    [ [x.prompt], [x.path] , x] for x in data.samples
]

In [None]:
experiment_setting_presets = {
    'clip_raster_bilateral':dict(
        representation="raster bilateral",
        LEARNING_RATE=1e-0,
        NUM_ITER=100,
        GRAVITY=0.05,
        clip_coef=500,
        use_stable_dream_loss=False,
    ),
    'raster_bilateral':dict( #This one is also good!
        representation="raster bilateral",
        LEARNING_RATE=1e-0,
        GUIDANCE_SCALE=200,
        NUM_ITER=100,
        GRAVITY=0.05,
        # min_step=10,
        # max_step=600,
    ),
    # dict(
    #     representation="raster bilateral",
    #     LEARNING_RATE=1e-0,
    #     GUIDANCE_SCALE=200,
    #     # NUM_ITER=500,
    #     GRAVITY=0.05,
    #     min_step=200,
    #     max_step=400,
    # ),
    # dict(
    #     representation="raster bilateral",
    #     LEARNING_RATE=1e-0,
    #     # NUM_ITER=500,
    #     GRAVITY=0.05,
    #     clip_coef=500,
    #     use_stable_dream_loss=True,
    #     GUIDANCE_SCALE=200,
    #     min_step=10,
    #     max_step=600,
    # ),
    # dict(NUM_ITER=500),
    # 'pure_fourier':dict(representation="fourier"),
    # 'default':dict(),
    # 'pure_raster':dict(representation="raster", LEARNING_RATE=1, GUIDANCE_SCALE=200,GRAVITY=.1),
}

In [None]:
# experiment_setting_presets=[dict()]

things_to_try=[]

for prompts, urls, sample in rp.shuffled(experiment_input_data):
    for prompt in prompts:
        for url in urls:
            things_to_try.append([prompt,url,sample])

i=0
while True:
    # things_to_try=rp.shuffled(things_to_try)
    for prompt,url,sample in things_to_try:
        sample_name=sample.name
        for preset_name,preset in experiment_setting_presets.items():
            i+=1
            rp.fansi_print('EXPERIMENT NUMBER: %i'%i,'green','bold')
            rp.fansi_print(preset_name,'green','bold')
            rp.fansi_print(sample_name,'green','bold')

            rp.ic(prompt,url,preset)

            results = run_peekaboo(
                name=sample_name+'.'+prompt+'.'+preset_name,
                label=SimpleLabel(prompt),
                image=url,
                **preset,
                output_folder_name='peekaboo_results_coco_nocrop'
            )

            get_result_thumbnail(results)