# Animation with Gradio UI

In [None]:
#@title Mount Google Drive
try:
    from google.colab import drive
    drive.mount('/content/gdrive')
    outputs_path = "/content/gdrive/MyDrive/AI/StabilityAnimations"
    !mkdir -p $outputs_path
except:
    outputs_path = "."
print(f"Animations will be saved to {outputs_path}")

In [2]:
#@title Connect to the Stability API

import cv2
import datetime
import gradio as gr
import imageio
import json
import numpy as np
import os
import param
import shutil
import subprocess
import sys

from base64 import b64encode
from collections import OrderedDict
from IPython import display
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from types import SimpleNamespace
from typing import Dict, Generator, List, Optional, Union, Any, Sequence, Tuple


# install Stability SDK for Python
path = Path('stability-sdk')
if os.path.exists("../src/stability_sdk"):    
    sys.path.append("../src") # use local SDK src
else:
    if path.exists():
        shutil.rmtree(path)
        sub_p_res = subprocess.run(['pip', 'uninstall', '-y', 'stability-sdk'], stdout=subprocess.PIPE).stdout.decode('utf-8')
        print(sub_p_res)
    sub_p_res = subprocess.run(['git', 'clone', '-b', 'pharma.anims', '--recurse-submodules', 'https://github.com/Stability-AI/stability-sdk'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(sub_p_res)
    Path("./stability-sdk/src/stability_sdk/interfaces/__init__.py").touch()
    sub_p_res = subprocess.run(['pip', 'install', './stability-sdk'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(sub_p_res)

from stability_sdk import client
from stability_sdk.animation import (
    Animator,
    BasicSettings,
    AnimationSettings,
    KeyframedSettings,
    CoherenceSettings,
    ColorSettings,
    DepthwarpSettings,
    VideoInputSettings,
    AnimationArgs,
)


# GRPC endpoint and engines
GRPC_HOST = "" #@param {type:"string"}
API_KEY = "" #@param {type:"string"}

# Connect to Stability API
channel = client.open_channel(GRPC_HOST, api_key=API_KEY)
api = client.Api(channel=channel)

In [None]:
#@title Gradio UI
show_ui_in_notebook = True #@param {type:"boolean"}
create_shareable_link = False #@param {type:"boolean"}

DATA_VERSION = "0.1"
DATA_GENERATOR = "alpha-test-notebook"

args_basic = BasicSettings()
args_anim = AnimationSettings()
args_kf = KeyframedSettings()
args_cohere = CoherenceSettings()
args_color = ColorSettings()
args_depth = DepthwarpSettings()
args_vid = VideoInputSettings()
arg_objs = (
    args_basic,
    args_anim,
    args_kf,
    args_cohere,
    args_color,
    args_depth,
    args_vid,
)
all_settings_controls = []
animation_prompts = "{ 0: \"\" }"
negative_prompt = "blurry, low resolution"
negative_prompt_weight = -1.0
interrupt = False

projects: List[client.Project] = []
project: client.Project = None


def ensure_api():
    if api is None:
        raise gr.Error("Not connected to Stability API")

def frames_to_video(frames_path: str, mp4_path: str, max_frames: int = 0, fps: int = 24):
    image_path = os.path.join(frames_path, "frame_%05d.png")

    cmd = [
        'ffmpeg',
        '-y',
        '-vcodec', 'png',
        '-r', str(fps),
        '-start_number', str(0),
        '-i', image_path,
        '-frames:v', str(max_frames),
        '-c:v', 'libx264',
        '-vf',
        f'fps={fps}',
        '-pix_fmt', 'yuv420p',
        '-crf', '17',
        '-preset', 'veryfast',
        mp4_path
    ]
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()
    if process.returncode != 0:
        print(stderr)
        raise RuntimeError(stderr)

def get_default_project():
    data = {
        "version": DATA_VERSION,
        "generator": DATA_GENERATOR
    }
    return data   

def reset_args_to_defaults():
    for args in arg_objs:
        for k, v in args.param.objects().items():
            if k == "name":
                continue
            setattr(args, k, v.default)

def accordion_from_args(name: str, args: param.Parameterized):
    with gr.Accordion(name, open=False):
        for k, v in args.param.objects().items():
            if k == "name":
                continue
            elif isinstance(v, param.Integer):
                t = gr.Number(label=v.label, value=v.default, interactive=True, precision=0)
            elif isinstance(v, param.ObjectSelector):
                t = gr.Dropdown(label=v.label, choices=v.objects, value=v.default, interactive=True)
            elif isinstance(v, param.Boolean):
                t = gr.Checkbox(label=v.label, value=v.default, interactive=True)
            elif isinstance(v, param.String):
                t = gr.Text(label=v.label, value=v.default, interactive=True)
            elif isinstance(v, param.Number):
                t = gr.Number(label=v.label, value=v.default, interactive=True)
            t.change(lambda v, k=k: setattr(args, k, v), inputs=t)
            all_settings_controls.append(t)

def args_to_controls() -> dict:
    returns = {}

    for args in arg_objs:
        for k, v in args.param.objects().items():
            if k == "name":
                continue
            for c in all_settings_controls:
                if c.label.lower() == v.label.lower():
                    returns[c] = gr.update(value=getattr(args, k))
                    break

    for c in all_settings_controls:
        if c.label.lower() == "animation prompts":
            returns[c] = gr.update(value=animation_prompts)
        elif c.label.lower() == "negative prompt":
            returns[c] = gr.update(value=negative_prompt)

    return returns

def render_tab():
    with gr.Row():
        with gr.Column():
            accordion_from_args("Generation", args_basic)
            accordion_from_args("Animation", args_anim)
            accordion_from_args("Camera", args_kf)
            accordion_from_args("Coherence", args_cohere)
            accordion_from_args("Color", args_color)
            accordion_from_args("Depth", args_depth)
            accordion_from_args("Video Input", args_vid)
            
            with gr.Accordion("Prompts", open=True):
                global animation_prompts, negative_prompt
                prompt = gr.TextArea(label="Animation prompts", max_lines=8, value=animation_prompts, interactive=True)
                def update_prompts(v):
                    global animation_prompts
                    animation_prompts = v
                def update_negative(v):
                    global negative_prompt
                    negative_prompt = v
                negative = gr.Textbox(label="Negative prompt", max_lines=1, value=negative_prompt, interactive=True)
                prompt.change(update_prompts, inputs=prompt)
                negative.change(update_negative, inputs=negative)
                all_settings_controls.extend([prompt, negative])

            button = gr.Button("Generate")
            button_stop = gr.Button("Stop", visible=False)

        with gr.Column():
            image_out = gr.Image(label="image", visible=True)
            video_out = gr.Video(label="video", visible=False)

    def generate():
        global animation_prompts, interrupt, project
        if not project:
            raise gr.Error("No project active!")
        
        # create local folder for the project
        project_folder_name = project.title.replace("/", "_").replace("\\", "_").replace(":", "")
        outdir = os.path.join(outputs_path, project_folder_name)
        os.makedirs(outdir, exist_ok=True)

        # each render gets a unique run index
        run_index = 0
        while True:
            project_settings_path = os.path.join(outdir, f"{project_folder_name} ({run_index}).json")
            if not os.path.exists(project_settings_path):
                break
            run_index += 1

        # gather up all the settings from sub-objects
        args_d = {}
        [args_d.update(a.param.values()) for a in arg_objs]
        args = AnimationArgs(**args_d)

        try:
            prompts = json.loads(animation_prompts)
        except json.JSONDecodeError:
            try:
                prompts = eval(animation_prompts)
            except Exception as e:
                raise gr.Error(f"Invalid JSON or Python code for animation_prompts!")
        prompts = {int(k): v for k, v in prompts.items()}

        # save settings to a dict
        save_dict = OrderedDict()
        save_dict['version'] = DATA_VERSION
        save_dict['generator'] = DATA_GENERATOR
        save_dict.update(args.param.values())
        for k in ['angle', 'zoom', 'translation_x', 'translation_y', 'translation_z', 'rotation_x', 'rotation_y', 'rotation_z']:
            save_dict.move_to_end(k, last=True)
        save_dict['animation_prompts'] = prompts
        save_dict['negative_prompt'] = negative_prompt
        project.save_settings(save_dict)
        with open(project_settings_path, 'w') as f:
            json.dump(save_dict, f, indent=4)

        animator = Animator(
            api=api,
            animation_prompts=prompts,
            args=args,
            out_dir=outdir,    
            negative_prompt=negative_prompt,
            negative_prompt_weight=negative_prompt_weight,
            resume=False,
        )

        for frame_idx, frame in enumerate(tqdm(animator.render(), initial=animator.start_frame_idx, total=args.max_frames)):
            if interrupt:
                break

            # saving frames to project
            #frame_uuid = project.put_image_asset(frame)

            yield {
                button: gr.update(visible=False),
                button_stop: gr.update(visible=True),
                image_out: gr.update(value=frame, label=f"frame {frame_idx}/{args.max_frames}", visible=True),
                video_out: gr.update(visible=False)
            }

        interrupt = False
        output_video = project_settings_path.replace(".json", ".mp4")
        frames_to_video(outdir, output_video, max_frames=frame_idx+1, fps=24)
        yield {
            button: gr.update(visible=True),
            button_stop: gr.update(visible=False),
            image_out: gr.update(visible=False),
            video_out: gr.update(value=output_video, visible=True),
        }        
    
    button.click(generate, inputs=[], outputs=[button, button_stop, image_out, video_out])

    def stop():
        global interrupt
        interrupt = True
        yield {
            button: gr.update(visible=True),
            button_stop: gr.update(visible=False),
        }
    button_stop.click(stop, inputs=[], outputs=[button, button_stop])


with gr.Blocks() as demo:
    gr.Markdown("Stability Animations Alpha Test")

    with gr.Tab("Project"):
        with gr.Column(variant="panel"):
            gr.Markdown("Create a new project")
            with gr.Row():
                new_project_title = gr.Text(label="Name", value="My amazing animation", interactive=True)
                button_create_project = gr.Button("Create")
        button_load_projects = gr.Button("Load Projects")
        with gr.Column(visible=False, variant="panel") as projects_row:
            gr.Markdown("Existing projects")
            with gr.Row():
                projects_dropdown = gr.Dropdown([p.title for p in projects], label="Projects", visible=True, interactive=True)
                with gr.Column():
                    button_load_project = gr.Button("Load")
                    button_delete_project = gr.Button("Delete")
        project_data_log = gr.Textbox(label="Status", visible=False)

        def create_project(title):
            ensure_api()
            global project, projects
            titles = [p.title for p in projects]
            if title in titles:
                raise gr.Error(f"Project with title '{title}' already exists")
            project = client.Project.create(api, title)
            settings = get_default_project()
            project.save_settings(settings)
            projects = client.Project.list_projects(api)
            log = f"Created project '{title}' with id {project.id}\n{json.dumps(settings)}"

            reset_args_to_defaults()
            returns = args_to_controls()
            returns[project_data_log] = gr.update(value=log, visible=True)
            returns[projects_dropdown] = gr.update(choices=[p.title for p in projects], visible=True, value=title)
            returns[projects_row] = gr.update(visible=len(projects) > 0)
            return returns

        def delete_project(title: str):
            ensure_api()
            global project, projects
            project = next(p for p in projects if p.title == title)
            project.delete()
            log = f"Deleted project '{title}' with id {project.id}"
            projects = client.Project.list_projects(api)
            project = None
            return {
                projects_dropdown: gr.update(choices=[p.title for p in projects], visible=True),
                projects_row: gr.update(visible=len(projects) > 0),
                project_data_log: gr.update(value=log, visible=True)
            }

        def load_projects():
            ensure_api()
            global projects
            projects = client.Project.list_projects(api)
            return {
                button_load_projects: gr.update(visible=len(projects)==0),
                projects_dropdown: gr.update(choices=[p.title for p in projects], visible=True),
                projects_row: gr.update(visible=len(projects) > 0)
            }

        def load_project(title: str):
            ensure_api()
            global project, animation_prompts, negative_prompt
            project = next(p for p in projects if p.title == title)
            data = project.load_settings()
            log = f"Loaded project '{title}' with id {project.id}\n{json.dumps(data, indent=4)}"

            # go through all the parameters and load their settings from the data
            for arg in arg_objs:
                for k, v in arg.param.objects().items():
                    if k != "name" and k in data:
                        arg.param.set_param(k, data[k])
            if "animation_prompts" in data:
                animation_prompts = data["animation_prompts"]
                animation_prompts = {int(k): v for k, v in animation_prompts.items()}
            if "negative_prompt" in data:
                negative_prompt = data["negative_prompt"]

            # update the ui controls
            returns = args_to_controls()
            returns[project_data_log] = gr.update(value=log, visible=True)
            return returns

    with gr.Tab("Render"):
        render_tab()

    create_project_outputs = [projects_dropdown, projects_row, project_data_log]
    create_project_outputs.extend(all_settings_controls)

    load_project_outputs = [project_data_log]
    load_project_outputs.extend(all_settings_controls)

    button_create_project.click(create_project, inputs=new_project_title, outputs=create_project_outputs)
    button_load_projects.click(load_projects, outputs=[button_load_projects, projects_dropdown, projects_row])
    button_load_project.click(load_project, inputs=projects_dropdown, outputs=load_project_outputs)
    button_delete_project.click(delete_project, inputs=projects_dropdown, outputs=[projects_dropdown, projects_row, project_data_log])

demo.queue(concurrency_count=1)
demo.launch(show_api=False, debug=True, inline=show_ui_in_notebook, height=768, share=create_shareable_link)