<a href="https://colab.research.google.com/github/MohammadNouman17/Digit-Recognition/blob/main/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo for paper "First Order Motion Model for Image Animation"
To try the demo, press the 2 play buttons in order and scroll to the bottom. Note that it may take several minutes to load.

In [None]:
import IPython.display
import PIL.Image
import cv2
import ffmpeg
import imageio
import io
import ipywidgets
import numpy
import os.path
import requests
import skimage.transform
import warnings
from base64 import b64encode
from demo import load_checkpoints, make_animation  # type: ignore (local file)
from google.colab import files, output
from IPython.display import HTML, Javascript
from shutil import copyfileobj
from skimage import img_as_ubyte
from tempfile import NamedTemporaryFile
from tqdm.auto import tqdm

# Suppress warnings
warnings.filterwarnings("ignore")

# Ensure the "user" directory exists
os.makedirs("user", exist_ok=True)

# Inject custom CSS for the UI
display(HTML("""
<style>
.widget-box > * {
    flex-shrink: 0;
}
.widget-tab {
    min-width: 0;
    flex: 1 1 auto;
}
.widget-tab .p-TabBar-tabLabel {
    font-size: 15px;
}
.widget-upload {
    background-color: tan;
}
.widget-button {
    font-size: 18px;
    width: 160px;
    height: 34px;
    line-height: 34px;
}
.widget-dropdown {
    width: 250px;
}
.widget-checkbox {
    width: 650px;
}
.widget-checkbox + .widget-checkbox {
    margin-top: -6px;
}
.input-widget .output_html {
    text-align: center;
    width: 266px;
    height: 266px;
    line-height: 266px;
    color: lightgray;
    font-size: 72px;
}
.title {
    font-size: 20px;
    font-weight: bold;
    margin: 12px 0 6px 0;
}
.warning {
    display: none;
    color: red;
    margin-left: 10px;
}
.warn {
    display: initial;
}
.resource {
    cursor: pointer;
    border: 1px solid gray;
    margin: 5px;
    width: 160px;
    height: 160px;
    min-width: 160px;
    min-height: 160px;
    max-width: 160px;
    max-height: 160px;
    -webkit-box-sizing: initial;
    box-sizing: initial;
}
.resource:hover {
    border: 6px solid crimson;
    margin: 0;
}
.selected {
    border: 6px solid seagreen;
    margin: 0;
}
.input-widget {
    width: 266px;
    height: 266px;
    border: 1px solid gray;
}
.input-button {
    width: 268px;
    font-size: 15px;
    margin: 2px 0 0;
}
.output-widget {
    width: 256px;
    height: 256px;
    border: 1px solid gray;
}
.output-button {
    width: 258px;
    font-size: 15px;
    margin: 2px 0 0;
}
.uploaded {
    width: 256px;
    height: 256px;
    border: 6px solid seagreen;
    margin: 0;
}
.label-or {
    align-self: center;
    font-size: 20px;
    margin: 16px;
}
.loading {
    align-items: center;
    width: fit-content;
}
.loader {
    margin: 32px 0 16px 0;
    width: 48px;
    height: 48px;
    min-width: 48px;
    min-height: 48px;
    max-width: 48px;
    max-height: 48px;
    border: 4px solid whitesmoke;
    border-top-color: gray;
    border-radius: 50%;
    animation: spin 1.8s linear infinite;
}
.loading-label {
    color: gray;
}
.video {
    margin: 0;
}
.comparison-widget {
    width: 256px;
    height: 256px;
    border: 1px solid gray;
    margin-left: 2px;
}
.comparison-label {
    color: gray;
    font-size: 14px;
    text-align: center;
    position: relative;
    bottom: 3px;
}
@keyframes spin {
    from { transform: rotate(0deg); }
    to { transform: rotate(360deg); }
}
</style>
"""))

# Function to get the thumbnail from a video file
def thumbnail(file):
    return imageio.get_reader(file, mode='I', format='FFMPEG').get_next_data()

# Function to create an image widget
def create_image(i, j):
    image_widget = ipywidgets.Image.from_file(f'demo/images/{i}{j}.png')
    image_widget.add_class('resource')
    image_widget.add_class(f'resource-image{i}{j}')
    return image_widget

# Function to create a video widget
def create_video(i):
    video_widget = ipywidgets.Image(
        value=cv2.imencode('.png', cv2.cvtColor(thumbnail(f'demo/videos/{i}.mp4'), cv2.COLOR_RGB2BGR))[1].tostring(),
        format='png'
    )
    video_widget.add_class('resource')
    video_widget.add_class(f'resource-video{i}')
    return video_widget

# Function to create a title widget
def create_title(title):
    title_widget = ipywidgets.Label(title)
    title_widget.add_class('title')
    return title_widget

# Functions for handling the output (download, convert, and back)
def download_output(button):
    complete.layout.display = 'none'
    loading.layout.display = ''
    files.download('output.mp4')
    loading.layout.display = 'none'
    complete.layout.display = ''

def convert_output(button):
    complete.layout.display = 'none'
    loading.layout.display = ''
    ffmpeg.input('output.mp4').output('scaled.mp4', vf='scale=1080x1080:flags=lanczos,pad=1920:1080:420:0').overwrite_output().run()
    files.download('scaled.mp4')
    loading.layout.display = 'none'
    complete.layout.display = ''

def back_to_main(button):
    complete.layout.display = 'none'
    main.layout.display = ''

# UI setup for selecting images and videos
label_or = ipywidgets.Label('or')
label_or.add_class('label-or')

image_titles = ['Peoples', 'Cartoons', 'Dolls', 'Game of Thrones', 'Statues']
image_lengths = [8, 4, 8, 9, 4]

image_tab = ipywidgets.Tab()
image_tab.children = [ipywidgets.HBox([create_image(i, j) for j in range(length)]) for i, length in enumerate(image_lengths)]
for i, title in enumerate(image_titles):
    image_tab.set_title(i, title)

input_image_widget = ipywidgets.Output()
input_image_widget.add_class('input-widget')
upload_input_image_button = ipywidgets.FileUpload(accept='image/*', button_style='primary')
upload_input_image_button.add_class('input-button')
image_part = ipywidgets.HBox([ipywidgets.VBox([input_image_widget, upload_input_image_button]), label_or, image_tab])

video_tab = ipywidgets.Tab()
video_tab.children = [ipywidgets.HBox([create_video(i) for i in range(5)])]
video_tab.set_title(0, 'All Videos')

input_video_widget = ipywidgets.Output()
input_video_widget.add_class('input-widget')
upload_input_video_button = ipywidgets.FileUpload(accept='video/*', button_style='primary')
upload_input_video_button.add_class('input-button')
video_part = ipywidgets.HBox([ipywidgets.VBox([input_video_widget, upload_input_video_button]), label_or, video_tab])

# Model dropdown and settings
model = ipywidgets.Dropdown(
    description="Model:",
    options=['vox', 'vox-adv', 'taichi', 'taichi-adv', 'nemo', 'mgif', 'fashion', 'bair']
)
warning = ipywidgets.HTML('<b>Warning:</b> Upload your own images and videos (see README)')
warning.add_class('warning')
model_part = ipywidgets.HBox([model, warning])

relative = ipywidgets.Checkbox(description="Relative keypoint displacement", value=True)
adapt_movement_scale = ipywidgets.Checkbox(description="Adapt movement scale", value=True)
generate_button = ipywidgets.Button(description="Generate", button_style='primary')

# Main layout for the interface
main = ipywidgets.VBox([create_title('Choose Image'), image_part, create_title('Choose Video'), video_part, create_title('Settings'), model_part, relative, adapt_movement_scale, generate_button])

# Loading and progress display
loader = ipywidgets.Label()
loader.add_class("loader")
loading_label = ipywidgets.Label("This may take several minutes to process…")
loading_label.add_class("loading-label")
progress_bar = ipywidgets.Output()
loading = ipywidgets.VBox([loader, loading_label, progress_bar])
loading.add_class('loading')

# Final output display
output_widget = ipywidgets.Output()
output_widget.add_class('output-widget')
download = ipywidgets.Button(description='Download', button_style='primary')
download.add_class('output-button')
download.on_click(download_output)

convert = ipywidgets.Button(description='Convert to 1920×1080', button_style='primary')
convert.add_class('output-button')
convert.on_click(convert_output)

complete = ipywidgets.VBox([output_widget, download, convert, ipywidgets.Button(description='Back', button_style='primary', on_click=back_to_main)])

# Function to generate animation
def generate(button):
    selected_image = input_image_widget.value
    selected_video = input_video_widget.value
    selected_model = model.value
    keypoints = relative.value
    scale = adapt_movement_scale.value

    # Load the selected checkpoints for the chosen model
    print(f"Loading model: {selected_model}")
    checkpoints = load_checkpoints(selected_model)

    # Make animation using the selected inputs
    output_video = make_animation(selected_image, selected_video, checkpoints, keypoints, scale)

    # Save generated video and display it
    imageio.mimsave('output.mp4', [img_as_ubyte(frame) for frame in output_video], fps=24)
    output_widget.clear_output(True)
    with output_widget:
        video_widget = ipywidgets.Video.from_file('output.mp4', autoplay=False, loop=False)
        video_widget.add_class('video')
        display(video_widget)

    complete.layout.display = 'initial'
    main.layout.display = 'none'

# Set button action
generate_button.on_click(generate)

# Display the UI
display(main)
