In [1]:
# !git clone https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder ./models/kandinsky-2-2-decoder

In [2]:
# !git clone https://huggingface.co/kandinsky-community/kandinsky-2-2-prior ./models/kandinsky-2-2-prior

In [3]:
# download models from https://drive.google.com/drive/folders/1GYMJ6ZJMljikSPkbJQNIbORqtdJjHBD0?usp=sharing

In [4]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import datetime
import inspect

import cv2
import torch
import numpy as np
from PIL import Image
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
from diffusers import DDIMScheduler, DDPMScheduler, VQModel
from diffusers.pipelines.kandinsky import MultilingualCLIP

from kandimate.models.unet import UNet3DConditionModel
from kandimate.pipelines.pipeline_kandimation import KandimatePipeline
from kandimate.utils.util import save_videos_grid

%load_ext autoreload
%autoreload 2

In [5]:
def show_gif_in_jupyter(gif_path, width=380):
    from IPython import display
    return display.HTML(f'<img src="{gif_path}" width="{width}">')

In [6]:
device = 'cuda'
prior_path = '../models/kandinsky-2-2-prior'
decoder_path = '../models/kandinsky-2-2-decoder'

In [7]:
tokenizer = CLIPTokenizer.from_pretrained(prior_path, subfolder='tokenizer')
text_encoder = CLIPTextModelWithProjection.from_pretrained(
    prior_path, 
    subfolder='text_encoder',
    torch_dtype=torch.float16,
)

In [8]:
scheduler = DDPMScheduler.from_pretrained(decoder_path, subfolder='scheduler')
# scheduler = DDIMScheduler(**{
#     'beta_start': 0.00085,
#     'beta_end': 0.012,
#     'beta_schedule': "linear",
# })

In [9]:
movq = VQModel.from_pretrained(
    decoder_path, 
    subfolder="movq", 
    torch_dtype=torch.float16,
)

In [10]:
unet_additional_kwargs = {
    "use_motion_module": True,
    "motion_module_resolutions": [1, 2, 4, 8],
    "motion_module_mid_block": False,
    "motion_module_decoder_only": False,
    "motion_module_kwargs": {
        "num_layers": 2,
        "num_attention_heads": 8,
        "temporal_position_encoding": True,
        "temporal_position_encoding_max_len": 24,
    },
}

unet = UNet3DConditionModel.from_pretrained_2d(
    decoder_path, 
    subfolder="unet",
    unet_additional_kwargs=unet_additional_kwargs,
)
unet = unet.to(dtype=torch.float16)

loaded temporal unet's pretrained weights from ../models/kandinsky-2-2-decoder/unet ...
### missing keys: 1512; 
### unexpected keys: 0;
### Motion Module Parameters: 321.823488 M


In [11]:
motion_module = '../models/motion-modules/checkpoint-65000.ckpt'
motion_module_state_dict = torch.load(motion_module, map_location="cpu")

state_dict = {}
for name, tensor in motion_module_state_dict['state_dict'].items():
    state_dict[name.replace('module.', '')] = tensor
    
missing, unexpected = unet.load_state_dict(state_dict, strict=False)
assert len(unexpected) == 0
len(missing), len(unexpected)

(724, 0)

In [12]:
pipeline = KandimatePipeline(
    movq=movq, 
    text_encoder=text_encoder, 
    tokenizer=tokenizer, 
    unet=unet,
    scheduler=scheduler,
).to(device)

In [13]:
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
    prior_path, torch_dtype=torch.float16
).to(device)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [14]:
IMG_H = 768
IMG_W = 768
TOTAL_FRAMES = 16

guidance_scale = 4 
num_inference_steps = 25

seeds = [10788741199826055526, 16372571278361863751, 6519455744612555650]

prompts = [
    'pretty anime girl looking at the camera, cinematic, 4k, 8k',
    'pretty girl looking at the camera, cinematic, 8k, 4k',
    'pretty redhead girl looking at the camera, cinematic, extremely high detail, 8k, 4k, HQ',
#     'train driving down a mountain railroad, cinematic, 4k, 8k',
#     'iron man is landing, avengers, cinematic, 4k, 8k',
]
n_prompts = {
#     'no': '',
#     'short': 'low quality, bad quality',
    'long': 'lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers',
}

for seed in seeds:
    samples = []
    for neg_name, negative_prompt in n_prompts.items():
        for prompt in prompts:
            generator = torch.Generator(device=device).manual_seed(seed)

            image_emb, zero_image_emb = pipe_prior(
                prompt=prompt, negative_prompt=negative_prompt, generator=generator, 
            ).to_tuple()

            sample = pipeline(
                prompt,
                image_embeds = image_emb.to(dtype=torch.float16),
                negative_image_embeds = zero_image_emb.to(dtype=torch.float16),
                negative_prompt = negative_prompt,
                num_inference_steps = num_inference_steps,
                guidance_scale = guidance_scale,
                width = IMG_W,
                height = IMG_H,
                video_length = TOTAL_FRAMES,
                generator = generator,
                use_progress_bar = True,
            ).videos

            samples.append(sample)
            savedir = f'samples/generation/{seed}_{neg_name}'
            prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
            gif_path = f"{savedir}/{prompt}.gif"
            save_videos_grid(sample, gif_path)

    savedir = f'samples/generation'
    gif_path = f"{savedir}/{seed}_grid.gif"
    save_videos_grid(torch.cat(samples), gif_path, n_rows=4)

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]