In [None]:
!nvidia-smi

Failed to initialize NVML: Unknown Error


In [1]:
from IPython.display import Video
from deforum_kandinsky import KandinskyV22Img2ImgPipeline, DeforumKandinsky
from diffusers import KandinskyV22PriorPipeline
from transformers import CLIPVisionModelWithProjection
from diffusers.models import UNet2DConditionModel
import imageio.v2 as iio
from PIL import Image
import numpy as np
import torch
from tqdm.notebook import tqdm
import ipywidgets as widgets
from tqdm.notebook import tqdm 
from IPython import display
from ipywidgets import Output



In [2]:
#  create video from generated frames
def frames2video(frames, output_path="video.mp4", fps=24, display=False):
    writer = iio.get_writer(output_path, fps=fps)
    for frame in tqdm(frames):
        writer.append_data(np.array(frame))
    writer.close()
    if display:
        Video(url=output_path)

## Kandinsky 2.2

In [3]:
device = "cuda"
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    'kandinsky-community/kandinsky-2-2-prior', 
    subfolder='image_encoder'
    ).to(torch.float16).to(device)

unet = UNet2DConditionModel.from_pretrained(
    'kandinsky-community/kandinsky-2-2-decoder', 
    subfolder='unet'
    ).to(torch.float16).to(device)

prior = KandinskyV22PriorPipeline.from_pretrained(
    'kandinsky-community/kandinsky-2-2-prior', 
    image_encoder=image_encoder, 
    torch_dtype=torch.float16,
    ).to(device)
decoder = KandinskyV22Img2ImgPipeline.from_pretrained(
    'kandinsky-community/kandinsky-2-2-decoder', 
    unet=unet,
    torch_dtype=torch.float16
    ).to(device)

prior.set_progress_bar_config(disable=True)
decoder.set_progress_bar_config(disable=True)

## Define instance of Deforum

In [4]:
# define instance of Deforum
deforum = DeforumKandinsky(
    prior=prior,
    decoder_img2img=decoder,
    device='cuda'
)

## Create Default Animation

In [9]:
def create_animation_widgets():
    prompt = widgets.Text(
        description='Prompt:', 
        layout=widgets.Layout(width='80%')
    )
    negative_prompt = widgets.Text(
        description='Neg Prompt:', 
        value="low quility, bad image, cropped, out of frame",
        layout=widgets.Layout(width='80%')
    )
    duration = widgets.FloatSlider(
        description='Duration:', 
        min=0.25, max=60, 
        value=5, step=0.25, 
        layout=widgets.Layout(width='80%')
    )
    animation = widgets.RadioButtons(
        options=["right", "left", "up", "down", "spin_clockwise", "spin_counterclockwise", "zoomin", "zoomout" , "live"], 
        description='Animation Mode:'
        )

    animation = widgets.Dropdown(
        options=["right", "left", "up", "down", "spin_clockwise", "spin_counterclockwise", "zoomin", "zoomout" , "live"],
        value="right",
        description='Number:',
    )
    return widgets.VBox(children=(prompt,negative_prompt, duration, animation))

def create_video_settings():
    return widgets.VBox(children=[
        widgets.HTML("<h2>Video Settings</h2>"),
        widgets.BoundedIntText(
            min=64,
            max=1e6,
            step=64,
            value=640,
            description='Width:',
            disabled=False
        ),
        widgets.BoundedIntText(
            min=64,
            max=1e6,
            step=64,
            value=640,
            description='Height:',
            disabled=False
        ),
        widgets.IntSlider(
            description='FPS',
            min=1, 
            max=48, 
            value=24, 
            step=1
        ),
        widgets.Text(
            description='output path:', 
            value = "video.mp4",
        )
    ])

def create_animation_tabs():
    an_widgets = widgets.Tab(layout=widgets.Layout(width='90%', height='100%'))
    an_widgets.children = [create_animation_widgets()]

    def update(a, an_widgets=an_widgets):
        if an_widgets.children[-1].children[0].value:
            an_widgets.children += (create_animation_widgets(),)
        for index, child in enumerate(an_widgets.children):
            an_widgets.set_title(index, child.children[0].value)

    def clear(a, an_widgets=an_widgets):
        children = list(an_widgets.children)
        children.pop(an_widgets.selected_index)
        an_widgets.children = tuple(children)
        an_widgets.set_title(0, "")

    add_button = widgets.Button(
        description='Add Animation',
        layout=widgets.Layout(width='44.75%')
    )
    add_button.style.button_color = "blue"
    add_button.on_click(update)

    clear_button = widgets.Button(
        description='Remove Animation',
        layout=widgets.Layout(width='44.75%')
    )
    clear_button.style.button_color = "red"
    clear_button.on_click(clear)

    return widgets.VBox([
        widgets.HTML("<h2>Animations</h2>"), 
        an_widgets,
        widgets.HBox([add_button, clear_button])
    ])


def create_start_button(animation_tabs, video_widgets, deforum, animation_display):
    def render_deforum(animation, animation_display, output_path):
        frames = []
        for frame, current_params in animation:
            frames.append(frame)
            with animation_display:
                display.clear_output(wait=True)
                display.display(frame)
        
        if output_path and output_path.endswith(".mp4"):
            frames2video(frames, output_path)
        else: 
            frames2video(frames)
            
    def parse_args(_):
        children = animation_tabs.children[1].children
        prompts = []
        negative_prompts = [] 
        durations = []
        animations = []
        for child in children:
            prompt, negative_prompt, duration, animation = [x.value for x in child.children]
            if prompt: 
                prompts.append(prompt)
                negative_prompts.append(negative_prompt)
                durations.append(int(duration))
                animations.append(animation)
                
        width, height, fps = [int(x.value) for x in video_widgets.children[1:-1]]
        output_path = video_widgets.children[-1].value
        animation = deforum(
            prompts=prompts,
            negative_prompts=negative_prompts, 
            animations=animations, 
            prompt_durations=durations,
            H=height,
            W=width,
            fps=fps,
            sampler="euler"
        )
        animation = tqdm(animation, total=len(deforum))
        render_deforum(animation, animation_display, output_path)

    button = widgets.Button(
        description='Start Rendering!', 
        layout=widgets.Layout(width='90%')
    )
    button.on_click(parse_args)
    return button

In [10]:
video_widgets = create_video_settings()
animation_tabs = create_animation_tabs()
animation_display = widgets.Output()
start_button = create_start_button(animation_tabs, video_widgets, deforum, animation_display)
display.display(animation_display, video_widgets, animation_tabs, start_button)

In [None]:
# define instance of Deforum
deforum = DeforumKandinsky(
    prior=prior,
    decoder_img2img=decoder,
    device='cuda'
)

            
animation = deforum(
    prompts=[
        "winter forest, snowflakes, Van Gogh style",
        "spring forest, flowers, sun rays, Van Gogh style",
        "summer forest, lake, reflections on the water, summer sun, Van Gogh style",
        "autumn forest, rain, Van Gogh style",
        "winter forest, snowflakes, Van Gogh style",
    ], 
    animations=['live', 'right', 'right', 'right', 'live'], 
    prompt_durations=[1, 1, 1, 1, 1],
    H=640,
    W=640,
    fps=24,
    save_samples=False,
    linear_transition=True,
)

frames = []

out = Output()
pbar = tqdm(animation, total=len(deforum))
display.display(out)
for frame, current_params in pbar:
    frames.append(frame)
    with out:
        display.clear_output(wait=True)
        display.display(frame)
        for key, value in current_params.items():
            print(f"{key}: {value}")

In [None]:
frames2video(frames, "output_2_2.mp4", fps=4)
Video("output_2_2.mp4")