In [1]:
#@title Install requirements
import base64
from io import BytesIO
import json
import mimetypes
import os
from PIL import Image
import requests
import time

In [2]:
#@title Define helper functions

def image_to_bytes(
        img: Image,
        format="PNG"
):
    im_file = BytesIO()
    img.save(im_file, format=format)
    img_bytes = im_file.getvalue()
    return img_bytes

def get_image_format(
        image_path : str
):
    image_mime_type = mimetypes.guess_type(image_path)[0]
    if image_mime_type is None:
        raise ValueError(f"Unknown image mime type for {image_path}")
    image_format = image_mime_type.split("/")[-1].upper()
    return image_format

def resize_image_to_bytes(
        image_path : str,
        size : tuple[int,int] = None
):
    # Resize image from file and convert to bytes
    image = Image.open(image_path)
    format = get_image_format(image_path)
    if size is None:
        width, height = get_closest_valid_dims(image)
    else:
        width, height = size
    image = image.resize((width, height))
    image_bytes = image_to_bytes(image, format=format)
    return image_bytes

def get_closest_valid_dims(
        image : Image
):
    # Finds the closest aspect ratio to the input image that are valid for SSC
    # Valid dims are 1024x576, 768x768, 1024
    w,h = image.size
    aspect_ratio = w/h
    portrait_aspect_ratio = 9/16
    landscape_aspect_ratio = 16/9
    portrait_aspect_ratio_midpoint = (portrait_aspect_ratio + 1)/2
    landscape_aspect_ratio_midpoint = (landscape_aspect_ratio + 1)/2
    if aspect_ratio < 1.0:
        # portrait
        width,height = (576,1024) if aspect_ratio < portrait_aspect_ratio_midpoint else (768,768)
    else:
        # landscape
        width,height = (1024,576) if aspect_ratio < landscape_aspect_ratio_midpoint else (768,768)

    return width, height


def image_to_valid_bytes(
        image_path : str
        ):
    image = Image.open(image_path)
    width, height = get_closest_valid_dims(image)
    format = get_image_format(image_path)
    image = image.resize((width, height))
    image_bytes = image_to_bytes(image, format=format)
    return image_bytes

In [None]:
#@title Set up credentials

import getpass
# @markdown To get your API key visit https://platform.stability.ai/account/keys
STABILITY_KEY = getpass.getpass('Enter your API Key')



In [31]:
#@title Define input

#@markdown - Drag and drop image to file folder on left
#@markdown - Right click it and choose Copy path
#@markdown - Paste that path into init_image field below
#@markdown <br><br>

init_image = "/content/img.jpg" #@param {type:"string"}
seed = 0 #@param {type:"integer"}
cfg_scale = 2.5 #@param {type:"number"}
motion_bucket_id = 40 #@param {type:"integer"}

In [None]:
#@title Use REST API

headers = {
    "Accept": "application/json",
    "Authorization": f"Bearer {STABILITY_KEY}"
}
host = f"https://api.stability.ai/v2alpha/generation/image-to-video"

init_image_bytes = image_to_valid_bytes(init_image)
image_mime_type = mimetypes.guess_type(init_image)[0]
files = {
    "image": ("file", init_image_bytes, image_mime_type),
    }
params = {
    "seed": seed,
    "cfg_scale": cfg_scale,
    "motion_bucket_id": motion_bucket_id
    }

for k,v in params.items():
    if isinstance(v, bool):
        v = str(v).lower()
    files[k] = (None, str(v).encode('utf-8'))

print(f"Sending REST request to {host}...")

response = requests.post(
        host,
        headers=headers,
        files=files,
    )

if not response.ok:
    raise Exception(f"HTTP {response.status_code}: {response.text}")

#
# Process async response
#
response_dict = json.loads(response.text)
request_id = response_dict.get("id", None)
assert request_id is not None, "Expected id in response"

# Loop until video result or timeout
timeout = int(os.getenv("WORKER_TIMEOUT", 500))
start = time.time()
status_code = 202
while status_code == 202:

    response = requests.get(
        f"{host}/result/{request_id}",
        headers=headers,
    )

    if not response.ok:
        raise Exception(f"HTTP {response.status_code}: {response.text}")
    status_code = response.status_code
    time.sleep(2)
    if time.time() - start > timeout:
        raise Exception(f"Timeout after {timeout} seconds")



In [33]:
#@title Decode response
json_data = response.json()

video = base64.b64decode(json_data["video"])
seed = json_data["seed"]
finish_reason = json_data["finishReason"]

if finish_reason == 'CONTENT_FILTERED':
    raise Warning("Video failed NSFW classifier")

In [None]:
#@title Save and display result

filename = f"video_{seed}.mp4"
with open(filename, "wb") as f:
    f.write(video)
print(f"Saved video {filename}")

import IPython
mp4 = open(filename,'rb').read()
data_url = f"data:video/mp4;base64," + base64.b64encode(mp4).decode()
IPython.display.display(IPython.display.HTML(f'<video controls loop><source src="{data_url}" type="video/mp4"></video>'))