TODO: Set up an exclusive alpha map; so it has to choose between alphas for multi prompts on a single image. For example, a region that has "jean luc picard" cannot be the same as a region that has "emma watson". If we want we could relax or play with those constraints; but let's keep it simple first. We want to eventualy build a map of the whole image with all classes that way.

To do that, add loss penalizing common alpha (dot product between offending alpha masks)

If total alpha dips below some threshold, or some kinda statistic - it means it failed to find the thing

To combine masks, we can have another optimization that allows weighted averages of all previously found masks until we get one we like.

TODO: Average the results across multiple runs. What can we do with that? 

TODO: Add some kinda priors for shape; like superpixels or something. Instance segmentation can kinda do this, so can bilaterl blurs...

TODO: Multiple classes - use stabledifusion prompt subtraction

TODO: Get mean prompt across texts

Ways to think about it:
- What can we chip away while keeping the given prompt?

In [None]:
import rp
import nerf.sd as sd
import numpy as np
import torch
import torch.nn as nn
from ryan.source.learnable_textures import LearnableTexturePackRaster,LearnableTexturePackFourier
from ryan.source.learnable_textures import LearnableImageRaster,LearnableImageFourier, LearnableImageRasterBilateral, LearnableImageFourierBilateral
import icecream
from IPython.display import clear_output
from easydict import EasyDict
from ryan.bilateral_blur import BilateralProxyBlur
import timm
from torchvision.transforms.functional import normalize
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.vision_transformer import vit_base_patch16_224_dino

In [None]:
if 's' not in dir():
    s=sd.StableDiffusion('cuda:2',"CompVis/stable-diffusion-v1-4")
    # s=sd.StableDiffusion('cuda:0',"/raid/xiangli/Codes/VOC-model/car")
device=s.device

In [None]:
#SETTINGS
def make_learnable_image(height, width, num_channels, foreground=None):
    #Here we determine our image parametrization schema
    bilateral_blur =  BilateralProxyBlur(foreground,**bilateral_kwargs)
    return LearnableImageFourierBilateral(bilateral_blur,num_channels) #A neural neural image
    return LearnableImageRasterBilateral(bilateral_blur,num_channels) #A neural neural image
    return LearnableImageFourier(height,width,num_channels) #A neural neural image

In [None]:
def blend_torch_images(foreground, background, alpha):
    #Input assertions
    assert foreground.shape==background.shape
    C,H,W=foreground.shape
    assert alpha.shape==(H,W), 'alpha is a matrix'
    
    return foreground*alpha + background*(1-alpha)

class PeekabooSegmenter(nn.Module):
    def __init__(self, image:np.ndarray, labels:list, size:int=256, name:str='Untitled'):
        
        super().__init__()
        
        height=width=size #We use square images for now
        
        assert all(issubclass(type(label),BaseLabel) for label in labels)
        assert len(labels), 'Must have at least one class to segment'
        
        self.height=height
        self.width=width
        self.labels=labels
        self.name=name
        
        assert rp.is_image(image), 'Input should be a numpy image'
        image=rp.cv_resize_image(image,(height,width))
        image=rp.as_rgb_image(image) #Make sure it has 3 channels in HWC form
        image=rp.as_float_image(image) #Make sure it's values are between 0 and 1
        assert image.shape==(height,width,3) and image.min()>=0 and image.max()<=1
        self.image=image
        
        self.foreground=rp.as_torch_image(image).to(device) #Convert the image to a torch tensor in CHW form
        assert self.foreground.shape==(3, height, width)
        
        self.background=self.foreground*0 #The background will be a solid color for now
        
        self.alphas=make_learnable_image(height,width,num_channels=self.num_labels,foreground=self.foreground)
            
    @property
    def num_labels(self):
        return len(self.labels)
            
    def set_background_color(self, color):
        r,g,b = color
        assert 0<=r<=1 and 0<=g<=1 and 0<=b<=1
        self.background[0]=r
        self.background[1]=g
        self.background[2]=b
        
    def randomize_background(self):
        self.set_background_color(rp.random_rgb_float_color())
        
    def forward(self, alphas=None, return_alphas=False):
        output_images = []
        
        if alphas is None:
            alphas=self.alphas()
        
        assert alphas.shape==(self.num_labels, self.height, self.width)
        assert alphas.min()>=0 and alphas.max()<=1
        
        for alpha in alphas:
            output_image=blend_torch_images(foreground=self.foreground, background=self.background, alpha=alpha)
            output_images.append(output_image)
            
        output_images=torch.stack(output_images)
        
        assert output_images.shape==(self.num_labels, 3, self.height, self.width) #In BCHW form
        
        if return_alphas:
            return output_images, alphas
        else:
            return output_images

In [None]:
def display(self):
    #This is a method of PeekabooSegmenter, but can be changed without rewriting the class if you want to change the display

    colors = [(0,0,0), (1,1,1), ]#(1,0,0), (0,1,0), (0,0,1)] #Colors used to make the display
    alphas = rp.as_numpy_array(self.alphas())
    image = self.image
    assert alphas.shape==(self.num_labels, self.height, self.width)

    composites = []
    for color in colors:
        self.set_background_color(color)
        column=rp.as_numpy_images(self(self.alphas()))
        composites.append(column)

    label_names=[label.name for label in self.labels]

    stats_lines = [
        self.name,
        '',
        'H,W = %ix%i'%(self.height,self.width),
    ]

    def try_add_stat(stat_format, var_name):
        if var_name in globals():
            stats_line=stat_format%globals()[var_name]
            stats_lines.append(stats_line)

    try_add_stat('Gravity: %.2e','GRAVITY'   )
    try_add_stat('Batch Size: %i','BATCH_SIZE')
    try_add_stat('Iter: %i','iter_num')
    try_add_stat('Init Iters: %i','INIT_ITERS')
    try_add_stat('Image Name: %s','image_filename')
    try_add_stat('Learning Rate: %.2e','LEARNING_RATE')
    try_add_stat('Guidance: %i%%','GUIDANCE_SCALE')

    stats_image=rp.labeled_image(self.image, rp.line_join(stats_lines), 
                                 size=15*len(stats_lines), 
                                 position='bottom', align='center')

    composite_grid=rp.grid_concatenated_images([
        rp.labeled_images(alphas,label_names),
        *composites
    ])

    output_image = rp.horizontally_concatenated_images(stats_image, composite_grid)

    rp.display_image(output_image)

    return output_image

PeekabooSegmenter.display=display

In [None]:
def get_mean_embedding(prompts:list):
    return torch.mean(
        torch.stack(
            [s.get_text_embeds(prompt) for prompt in prompts]
        ),
        dim=0
    ).to(device)

class BaseLabel:
    def __init__(self, name:str, embedding:torch.Tensor):
        #Later on we might have more sophisticated embeddings, such as averaging multiple prompts
        #We also might have associated colors for visualization, or relations between labels
        self.name=name
        self.embedding=embedding
        
    def get_sample_image(self):
        output=s.embed_to_img(self.embedding)[0]
        assert rp.is_image(output)
        return output

    def __repr__(self):
        return '%s(name=%s)'%(type(self).__name__,self.name)
        
class SimpleLabel(BaseLabel):
    def __init__(self, name:str):
        super().__init__(name, s.get_text_embeds(name).to(device))

class MeanLabel(BaseLabel):
    #Test: rp.display_image(rp.horizontally_concatenated_images(MeanLabel('Dogcat','dog','cat').get_sample_image() for _ in range(1)))
    def __init__(self, name:str, *prompts):
        prompts=rp.detuple(prompts)
        super().__init__(name, get_mean_embedding(prompts))
    

In [None]:
def log_cell(cell_title):
    rp.fansi_print("<Cell: %s>"%cell_title, 'cyan', 'underlined')
    rp.ptoc()
def log(x):
    x=str(x)
    rp.fansi_print(x, 'yellow')

In [None]:
class PeekabooResults(EasyDict):
    #Acts like a dict, except you can read/write parameters by doing self.thing instead of self['thing']
    pass

def save_peekaboo_results(results,new_folder_path):
    assert not rp.folder_exists(new_folder_path), 'Please use a different name, not %s'%new_folder_path
    rp.make_folder(new_folder_path)
    with rp.SetCurrentDirectoryTemporarily(new_folder_path):
        log("Saving PeekabooResults to "+new_folder_path)
        params={}
        for key in results:
            value=results[key]
            if rp.is_image(value): 
                #Save a single image
                rp.save_image(value,key+'.png')
            elif isinstance(value, np.ndarray) and rp.is_image(value[0]):
                #Save a folder of images
                rp.make_directory(key)
                with rp.SetCurrentDirectoryTemporarily(key):
                    for i in range(len(value)):
                        rp.save_image(value[i],str(i)+'.png')
            elif isinstance(value, np.ndarray):
                #Save a generic numpy array
                np.save(key+'.npy',value) 
            else:

                import json
                try:
                    json.dumps({key:value})
                    #Assume value is json-parseable
                    params[key]=value
                except Exception:
                    params[key]=str(value)
        rp.save_json(params,'params.json',pretty=True)
        log("Done saving PeekabooResults to "+new_folder_path+"!")
                    
                

In [None]:
def run_peekaboo(label, image):
    assert rp.is_image(image)
    assert issubclass(type(label),BaseLabel)
    image=rp.as_rgb_image(rp.as_float_image(image))
    rp.tic()
    time_started=rp.get_current_date()
    
    
    log_cell('Get Hyperparameters') ########################################################################
    global GRAVITY
    global BATCH_SIZE
    global NUM_ITER
    global GUIDANCE_SCALE
    global LEARNING_RATE
    global INIT_ITERS
    global bilateral_kwargs
    icecream.ic(GRAVITY, BATCH_SIZE, NUM_ITER, GUIDANCE_SCALE, INIT_ITERS, bilateral_kwargs)



    log_cell('Alpha Initializer') ########################################################################

    p=PeekabooSegmenter(image,labels=[label], name=name).to(device)

    blur_image=rp.as_numpy_image(p.alphas.bilateral_blur(p.foreground))
    rp.display_image(blur_image)

    p.display();




    log_cell('Create Optimizers') ########################################################################

    params=list(p.parameters())
    optim=torch.optim.Adam(params,lr=1e-3)
    optim=torch.optim.SGD(params,lr=LEARNING_RATE)




    log_cell('Create DINO Map') ########################################################################

    if "dino_model" not in dir():
        dino_model = vit_base_patch16_224_dino(True)

    dino_preview_image=None
    def get_dino_map(image, contrast=4):
        norm_image = rp.as_torch_image(rp.as_rgb_image(rp.as_float_image(rp.cv_resize_image(image, (224, 224)))))
        norm_image = normalize(norm_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)

        sample = norm_image[None]
        assert sample.shape==(1, 3, 224, 224)

        out = dino_model.forward_features(sample)

        vis = out[0, 1:].reshape(14, 14, -1)

        TOKEN_NUMBER = 0  # 0 is classification token, all others are spatial

        vis = vis @ out[0, TOKEN_NUMBER]
        vis = rp.full_range(vis)

        dino_map = 1 - rp.cv_resize_image(
            rp.as_numpy_array(vis), rp.get_image_dimensions(image)
        )

        dino_map = ((dino_map-1/2)*contrast)+1/2
        dino_map = np.clip(dino_map, 0, 1)

        nonlocal dino_preview_image
        dino_preview_image=(
            rp.horizontally_concatenated_images(
                dino_map,
                rp.blend_images(
                    rp.blend_images(image, 0, (dino_map < 0.5) * 0.8),
                    (0, 1, 0),
                    rp.cv_dilate(rp.auto_canny(dino_map < 0.5),diameter=2),
                ),
            )
        ) 
        rp.display_image(dino_preview_image)

        assert dino_map.shape==rp.get_image_dimensions(image)

        dino_map = torch.Tensor(dino_map[None,:,:]).to(device) #Make it a torch image with 1 channel

        return dino_map

    dino_map=get_dino_map(p.image)
    icecream.ic(dino_map.shape);




    log_cell('DINO Initialization Pretraining') ########################################################################

    #Pre-train the alphas to be like the initial_alphas

    if INIT_ITERS:
        log("Initializing with DINO please wait...")

        initial_alphas=dino_map 

        initial_anim=[]
        initial_losses=[]

        display_eta=rp.eta(INIT_ITERS)
        for _ in range(INIT_ITERS):
            display_eta(_)

            alphas=p.alphas()
            loss = ((alphas - initial_alphas)**2).sum()
            initial_losses.append(float(loss))
            loss.backward()
            optim.step()
            optim.zero_grad()

            # if not _%(INIT_ITERS//10):
            #     clear_output()
            #     frame=rp.horizontally_concatenated_images(rp.as_numpy_array(alphas))
            #     rp.display_image(frame)
            #     initial_anim.append(frame)

        # clear_output()
        # rp.display_image_slideshow(initial_anim)
        # rp.display_image(frame)
        rp.line_graph_via_bokeh(initial_losses, title='Initialization Losses',xlabel='Iter',ylabel='Loss')
    else:
        log("Skipping DINO Initialization because INIT_ITERS=0")




    log_cell('Create Logs') ########################################################################
    global iter_num
    iter_num=0
    timelapse_frames=[]




    log_cell('Do Training') ########################################################################
    preview_interval=NUM_ITER//10 #Show 10 preview images throughout training to prevent output from being truncated
    preview_interval=max(1,preview_interval)
    log("Will show preview images every %i iterations"%(preview_interval))

    try:
        display_eta=rp.eta(NUM_ITER)
        for _ in range(NUM_ITER):
            display_eta(_)
            iter_num+=1

            alphas=p.alphas()

            for __ in range(BATCH_SIZE):
                p.randomize_background()
                composites=p()
                for label, composite in zip(p.labels, composites):
                    s.train_step(label.embedding, composite[None], 
                                 guidance_scale=GUIDANCE_SCALE
                                )

            ((alphas.sum())*GRAVITY).backward()

            optim.step()
            optim.zero_grad()

            with torch.no_grad():
                # if not _%100:
                    #Don't overflow the notebook
                    # clear_output()
                if not _%preview_interval: 
                    timelapse_frames.append(p.display())
                    rp.ptoc()
    except KeyboardInterrupt:
        log("Interrupted early, returning current results...")
        pass

                
    rp.ptoc()
    return PeekabooResults(
        #The main output is the alphas
        alphas=rp.as_numpy_array(alphas),
        
        #Keep track of hyperparameters used
        GRAVITY=GRAVITY,
        BATCH_SIZE=BATCH_SIZE,
        NUM_ITER=NUM_ITER,
        GUIDANCE_SCALE=GUIDANCE_SCALE,
        INIT_ITERS=INIT_ITERS,
        bilateral_kwargs=bilateral_kwargs,
        
        #Keep track of the inputs used
        label=label,
        image=image,
        
        #Record some extra info
        preview_image=p.display(),
        timelapse_frames=rp.as_numpy_array(timelapse_frames),
        dino_preview_image=dino_preview_image,
        blur_image=blur_image,
        height=p.height,
        width=p.width,
        p_name=p.name,
        
        git_hash=rp.get_current_git_hash(), 
        time_started=rp.r._format_datetime(time_started),
        time_completed=rp.r._format_datetime(rp.get_current_date()),
        device=device,
        computer_name=rp.get_computer_name(),
    ) 
    

In [None]:
def visualize_gradient(label, image):
    assert rp.is_image(image)
    assert issubclass(type(label),BaseLabel)
    image=rp.as_rgb_image(rp.as_float_image(image))
    rp.tic()
    time_started=rp.get_current_date()
    
    
    log_cell('Get Hyperparameters') ########################################################################
    global GRAVITY
    global BATCH_SIZE
    global NUM_ITER
    global GUIDANCE_SCALE
    global LEARNING_RATE
    global INIT_ITERS
    global bilateral_kwargs
    icecream.ic(GRAVITY, BATCH_SIZE, NUM_ITER, GUIDANCE_SCALE, INIT_ITERS, bilateral_kwargs)



    log_cell('Alpha Initializer') ########################################################################

    p=PeekabooSegmenter(image,labels=[label], name=name).to(device)

    blur_image=rp.as_numpy_image(p.alphas.bilateral_blur(p.foreground))
    rp.display_image(blur_image)

    p.display();




    log_cell('Create Optimizers') ########################################################################

    params=list(p.parameters())
    optim=torch.optim.Adam(params,lr=1e-3)
    optim=torch.optim.SGD(params,lr=LEARNING_RATE)




    log_cell('Create DINO Map') ########################################################################

    if "dino_model" not in dir():
        dino_model = vit_base_patch16_224_dino(True)

    dino_preview_image=None
    def get_dino_map(image, contrast=4):
        norm_image = rp.as_torch_image(rp.as_rgb_image(rp.as_float_image(rp.cv_resize_image(image, (224, 224)))))
        norm_image = normalize(norm_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)

        sample = norm_image[None]
        assert sample.shape==(1, 3, 224, 224)

        out = dino_model.forward_features(sample)

        vis = out[0, 1:].reshape(14, 14, -1)

        TOKEN_NUMBER = 0  # 0 is classification token, all others are spatial

        vis = vis @ out[0, TOKEN_NUMBER]
        vis = rp.full_range(vis)

        dino_map = 1 - rp.cv_resize_image(
            rp.as_numpy_array(vis), rp.get_image_dimensions(image)
        )

        dino_map = ((dino_map-1/2)*contrast)+1/2
        dino_map = np.clip(dino_map, 0, 1)

        nonlocal dino_preview_image
        dino_preview_image=(
            rp.horizontally_concatenated_images(
                dino_map,
                rp.blend_images(
                    rp.blend_images(image, 0, (dino_map < 0.5) * 0.8),
                    (0, 1, 0),
                    rp.cv_dilate(rp.auto_canny(dino_map < 0.5),diameter=2),
                ),
            )
        ) 
        rp.display_image(dino_preview_image)

        assert dino_map.shape==rp.get_image_dimensions(image)

        dino_map = torch.Tensor(dino_map[None,:,:]).to(device) #Make it a torch image with 1 channel

        return dino_map

    dino_map=get_dino_map(p.image)
    icecream.ic(dino_map.shape);




    log_cell('DINO Initialization Pretraining') ########################################################################

    #Pre-train the alphas to be like the initial_alphas

    if INIT_ITERS:
        log("Initializing with DINO please wait...")

        initial_alphas=dino_map 

        initial_anim=[]
        initial_losses=[]

        display_eta=rp.eta(INIT_ITERS)
        for _ in range(INIT_ITERS):
            display_eta(_)

            alphas=p.alphas()
            loss = ((alphas - initial_alphas)**2).sum()
            initial_losses.append(float(loss))
            loss.backward()
            optim.step()
            optim.zero_grad()

            # if not _%(INIT_ITERS//10):
            #     clear_output()
            #     frame=rp.horizontally_concatenated_images(rp.as_numpy_array(alphas))
            #     rp.display_image(frame)
            #     initial_anim.append(frame)

        # clear_output()
        # rp.display_image_slideshow(initial_anim)
        # rp.display_image(frame)
        rp.line_graph_via_bokeh(initial_losses, title='Initialization Losses',xlabel='Iter',ylabel='Loss')
    else:
        log("Skipping DINO Initialization because INIT_ITERS=0")




    log_cell('Create Logs') ########################################################################
    global iter_num
    iter_num=0
    timelapse_frames=[]




    log_cell('Do Training') ########################################################################
    preview_interval=NUM_ITER//10 #Show 10 preview images throughout training to prevent output from being truncated
    preview_interval=max(1,preview_interval)
    log("Will show preview images every %i iterations"%(preview_interval))

    try:
        display_eta=rp.eta(NUM_ITER)
        for _ in range(NUM_ITER):
            display_eta(_)
            iter_num+=1

            alphas=p.alphas()

            for __ in range(BATCH_SIZE):
                p.randomize_background()
                composites=p()
                for label, composite in zip(p.labels, composites):
                    s.train_step(label.embedding, composite[None], 
                                 guidance_scale=GUIDANCE_SCALE
                                )

            ((alphas.sum())*GRAVITY).backward()

            optim.step()
            optim.zero_grad()

            with torch.no_grad():
                # if not _%100:
                    #Don't overflow the notebook
                    # clear_output()
                if not _%preview_interval: 
                    timelapse_frames.append(p.display())
                    rp.ptoc()
    except KeyboardInterrupt:
        log("Interrupted early, returning current results...")
        pass

                
    rp.ptoc()
    return PeekabooResults(
        #The main output is the alphas
        alphas=rp.as_numpy_array(alphas),
        
        #Keep track of hyperparameters used
        GRAVITY=GRAVITY,
        BATCH_SIZE=BATCH_SIZE,
        NUM_ITER=NUM_ITER,
        GUIDANCE_SCALE=GUIDANCE_SCALE,
        INIT_ITERS=INIT_ITERS,
        bilateral_kwargs=bilateral_kwargs,
        
        #Keep track of the inputs used
        label=label,
        image=image,
        
        #Record some extra info
        preview_image=p.display(),
        timelapse_frames=rp.as_numpy_array(timelapse_frames),
        dino_preview_image=dino_preview_image,
        blur_image=blur_image,
        height=p.height,
        width=p.width,
        p_name=p.name,
        
        git_hash=rp.get_current_git_hash(), 
        time_started=rp.r._format_datetime(time_started),
        time_completed=rp.r._format_datetime(rp.get_current_date()),
        device=device,
        computer_name=rp.get_computer_name(),
    ) 
    

In [None]:
#Xiang's Dataloader

import numpy as np
from PIL import Image
import os

voc_label_map = {
    0: 'background',
    1: 'aeroplane',
    2: 'bicycle',
    3: 'bird',
    4: 'boat',
    5: 'bottle',
    6: 'bus',
    7: 'car',
    8: 'cat',
    9: 'chair',
    10: 'cow',
    11: 'dining table',
    12: 'dog',
    13: 'horse',
    14: 'motorcycle',
    # 14: 'motorbike',
    # 15: 'person',
    15: 'man',
    16: 'potted plant',
    17: 'sheep',
    18: 'sofa',
    19: 'train',
    # 20: 'television monitor',
    20: 'monitor',
}

voc_prompts={}
for x in voc_label_map:
    voc_prompts[x]='a '+voc_label_map[x]


class SegDataset:
    pass


class CroppedPascalDataset(SegDataset):

    def __init__(self, root, split="val"):
        self.root = root
        # assert split in ["train", "val"], f"invalid split: {split}"
        self.split = split
        anno_path = f"{root}/ImageSets/Segmentation/{split}.txt"
        self.file_list = self.get_file_list(anno_path)
        self.label_remap = voc_label_map

    def get_file_list(self, path):
        with open(path, "r") as fo:
            files = fo.readlines()
            # print(files[0], f"{self.root}/CroppedImages/cropped-{files[0].strip()}.jpg")
        files = [x.strip() for x in files if os.path.exists(f"{self.root}/CroppedImages/cropped-{x.strip()}.jpg")]
        # files=rp.shuffled(files)
        print(f'Found {len(files)} files.')
        return files

    @staticmethod
    def load_image(path, get_arr=True):
        if get_arr:
            return np.array(Image.open(path))
        else:
            return Image.open(path)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, item, get_arr=True):
        if isinstance(item, int):
            item = item % len(self.file_list)
            cur_idx = self.file_list[item]
        elif isinstance(item, str):
            cur_idx = item
        else:
            raise NotImplementedError(f"Invalid item dtype: {type(item)}")
        img_path = f"{self.root}/CroppedImages/cropped-{cur_idx}.jpg"
        seg_path = f"{self.root}/SegmentationClass/{cur_idx}.png"

        img = self.load_image(img_path, get_arr=get_arr)
        seg = self.load_image(seg_path, get_arr=get_arr)
        
        cropped_path = f"{self.root}/CroppedImages/cropped-{cur_idx}.jpg.txt"
        # print(seg.shape)
        with open(cropped_path, 'r') as f:
            info = f.readline().split()
            x1 = int(info[0])
            x2 = int(info[1])
            y1 = int(info[2])
            y2 = int(info[3])
            # print(info)
        seg = seg[y1:y2, x1:x2]
        # print(img.shape, seg.shape)
        assert seg.shape == img.shape[:2]
        
        seg_labels = np.unique(seg)
        seg_labels = set(seg_labels) - {0, 255}

        return EasyDict(
            img=img,
            seg=seg,
            seg_labels=seg_labels,
            cur_idx=cur_idx,
            img_path=img_path,
            seg_path=seg_path, 
            cropped_path=cropped_path,
            names=[voc_label_map[x] for x in seg_labels],
            prompts=[voc_prompts[x] for x in seg_labels],
        )


if __name__ == '__main__':
    data_dir = "/raid/datasets/pascal_voc/VOC2012"
    ds = CroppedPascalDataset(data_dir, split='seg0')

    from collections import defaultdict

    l_set = defaultdict(int)
    for i in range(50):
        for j in ds[i].seg_labels:
            l_set[j] += 1



In [None]:
def IOU(a,b):
    assert rp.is_image(a)
    assert rp.is_image(b)
    a=rp.as_grayscale_image(rp.as_float_image(a))
    b=rp.as_grayscale_image(rp.as_float_image(b))
    intersection=a*b
    union=a+b-intersection
    return intersection.sum()/union.sum()

In [None]:
#We set the hyperparamers here as global variables
GRAVITY=1e-2
GRAVITY=1e-1/2
# GRAVITY=1e-1
# GRAVITY=1e-1*1.5
# GRAVITY=1e-1*1.5
BATCH_SIZE=1
NUM_ITER=300
NUM_ITER=300
# NUM_ITER=3
GUIDANCE_SCALE=100 #100 is default
INIT_ITERS=50 #DINO Pretraining Iters
INIT_ITERS=0
LEARNING_RATE=1e-5
bilateral_kwargs=dict(kernel_size = 3,
                      # tolerance = .1,
                      tolerance = .08,
                      sigma = 5,
                      iterations=40,
                     )

In [None]:
I=0
for x in ds:
    import random
    random.seed(31)
    x=ds[rp.random_index(len(ds))]
    
    for name,prompt in zip(x.names,x.prompts):
        name=prompt='an old red mercedes benz car, vintage car'
        name=prompt='harry styles'
        name=prompt='danny devito'
        name=prompt='Arnold Schwarzenegger'
        name=prompt='The Harry Potter Boy'
        name=prompt='Harry Potter'
        name=prompt='minecraft steve'
        name=prompt='wall-e'
        name=prompt='white eve robot'
        name=prompt='c3po'
        name=prompt='rubiks cube'
        
        name=prompt=['yellow rubber duck','blue rubber duck','orange rubber duck','green rubber duck','pink rubber duck'];
        name=prompt=['red rubber duck',];
        name=prompt=['a can of mtn dew', 'a can of sprite', 'a can of dr pepper', 'dr pepper', 'sprite', 'mtn dew', 'orange', 'yellow soda', 'orange soda', 'sunkist orange', 'sunkist yellow']
        name=prompt=['strawberry jam jar']#,'skippy peanut butter']
        name=prompt=['makise kurisu anime girl']#,'skippy peanut butter']
        name=prompt=['Rintarou Okabe anime boy']#,'skippy peanut butter']
        name=prompt=['Beverly Crusher']#,'skippy peanut butter']
        name=prompt=['Jean Luc Picard']#,'skippy peanut butter']
        name=prompt=["Gyarados"]
        name=prompt=["A cat"]
        name=prompt=["Strawberry ice cream cone"]
        name=prompt=prompt[I%len(prompt)]
        I+=1
        
        
        output_folder='peekaboo_results_kr/%i.%s.%s'%(rp.millis(),x.cur_idx,name)
        print('NEW TEST!',output_folder)
        icecream.ic(output_folder,name,prompt,x.cur_idx,x.img_path,x.seg_path,x.cropped_path,x.names,x.prompts)
        label=SimpleLabel(prompt)
        # results=run_peekaboo(label,x.img)
        
        # img=rp.load_image("https://www.padoniavets.com/sites/default/files/field/image/cats-and-dogs.jpg",use_cache=True)
        # img=rp.load_image("https://img.etimg.com/thumb/msid-79778298,width-650,imgsize-1290800,,resizemode-4,quality-100/sales-of-classic-cars-have-remained-positively-stable-in-2020-.jpg",use_cache=True)
        # img=rp.load_image("https://s2.r29static.com/bin/entry/bc7/0,46,460,460/1200x1200,80/1333127/image.jpg",use_cache=True)
        # img=rp.load_image("https://s.yimg.com/ny/api/res/1.2/HIg4XSC.pN9vPMx_MHTv5w--/YXBwaWQ9aGlnaGxhbmRlcjt3PTY0MA--/https://s.yimg.com/os/creatr-uploaded-images/2021-10/69705cc0-2ac8-11ec-973b-482c29ca5c68",use_cache=True)
        # img=rp.load_image("https://assets.teenvogue.com/photos/569e7d2a74da98670ff0ce1c/1:1/w_2159,h_2159,c_limit/MCDHAPO_EC797_H.JPG",use_cache=True)
        # img=rp.load_image("https://www.nme.com/wp-content/uploads/2021/08/pokemon_brilliant_diamond_and_shining_pearl_starters_turtwig_chimchar_piplup.jpeg",use_cache=True)
        # img=rp.load_image("https://images.genius.com/ce44822bee903c5e777c004c8bcaa1ef.300x300x1.jpg",use_cache=True)
        # img=rp.load_image("https://w0.peakpx.com/wallpaper/936/113/HD-wallpaper-wall-e-and-eve-and-e-wall-eve.jpg",use_cache=True)
        # img=rp.load_image("https://www.irishtimes.com/resizer/q5bI2nij6BAIDT8AZeKKpyvm8T4=/1600x1200/filters:format(jpg):quality(70)/cloudfront-eu-central-1.images.arcpublishing.com/irishtimes/UDNDFXTEMKENA5YILPWQMHCKXU.jpg",use_cache=True)
        # img=rp.load_image("ducks.png",use_cache=True)
        # img=rp.load_image("five_sodas.png",use_cache=True)
        # img=rp.load_image("jam_jelly_beans.png",use_cache=True)
        # img=rp.load_image("https://64.media.tumblr.com/909bb24eae0b533481627a2963528da7/5fa507f8401885d2-f2/s540x810/48b07eb60b44467ddfd4892b616ec8d6812ddea4.jpg",use_cache=True)
        # img=rp.load_image("https://www.slashfilm.com/img/gallery/star-trek-picard-season-3-features-the-best-material-yet-for-dr-crusher-according-to-gates-mcfadden-comic-con/crushers-role-in-star-trek-1658612419.jpg",use_cache=True)
        # img=rp.load_image("https://akns-images.eonline.com/eol_images/Entire_Site/2016917/rs_600x600-161017135949-600.finding-dory-2.101716.jpg")
        # img=rp.load_image("https://m.media-amazon.com/images/I/519l1cA-sTL.jpg")
        img=rp.load_image("https://www.shutterstock.com/image-photo/dog-cat-under-plaid-pet-260nw-726710023.jpg")
        img=rp.load_image("ice-cream-1.jpg")
        # img=rp.load_image("https://ima/ges-wixmp-ed30a86b8c4ca887773594c2.wixmp.com/f/bbfdfeed-2118-4cd3-873b-279837423fa9/dd8x8v4-62e42076-34e3-42ce-94c5-01547f4a7239.png?token=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ1cm46YXBwOjdlMGQxODg5ODIyNjQzNzNhNWYwZDQxNWVhMGQyNmUwIiwiaXNzIjoidXJuOmFwcDo3ZTBkMTg4OTgyMjY0MzczYTVmMGQ0MTVlYTBkMjZlMCIsIm9iaiI6W1t7InBhdGgiOiJcL2ZcL2JiZmRmZWVkLTIxMTgtNGNkMy04NzNiLTI3OTgzNzQyM2ZhOVwvZGQ4eDh2NC02MmU0MjA3Ni0zNGUzLTQyY2UtOTRjNS0wMTU0N2Y0YTcyMzkucG5nIn1dXSwiYXVkIjpbInVybjpzZXJ2aWNlOmZpbGUuZG93bmxvYWQiXX0.I5vSY99f6m8YOAdtCj6sk6ZbLtDIe-IIWwfhNswI678",use_cache=True)
        # img=rp.crop_image(img,width=rp.get_image_height(img),origin='center')
        
        x.prompts[0]=prompt
        
        
        results=run_peekaboo(SimpleLabel(prompt),img)
        
        
        #Add extra data to results
        results.x_prompt=prompt
        results.x_name=name
        # 
        results.x_seg = x.seg
        results.x_seg_labels = x.seg_labels
        results.x_cur_idx = x.cur_idx
        results.x_img_path = x.img_path
        results.x_seg_path = x.seg_path
        results.x_cropped_path = x.cropped_path
        results.x_names = x.names
        results.x_prompts = x.prompts
        
        _alpha=results.alphas[0]
        _seg=x.seg!=0
        _seg=rp.cv_resize_image(_seg,(results.height,results.width))
        results.IOU_Continuous=IOU(_alpha, _seg)
        results['IOU_>.9']=IOU(_alpha>.9, _seg)
        results['IOU_>.8']=IOU(_alpha>.8, _seg)
        results['IOU_>.7']=IOU(_alpha>.7, _seg)
        results['IOU_>.6']=IOU(_alpha>.6, _seg)
        results['IOU_>.5']=IOU(_alpha>.5, _seg)
        results['IOU_>.4']=IOU(_alpha>.4, _seg)
        results['IOU_>.3']=IOU(_alpha>.3, _seg)
        results['IOU_>.2']=IOU(_alpha>.2, _seg)
        results['IOU_>.1']=IOU(_alpha>.1, _seg)
        
        
                    
        save_peekaboo_results(results,output_folder)
        
        from IPython.display import clear_output
        
        1/0
        clear_output()

In [None]:
!wget https://becs-table.com.au/wp-content/uploads/2014/01/ice-cream-1.jpg

In [None]:
x=rp.random_element(ds)

In [None]:
#We set the hyperparamers here as global variables
GRAVITY=1e-2
# GRAVITY=1e-1
BATCH_SIZE=2
NUM_ITER=300
NUM_ITER=500
# NUM_ITER=3
GUIDANCE_SCALE=50 #100 is default
INIT_ITERS=50 #DINO Pretraining Iters
INIT_ITERS=20
INIT_ITERS=0
LEARNING_RATE=1e-5
bilateral_kwargs=dict(kernel_size = 3,
                      tolerance = .1,
                      # tolerance = .08,
                      sigma = 5,
                      iterations=40,
                     )