# Imports and Setup

## Imports

In [None]:
#This notebook tests the following two functions:
from source.color_quantizer import quantize_image, quantize_images

In [None]:
import rp
import torch
import icecream
import numpy as np

## Config

In [None]:
%config InlineBackend.figure_format='retina'

#Make the pixels of Jupyter-displayed images 
# use nearest-neigbor interpolation
from IPython.core.display import display, HTML
display(HTML("""
<style>
img {
  image-rendering: auto;
  image-rendering: crisp-edges;
  image-rendering: pixelated;
}
</style>
"""));

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu") #Uses less VRAM so we can train while running this notebook

icecream.ic(device);

In [None]:
def display_image(image):
    image = rp.as_numpy_image(image)
    rp.display_image(image)
    
def display_images(images):
    images = rp.as_numpy_images(images)
    rp.display_image(rp.tiled_images(images))

# Other Stuff

In [None]:
target_image_choices={}

target_image_choices['fox'      ] = 'https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg'
target_image_choices['magikarp' ] = 'https://static.pokemonpets.com/images/monsters-images-300-300/129-Magikarp.webp'
target_image_choices['makeup'   ] = 'https://i.redd.it/vxvs3dgsbxw31.png'
target_image_choices['snowflake'] = 'https://2s7gjr373w3x22jf92z99mgm5w-wpengine.netdna-ssl.com/wp-content/uploads/2020/11/snowflake_shutterstock_kichigin.jpg'
target_image_choices['portal'   ] = 'https://static.wikia.nocookie.net/half-life/images/9/9d/Atlas_P-body_fhp2.jpg/revision/latest?cb=20110519013122&path-prefix=en'

def display_target_images():
    target_labels, target_images = zip(*target_image_choices.items())
    target_images = [rp.load_image(image) for image in target_images]
    target_images = [rp.cv_resize_image(image,(128,128)) for image in target_images]
    target_images = rp.labeled_images(target_images, target_labels)
    target_images = rp.tiled_images(target_images)
    target_images = rp.labeled_image(target_images, "Choices", size=30)
    rp.display_image(target_images)

def load_target_image(target_image:str):
    target_image = rp.load_image    (target_image, use_cache=True)
    target_image = rp.as_float_image(target_image)
    target_image = rp.as_rgb_image  (target_image)
    target_image = rp.crop_image    (target_image, target_height, target_width, origin='center')
    target_image = target_image.copy()
    return target_image

display_target_images()

In [None]:
target_image_choice = 'makeup'

In [None]:
target_image_url    = target_image_choices[target_image_choice]
target_image = rp.load_image(target_image_url)

target_image=rp.as_rgb_image(target_image)
target_image=rp.as_float_image(target_image)

rp.display_image(target_image)
icecream.ic(target_image.shape,target_image.max(),target_image.min(),target_image.dtype);

In [None]:
target_image = rp.as_torch_image(target_image).to(device)

In [None]:
colors = [[0,0,0], [1,.5,0], [1,1,1],  [1,0,0], [1,0,.5]]

def solid_color_image(color, height=128, width=128):
    white_image  = np.ones((height,width,3))
    color        = np.asarray(color)
    color_image  = white_image * color[None,None]
    label        = str(color)
    color_image  = rp.labeled_image(color_image, label, size=20)
    return color_image

rp.display_image(
    rp.labeled_image(
        rp.horizontally_concatenated_images(
            [solid_color_image(color) for color in colors]
        ),
        "Color Pallette",
        size=30
    )
)
    
colors = torch.tensor(colors).to(device)

In [None]:
display_image(quantize_image(target_image, colors))

In [None]:
batch = rp.load_images(target_image_choices.values())
batch = [rp.cv_resize_image(image, (256,256)) for image in batch]
batch = [rp.as_float_image(rp.as_rgb_image(image)) for image in batch]
batch = rp.as_numpy_array (batch)
batch = rp.as_torch_images(batch)

print("Original Images:")
display_images(batch)

quantized_batch = quantize_images(batch, colors)

print("Quantized Images:")
display_images(quantized_batch)

In [None]:
#Correctness Check
gradient_image=np.ones((512,512))*np.linspace(0,1,512)
rp.display_image(gradient_image)

palette=torch.tensor([[0],[.5],[1]])

quantized_gradient_image=gradient_image
quantized_gradient_image=torch.tensor(quantized_gradient_image)[None].float()
quantized_gradient_image=quantize_image(quantized_gradient_image,palette)

rp.display_image(rp.as_numpy_array(quantized_gradient_image)[0])