In [None]:
%cd ..

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from copy import deepcopy

import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from omegaconf.dictconfig import DictConfig

from kandinsky2 import CONFIG_2_1, Kandinsky2_1 

### Helper functions

In [3]:
def show_image(image, figsize=(5, 5), cmap=None, title='', xlabel=None, ylabel=None, axis=False):
    plt.figure(figsize=figsize)
    plt.imshow(image, cmap=cmap)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.axis(axis)
    plt.show();

def show_images(images, n_rows=1, title='', figsize=(5, 5), cmap=None, xlabel=None, ylabel=None, axis=False):
    n_cols = len(images) // n_rows
    if n_rows == n_cols == 1:
        show_image(images[0], title=title, figsize=figsize, cmap=cmap, xlabel=xlabel, ylabel=ylabel, axis=axis)
    else:
        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
        fig.tight_layout(pad=0.0)
        axes = axes.flatten()
        for ax, img in zip(axes, images):
            ax.imshow(img, cmap=cmap)
            ax.set_title(title)
            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)
            ax.axis(axis)
        plt.show();
        
def create_model(unet_path, cache_root, task_type, device, use_fp16=True):
    config = DictConfig(deepcopy(CONFIG_2_1))
    cache_dir = os.path.join(cache_root, "2_1")

    config["model_config"]["up"] = False
    config["model_config"]["use_fp16"] = use_fp16
    config["model_config"]["inpainting"] = False
    config["model_config"]["cache_text_emb"] = False
    config["model_config"]["use_flash_attention"] = False

    config["tokenizer_name"] = os.path.join(cache_dir, "text_encoder")
    config["text_enc_params"]["model_path"] = os.path.join(cache_dir, "text_encoder")
    config["prior"]["clip_mean_std_path"] = os.path.join(cache_dir, "ViT-L-14_stats.th")
    config["image_enc_params"]["ckpt_path"] = os.path.join(cache_dir, "movq_final.ckpt")

    model_path = os.path.join(cache_dir, "decoder_fp16.ckpt") if unet_path is None else unet_path
    prior_path = os.path.join(cache_dir, "prior_fp16.ckpt")
    
    return Kandinsky2_1(config, model_path, prior_path, device, task_type=task_type)

### Initialize default Kandinsky 2.1 model (careful with cache dir)

In [4]:
device = 'cuda'
task_type = 'text2img'
cache_dir = '/tmp/kandinsky2'

unet_path = './finetune/output/sapsan/decoder_fp16.ckpt'

model = create_model(
    unet_path=unet_path,
    cache_root=cache_dir,
    task_type=task_type,
    device=device);

making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.


### Paste your class_prompt and instance_prompt

In [5]:
class_prompt = 'train' # Global object name (train, bear, etc)
instance_prompt = 'sapsan' # Unique object name (sapsan, *, etc)

### Generation with p_sampler

In [7]:
prompt = f'''Photo of a {instance_prompt} {class_prompt}. photorealistic'''

for _ in range(2):
    images = model.generate_text2img(
        prompt,
        num_steps=50, 
        batch_size=4, 
        guidance_scale=7.5,
        h=768, 
        w=768,
        sampler='p_sampler', 
        prior_cf_scale=2,
        prior_steps='4',
    )

    show_images(images, n_rows=2, figsize=(15, 15))

### Generation with ddim_sampler

In [None]:
for _ in range(2):
    images = model.generate_text2img(
        prompt,
        num_steps=50, 
        batch_size=4, 
        guidance_scale=7.5,
        h=768, 
        w=768,
        sampler='ddim_sampler', 
        prior_cf_scale=4,
        prior_steps='ddim25',
    )

    show_images(images, n_rows=2, figsize=(15, 15))