In [None]:
# Run this cell to get nice real time updating GPU widget
!pip install jupyterlab-nvdashboard
# ttps://developer.nvidia.com/blog/gpu-dashboards-in-jupyter-lab/

In [1]:
from datetime import datetime
from os import path

In [3]:
import torch

In [4]:
import ipywidgets as widgets

In [5]:
from kandinsky2 import get_kandinsky2

In [6]:
model = get_kandinsky2(
    'cuda', 
    task_type='text2img', 
    cache_dir='/tmp/kandinsky2', 
    model_version='2.1', 
    use_flash_attention=False
)



making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.


In [7]:
def torch_gc():
   with torch.cuda.device("cuda"):
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

In [8]:
im_per_iter = widgets.Checkbox(
    value=False,
    description='Display image per iteration',
)
display(im_per_iter)

Checkbox(value=False, description='Display image per iteration')

In [9]:
def formatted_now():
    return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

In [10]:
save_imgs = widgets.Checkbox(
    value=False,
    description='Save all out images',
)
display(save_imgs)
save_path = widgets.Text(value="", description='Save path:')
display(save_path)
img_name_prefix = widgets.Text(value="formatted_now", description='Name prefix:')
display(img_name_prefix)

Checkbox(value=False, description='Save all out images')

Text(value='', description='Save path:')

Text(value='formatted_now', description='Name prefix:')

In [11]:
prompt = widgets.Textarea(value="", description='Prompt:')
display(prompt)

Textarea(value='', description='Prompt:')

In [12]:
seed = widgets.IntText(
    value=-1,
    description='Seed:',
    disabled=False
)
display(seed)

IntText(value=-1, description='Seed:')

In [13]:
steps = widgets.IntSlider(
    value=30,
    step=1,
    description='Steps:',
)
display(steps)

IntSlider(value=30, description='Steps:')

In [14]:
n_iter = widgets.IntSlider(
    value=6,
    min=1,
    max=20,
    step=1,
    description='Total images:',
)
display(n_iter)

IntSlider(value=6, description='Total images:', max=20, min=1)

In [15]:
batch_size = widgets.IntSlider(
    value=1,
    min=1,
    max=10,
    step=1,
    description='Batch size:',
)
display(batch_size)

IntSlider(value=1, description='Batch size:', max=10, min=1)

In [16]:
cfg_scale = widgets.FloatSlider(
    value=7.5,
    min=1.0,
    max=20.0,
    step=0.25,
    description='Cfg scale:',
    readout_format='.1f',
)
display(cfg_scale)

FloatSlider(value=7.5, description='Cfg scale:', max=20.0, min=1.0, readout_format='.1f', step=0.25)

In [17]:
height = widgets.IntSlider(
    value=768,
    min=128,
    max=4096,
    step=2,
    description='Height:',
)
display(height)

IntSlider(value=768, description='Height:', max=4096, min=128, step=2)

In [18]:
width = widgets.IntSlider(
    value=768,
    min=128,
    max=4096,
    step=2,
    description='Width:',
)
display(width)

IntSlider(value=768, description='Width:', max=4096, min=128, step=2)

In [19]:
sampler = widgets.Dropdown(
    options=['ddim_sampler', 'p_sampler', 'plms_sampler'],
    value='p_sampler',
    description='Sampler:',
    disabled=False,
)
display(sampler)

Dropdown(description='Sampler:', index=1, options=('ddim_sampler', 'p_sampler', 'plms_sampler'), value='p_samp…

In [32]:
def images_processing(images):
    for postfix, img in enumerate(images):
        display(img)
        if save_imgs.value:
            prefix = locals().get(img_name_prefix.value, formatted_now)()
            fname = f"{prefix}_{postfix}.png"
            s_path = path.join(save_path.value, fname)
            print(s_path)
            img.save(s_path)

In [33]:
out = widgets.Output()

@out.capture(clear_output=True)
def txt2img(event):
    if seed.value != -1:
        torch.manual_seed(seed.value)
        torch.cuda.manual_seed_all(seed.value)
    
    images = []
    for cur_iter in range(n_iter.value):
        image_iter = model.generate_text2img(
            prompt.value,
            num_steps=steps.value,
            batch_size=batch_size.value,
            guidance_scale=cfg_scale.value,
            h=height.value,
            w=width.value,
            sampler=sampler.value, 
            prior_cf_scale=4,
            prior_steps="5"
        )
        torch_gc()
        if im_per_iter.value:
            images_processing(image_iter)
        else:
            images.extend(image_iter)


    if not im_per_iter.value and images:
        images_processing(images)

In [34]:
generate = widgets.Button(description="Generate")
display(generate)
generate.on_click(txt2img)
display(out)

Button(description='Generate', style=ButtonStyle())

Output()