In [1]:
from datetime import datetime
from os import path
from io import BytesIO
from PIL import Image
import torch
import ipywidgets as widgets
from kandinsky2 import get_kandinsky2

In [6]:
model = get_kandinsky2(
    'cuda', 
    task_type='text2img', 
    cache_dir='/tmp/kandinsky2', 
    model_version='2.1', 
    use_flash_attention=False
)

def torch_gc():
   with torch.cuda.device("cuda"):
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()



making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.


In [2]:
im_per_iter = widgets.Checkbox(
    value=False,
    description='Display image per iteration',
)
display(im_per_iter)

def formatted_now():
    return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

save_imgs = widgets.Checkbox(
    value=False,
    description='Save all out images',
)
display(save_imgs)

save_path = widgets.Text(value="", description='Save path:')
display(save_path)

img_name_prefix = widgets.Text(value="formatted_now", description='Name prefix:')
display(img_name_prefix)

img1 = widgets.FileUpload(
    accept='image/*',  # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    multiple=False  # True to accept multiple files upload else False
)
img2 = widgets.FileUpload(
    accept='image/*',  # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    multiple=False  # True to accept multiple files upload else False
)
display(img1)
display(img2)

preview1 = widgets.Image(
    value=img1.value[0]["content"],
    format='png',
    width=300,
    height=400,
)
preview2 = widgets.Image(
    value=img2.value[0]["content"],
    format='png',
    width=300,
    height=400,
)
display(widgets.Box([preview1, preview2]))
raw1 = BytesIO(img1.value[0]["content"].tobytes())
pil1 = Image.open(raw1)
raw2 = BytesIO(img1.value[0]["content"].tobytes())
pil2 = Image.open(raw2)

img_w1 = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=2.0,
    step=0.1,
    description='IMG1 weigth:',
    readout_format='.1f',
)
display(img_w1)
img_w2 = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=2.0,
    step=0.1,
    description='IMG2 weigth:',
    readout_format='.1f',
)
display(img_w2)

seed = widgets.IntText(
    value=-1,
    description='Seed:',
    disabled=False
)
display(seed)

steps = widgets.IntSlider(
    value=30,
    step=1,
    description='Steps:',
)
display(steps)

n_iter = widgets.IntSlider(
    value=6,
    min=1,
    max=20,
    step=1,
    description='Total images:',
)
display(n_iter)

cfg_scale = widgets.FloatSlider(
    value=7.5,
    min=1.0,
    max=20.0,
    step=0.25,
    description='Cfg scale:',
    readout_format='.1f',
)
display(cfg_scale)

height = widgets.IntSlider(
    value=768,
    min=128,
    max=4096,
    step=2,
    description='Height:',
)
display(height)

width = widgets.IntSlider(
    value=768,
    min=128,
    max=4096,
    step=2,
    description='Width:',
)
display(width)

sampler = widgets.Dropdown(
    options=['ddim_sampler', 'p_sampler', 'plms_sampler'],
    value='p_sampler',
    description='Sampler:',
    disabled=False,
)
display(sampler)

def images_processing(images):
    for postfix, img in enumerate(images):
        display(img)
        if save_imgs.value:
            prefix = locals().get(img_name_prefix.value, formatted_now)()
            fname = f"{prefix}_{postfix}.png"
            s_path = path.join(save_path.value, fname)
            print(s_path)
            img.save(s_path)

def center_crop(image):
    width, height = image.size
    new_size = min(width, height)
    left = (width - new_size) / 2
    top = (height - new_size) / 2
    right = (width + new_size) / 2
    bottom = (height + new_size) / 2
    return image.crop((left, top, right, bottom))

prior_steps = widgets.IntSlider(
    value=5,
    min=1,
    max=50,
    step=1,
    description='Prior steps:',
)
display(prior_steps)

out = widgets.Output()

@out.capture(clear_output=True)
def imgmix(event):
    if seed.value != -1:
        torch.manual_seed(seed.value)
        torch.cuda.manual_seed_all(seed.value)
    
    print([img_w1.value, img_w2.value])
    images = []
    for cur_iter in range(n_iter.value):
        image_iter = model.mix_images(
            [pil1, pil2], [img_w1.value, img_w2.value], 
            num_steps=steps.value,
            batch_size=1, # batch_size.value,
            guidance_scale=cfg_scale.value,
            h=height.value,
            w=width.value,
            sampler=sampler.value, 
            prior_cf_scale=4,
            prior_steps=str(prior_steps.value)
        )
        torch_gc()
        if im_per_iter.value:
            images_processing(image_iter)
        else:
            images.extend(image_iter)


    if not im_per_iter.value and images:
        images_processing(images)

generate = widgets.Button(description="Generate")
display(generate)
generate.on_click(imgmix)
display(out)

Checkbox(value=False, description='Display image per iteration')

Checkbox(value=False, description='Save all out images')

Text(value='', description='Save path:')

Text(value='formatted_now', description='Name prefix:')

FileUpload(value=(), accept='image/*', description='Upload')

FileUpload(value=(), accept='image/*', description='Upload')