In [None]:
import ipywidgets as widgets
from IPython.display import display
from io import BytesIO
from PIL import Image

from visual_concept_blending import VisualConceptBlending

In [None]:
class ImageUploader:
    def __init__(self, multiple=True):
        self.upload_widget = widgets.FileUpload(accept='image/*', multiple=multiple)
        self.output = widgets.Output()
        self.uploaded_images = []
        self.upload_widget.observe(self.on_upload_change, names='value')
    
    def on_upload_change(self, change):
        self.output.clear_output(wait=True)
        self.uploaded_images.clear()
        for file_info in change['new']:
            image_data = BytesIO(file_info['content'])
            image = Image.open(image_data).resize((512, 512))
            self.uploaded_images.append(image)
            with self.output:
                display(image)
    
    def display(self):
        display(self.upload_widget, self.output)
    
    def get_uploaded_images(self):
        return self.uploaded_images

## Load Source Image (Use 1 or 2)
### 1. Load by Specifying Path

In [None]:
src_img_path = "<path>"
src_img = Image.open(src_img_path).convert('RGB').resize((512, 512))

### 2. Load by Drag and Drop

In [None]:
image_uploader = ImageUploader(multiple=True)
image_uploader.display()

results = image_uploader.get_uploaded_images()

In [None]:
src_img = results[0].convert('RGB').resize((512, 512))

## Load Reference Images (Use 1 or 2)
### 1. Load by Specifying Path

In [None]:
ref_imgs_paths = ["<path1>", "<path2>"] # two or more paths to reference images

ref_imgs = [
    Image.open(p).convert('RGB').resize((512, 512))
    for p in ref_imgs_paths
]

### 2. Load by Drag and Drop

In [None]:
image_uploader = ImageUploader(multiple=True)
image_uploader.display()

results = image_uploader.get_uploaded_images()

In [None]:
ref_imgs = [
    results[i].convert('RGB').resize((512, 512))
    for i in range(len(results))
]

## Generate Images

In [None]:
common = True
ip = VisualConceptBlending(common=common)
theta = 0.3
depth_scale = 0.0
SEED = 168

output_img, depth_map = ip.run(src_img, ref_imgs, seed=SEED, theta=theta, num_samples=1, depth_scale=depth_scale)
output_img = output_img[0].resize((512, 512))

output_img