<a href="https://colab.research.google.com/github/RichardSlater/ai-ml-playground/blob/main/vqgan%2Bclip/ai_animating_intermediate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Making Animations with cutting edge Artificial Inteligence

This notebook uses VQGAN+CLIP to generate the frames and a series of animation techniques on top of it to zoom, rotate, and shift by x (left/right), y (up/down) pixels.  Optionally you can also use SRCNN to upscale the resulting images and Super-SloMo to create additional interpolated frames to create smoother motion.

## Technology

### Core Neural Networks

#### Vector Quantized Generative Adversarial Network (VQGAN)

Vector Quantized Generative Adversarial Network ("VQGAN") is a type of Generative Adveserial Network which pitches two neural networks against eachother in a "Generate and Review" relationship.  VQGAN is capable of taking text input and synthesising a picture or scene.

Read more in [Taming Transformers for High-resolution Image Synthesis](https://compvis.github.io/taming-transformers/).

#### Contrastive Language–Image Pre-training (CLIP)

Contrastive Language–Image Pre-training is a neural network which compares an image with a caption and decides how well that image matches the caption.  In this context CLIP is used to assess the images produced by the other Neural Network.

Read more in [CLIP: Connecting Text and Images](https://openai.com/blog/clip/).

### Additional Neural Networks

The following two neural networks are optional, you can happily create a video without them however by interpolating pixels and frames we can create higher resolutions and better quality videos.

#### Super-Resolution Convolutional Neural Network (SRCNN)

SRCNN is a convolutional neural network, in the sense that it uses the mathmatical concept of convolution to infer information about a subject.  SRCNN typically performs better than "normal" bicubic interpolation resulting in clearer images.

Read more in [Original paper on SRCNN by Dong et al. (Image Super-Resolution Using Deep Convolutional Networks)](https://github.com/Mirwaisse/SRCNN).

#### Super-SloMo

Super-SlowMo is another convolutional neural network that looks at two frames and creates intermediate frames allowing for a lower overall frame rate whilst maintining the quality.  As it takes between 2-3 minutes to create 100 iterations of a frame on a NVidia P100 it would be prohibitivly timeconsuming to use VQGAN+CLIP to create a 30fps video, therefore using Super-SlowMo helps close the gap.

Read more in "[Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation](https://arxiv.org/abs/1712.00080)" by Jiang H., Sun D., Jampani V., Yang M., Learned-Miller E. and Kautz J.

## Credit

Notebook by [Katherine Crowson](https://github.com/crowsonkb) [[Twitter](https://twitter.com/RiversHaveWings)]. Zoom, pan, rotation, and keyframes features by [Chigozie Nri](https://github.com/chigozienri) [[Twitter](https://twitter.com/chigozienri)]. Adapted by the A. I. Whisperer and [Richard Slater](https://github.com/RichardSlater/).

```
# Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
```

# Prepare Filesystem

In [None]:
#@title # Connect to Google Drive
#@markdown Use Google Drive to provide persistent storage, Google Colabotory's
#@markdown `/content` is ephemeral and is lost when the runtime disconnects.
#@markdown *Note:* Google Colab VMs are small, and network IO is slow.
from google.colab import drive
drive.mount('/content/drive')

In [7]:
#@title # Processing options

upscale_frames = True #@param {type:"boolean"}
crop_frames = True #@param {type:"boolean"}
generate_videos = True #@param {type:"boolean"}
super_slomo = False #@param {type:"boolean"}

In [None]:
#@title # Set filesystem variables
#@markdown | Variable | Usage |
#@markdown |---|---|
#@markdown | `working_dir` | Directory where Git repos, `/steps`, `video.mp4`, etc. are stored |
#@markdown | `backup_dir`  | Persistent directory for project files, make sure you **connect to Google Drive Above** |
#@markdown | `cache_dir`   | Persistent directory for caches of models to avoid repeatedly downloading them when the runtime is reset |
#@markdown | `models_dir`  | Location to keep the models at runtime, should be in the `/content` directory on Colab |

working_dir = '/content' #@param {type:"string"}
project_dir = '/content/drive/MyDrive/vqgan+clip/projects/testing' #@param {type:"string"}
cache_dir = '/content/drive/MyDrive/vqgan+clip/cache' #@param {type:"string"}
models_dir = '/content/models' #@param {type:"string"}

# Configuration
## Instructions for setting Animation parameters

| Parameter  |  Usage |
|---|---|
| `key_frames` | Using keyframes allows you to change animation parameters over time |
|  `text_prompts` |  Input to Neural Network to Generate an Frame based off of, you can seperate by "\|" to get different prompts |
| `width` | Width of the output, in pixels. This will be rounded down to a multiple of 16 |
| `height` | Height of the output, in pixels. This will be rounded down to a multiple of 16 |
| `trim_width` | Number of pixels to remove from the width |
| `trim_height` | Number of pixels to remove from the height |
| `vqgan_model` | Choice of model, must be downloaded in the download models cell |
| `interval` | How often to display the frame in the notebook (doesn't affect the actual output) |
| `initial_image` | Image to start with (relative path to file) |
| `target_images` | Image prompts to target, separated by "|" (relative path to files) |
| `seed` | Random seed, if set to a positive integer the run will be repeatable (get the same output for the same input each time, if set to -1 a random seed will be used. |
| `init_frame` | Frame to start from, allows resumption of a crashed session by uploading a backup of the last frame and setting `init_frame` and `initial_image` to the last good frame |
| `max_frames` | Number of frames for the animation |
| `angle` | Angle in degrees to rotate clockwise between each frame |
| `zoom` | Factor to zoom in each frame, 1 is no zoom, less than 1 is zoom out, more than 1 is zoom in (negative is uninteresting, just adds an extra 180 rotation beyond that in angle) |
| `translation_x` | Number of pixels to shift right each frame |
| `translation_y` | Number of pixels to shift down each frame |
| `iterations_per_frame` | Number of times to run the VQGAN+CLIP method each frame |
| `save_all_iterations` | Debugging, set False in normal operation |

## Process

On each frame, the network restarts, is fed a version of the output zoomed in by `zoom` as the initial image, rotated clockwise by `angle` degrees, translated horizontally by `translation_x` pixels, and translated vertically by `translation_y` pixels. Then it runs `iterations_per_frame` iterations of the VQGAN+CLIP method. 0 `iterations_per_frame` is supported, to help test out the transformations without changing the image.

For `iterations_per_frame = 1` (recommended for more abstract effects), the resulting images will not have much to do with the prompts, but at least one prompt is still required.

In normal use, only the last iteration of each frame will be saved, but for trouble-shooting you can set `save_all_iterations` to True, and every iteration of each frame will be saved.

![](https://raw.githubusercontent.com/RichardSlater/ai-ml-playground/main/vqgan%2Bclip/assets/vqgan%2Bclip-flow-intermediate.png)

### Resolutions

#### Common resolutions

These are common resolutions used for these platforms, in most cases they are on the low-resolution side and may need to be upscaled before the network accepts the video:
 - **TikTok resolution**: 340px by 570px
 - **YouTube (16:9) resolution**: 640px by 360px (360p) or 426px x 240px (240p)
 - **4:3 resolution**: 480px by 360px
 - **TikTok/Instagram (Square)**: 500px by 500px or 400px by 400px

You will find that any height or width that is not divisible by 16 is subject to change, you can entirely avoid this by selecting resolutions that that are disvisible by 16:
 - 144px by 256px
 - 288px by 512px
 - 432px by 768px (too large for a 16GB GPU)
 - 576px by 1024px (too large for a 24GB GPU)
 - 720px by 1280px (too large for a 40GB GPU)
 - 864px by 1536px (too large for a 40GB GPU)

Alternatively you can generate more pixels than are requried by 8px then crop the vestigal 8px off during post processing.

### Key Frames

If `key_frames` is set to True, you are able to change the parameters over the course of the run.
To do this, put the parameters in in the following format:
10:(0.5), 20: (1.0), 35: (-1.0)

This means at frame 10, the value should be 0.5, at frame 20 the value should be 1.0, and at frame 35 the value should be -1.0. The value at each other frame will be linearly interpolated (that is, before frame 10, the value will be 0.5, between frame 10 and 20 the value will increase frame-by-frame from 0.5 to 1.0, between frame 20 and 35 the value will decrease frame-by-frame from 1.0 to -1.0, and after frame 35 the value will be -1.0)

This also works for text_prompts, e.g. 10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)
will start with an Apple value of 1, once it hits frame 10 it will start decreasing in in Apple and increasing in Orange until it hits frame 20. Note that Peach will have a value of 1 the whole time.

If `key_frames` is set to True, all of the parameters which can be key-framed must be entered in this format.

In [None]:
#@title # Parameters for Model and Animating

key_frames = True #@param {type:"boolean"}
text_prompts = "0:(shouting at the moon night time scene in unreal engine hyperrealistic: 1)" #@param {type:"string"}
width = 368 #@param {type:"slider", min:0, max:1536, step:1}
height = 640 #@param {type:"slider", min:0, max:1536, step:1}
trim_width = 8 #@param {type:"slider", min:0, max:1536, step:1}
trim_height = 0 #@param {type:"slider", min:0, max:1536, step:1}
interval = 1 #@param {type:"slider", min:0, max:100, step:1}
initial_image = "" #@param {type:"string"}
target_images = "" #@param {type:"string"}
seed = 183264307 #@param {type:"number"}
max_frames = 5 #@param {type:"slider", min:1, max:600, step:1}
angle = "0: (0)" #@param {type:"string"}
zoom = "0: (1.05)" #@param {type:"string"}
translation_x = "0: (0)" #@param {type:"string"}
translation_y = "0: (0)" #@param {type:"string"}
iterations_per_frame = "0: (3)" #@param {type:"string"}
save_all_iterations = False #@param {type:"boolean"}
superscale_resolution = "3x" #@param ["2x", "3x", "4x"] {type:"string"}
target_fps = 12 #@param {type:"slider", min:0, max:30, step:1}

init_frame = 0

In [None]:
#@title # Configure Models
#@markdown By default, the notebook downloads S-FLCKR. There are others such as ImageNet 16385, ImageNet 1024, COCO-Stuff, WikiArt 1024, WikiArt 16384 or FacesHQ, which are not downloaded by default, since it would be in vain if you are not going to use them, so if you want to use them, simply select the models to download.

imagenet_1024 = False #@param {type:"boolean"}
imagenet_16384 = False #@param {type:"boolean"}
coco = False #@param {type:"boolean"}
faceshq = False #@param {type:"boolean"}
wikiart_16384 = False #@param {type:"boolean"}
sflckr = True #@param {type:"boolean"}

#@markdown this is the model that will actually be used 
vqgan_model = "sflckr" #@param ["vqgan_imagenet_f16_16384", "vqgan_imagenet_f16_1024", "wikiart_16384", "coco", "faceshq", "sflckr"]

#@markdown Note that some datasets are not compatible with commercial use:
#@markdown - imagenet_1024, imagenet_16384: research, non-commercial
#@markdown - coco: [various](https://github.com/nightrome/cocostuff#licensing)
#@markdown - faceshq: [CC-BY-NC-4.0](https://github.com/NVlabs/ffhq-dataset/blob/master/LICENSE.txt)
#@markdown - wikiart_16384: non-commercial
#@markdown - sflckr: public domain

model_names={
    "vqgan_imagenet_f16_16384": 'ImageNet 16384',
    "vqgan_imagenet_f16_1024":"ImageNet 1024", 
    "wikiart_1024":"WikiArt 1024",
    "wikiart_16384":"WikiArt 16384",
    "coco":"COCO-Stuff",
    "faceshq":"FacesHQ",
    "sflckr":"S-FLCKR"
}
model_name = model_names[vqgan_model]

----

You should not need to change anything below as all of the configuration has been moved above this line!

----

# Setup

In [None]:
#@title ## Setup Project
#@markdown We create a `NeuralProject` class as a utility class to simplify session save and resumption.

%pip install pyyaml

from os.path import join, isfile, split
from os import makedirs
from shutil import copy
import yaml


def dump(obj):
    for attr in dir(obj):
        if attr.startswith('__'):
            continue
        print("obj.%s = %r" % (attr, getattr(obj, attr)))


class NeuralProject:
    __project_file = None
    __project_dir = None
    __gpus = []
    __frames = {}
    __outputs = {}
    __config = {}


    def __init__(self, project_dir):
        self.__project_dir = project_dir
        self.__project_file = join(project_dir, 'project.yaml')
        makedirs(project_dir, exist_ok = True)


    def exists(self):
        return isfile(self.__project_file)


    def __get_stage_dir(self, stage):
        return join(self.__project_dir, stage)


    def add_gpu(self, gpu, uuid):
        gpu_uuid = f'{gpu}: {uuid}'
        
        if len(uuid) == 0:
            return

        if self.__gpus.count(gpu_uuid) == 0:
            self.__gpus.append(gpu_uuid)


    def add_artifact(self, stage, frame, artifact_file):
        self.__frames[stage] = frame

        stage_dir = self.__get_stage_dir(stage)
        makedirs(stage_dir, exist_ok=True)

        artifact_filename = split(artifact_file)[1]
        backup_file = join(stage_dir, artifact_filename) 
        copy(artifact_file, backup_file)

        self.__save_project()

        print(f'Stored frame {frame} from {stage} ({artifact_file} => {backup_file})')


    def add_output(self, stage, output_file):
        if not stage in self.__outputs:
            self.__outputs[stage] = []
        
        self.__outputs[stage].append(output_file)
        output_filename = split(output_file)[1]

        stage_outputs_dir = join(self.__project_dir, 'outputs', stage)
        makedirs(stage_outputs_dir, exist_ok=True)

        backup_file = join(stage_outputs_dir, output_filename)
        copy(output_file, backup_file)

        self.__save_project()

        print(f'Stored output from {stage} ({output_file} => {backup_file})')


    def add_config(self, key, value):
        self.__config[key] = value


    def get_config(self, key):
        return self.__config[key]


    def check_config(self, key):
        return key in self.__config.keys()


    def __save_project(self):
        data = dict(
            gpus = self.__gpus,
            frames = self.__frames,
            outputs = self.__outputs,
            config = self.__config
        )
        with open(self.__project_file, 'w') as outfile:
            yaml.dump(data, outfile, default_flow_style=False)


    def __load_project(self):
        with open(self.__project_file, 'r') as infile:
            data = yaml.safe_load(infile)

            self.__config = data["config"]
            self.__gpus = data["gpus"]
            self.__frames = data["frames"]
            self.__outputs = data["outputs"]

        print(f"Config: {self.__config}")
        print(f"GPUs: {self.__gpus}")
        print(f"Frames: {self.__frames}")
        print(f"Outputs: {self.__outputs}")


    def __check_stage(self, stage):
        last_frame = self.__frames[stage]
        stage_dir = self.__get_stage_dir(stage)
        dirty = False

        for i in range(1, last_frame):
            frame = join(stage_dir, f"{i:04d}.png")
            if (not isfile(frame)):
                print(f"  {frame} missing")
                dirty = True
                continue

        if dirty:
            print("At least one frame in the sequence is missing!")
        else:
            print(f"All frames appear to be in place.")


    def restore_project(self):
        self.__load_project()

        for stage in self.__frames.keys():
            self.__check_stage(stage)
            backup_dir = self.__get_stage_dir(stage)
            restore_dir = join(working_dir, stage)
            for target_file in listdir(backup_dir):
                backup_file = join(backup_dir, target_file)
                restore_file = join(restore_dir, target_file)
                copy(backup_file, restore_file)


project = NeuralProject(project_dir)

notebook_config_vars = ["upscale_frames", "crop_frames", "super_slomo", "generate_videos", "key_frames", "text_prompts", 
    "width", "height", "trim_width", "trim_height", "interval", "initial_image", "target_images", "seed", "max_frames",
    "angle", "zoom", "translation_x", "translation_y", "iterations_per_frame", "save_all_iterations", "superscale_resolution",
    "target_fps", "vqgan_model"]

if (project.exists()):
    project.restore_project()
    for config_var_name in notebook_config_vars:
        if not project.check_config(config_var_name):
            print(f"The config variable '{config_var_name}' was not found in the project config.")
            continue

        config_value = project.get_config(config_var_name)
        if config_value != globals()[config_var_name]:
            print(f"WARNING: Loading {config_var_name} from file, overwriting the value '{globals()[config_var_name]}' with {config_value}")
            globals()[config_var_name] = config_value
else:
    for config_var_name in notebook_config_vars:
        project.add_config(config_var_name, globals()[config_var_name])

In [None]:
#@title ## Check GPU type

#@markdown Factory reset runtime if you don't have the desired GPU.
#@markdown
#@markdown | GPU  | Memory | Information |
#@markdown |---   |    ---:|---          |
#@markdown | P100 | 16GB   | Very good GPU, typically takes 2-3 minutes per frame of ~16k pixels |
#@markdown | T4   | 16GB   | Good GPU, typically takes 3-4 minutes per frame of ~16k pixels |

import xml.etree.ElementTree as ET

gpu_detail = !nvidia-smi --query --xml-format

root = ET.ElementTree(ET.fromstringlist(gpu_detail)).getroot()
for child in root:
    if child.tag == 'driver_version':
        print(f'NVIDIA Driver Version: {child.text}')
    elif child.tag == 'cuda_version':
        print(f'CUDA Version: {child.text}')
    elif child.tag == 'gpu':
        gpu_product = ''
        gpu_uuid = ''
        print(f'{child.attrib["id"]}:')
        for gpu_property in child:
            if gpu_property.tag == 'product_name':
                print(f'  Product Name : {gpu_property.text}')
                gpu_product = gpu_property.text
            elif gpu_property.tag == 'product_brand':
                print(f'  Brand        : {gpu_property.text}')
            elif gpu_property.tag == 'product_architecture':
                print(f'  Architecture : {gpu_property.text}')
            elif gpu_property.tag == 'uuid':
                print(f'  UUID         : {gpu_property.text}')
                gpu_uuid = gpu_property.text
            elif gpu_property.tag == 'fb_memory_usage':
                for mem in gpu_property:
                    if mem.tag == 'total':
                        print(f'  Total Memory : {mem.text}')
            project.add_gpu(gpu_product, gpu_uuid)

In [None]:
# @title ## Library installation
# @markdown This cell will take a while because you have to download multiple libraries

from os import makedirs
from os.path import join
from urllib.request import urlretrieve

print("Downloading CLIP...")
!git clone https://github.com/openai/CLIP                                      &> /dev/null
 
print("Downloading Taming Transformers (VQGAN)...")
!git clone https://github.com/CompVis/taming-transformers                      &> /dev/null

print("Downloading Installing Python AI libraries...")
!pip install ftfy regex tqdm omegaconf pytorch-lightning                       &> /dev/null
!pip install kornia                                                            &> /dev/null
!pip install einops                                                            &> /dev/null
 
print("Installing libraries for handling metadata...")
!pip install stegano                                                           &> /dev/null
!apt install exempi                                                            &> /dev/null
!pip install python-xmp-toolkit                                                &> /dev/null
!pip install imgtag                                                            &> /dev/null
!pip install pillow==7.1.2                                                     &> /dev/null
 
print("Installing Python video creation libraries...")
!pip install imageio-ffmpeg                                                    &> /dev/null
!pip install ffmpeg-python                                                     &> /dev/null

print("Downloading SRCNN")
!git clone https://github.com/Mirwaisse/SRCNN.git                              &> /dev/null

print("Downloading Super-SlowMo")
!git clone -q --depth 1 https://github.com/avinashpaliwal/Super-SloMo.git      &> /dev/null

path = join(working_dir, 'steps')
makedirs(path, exist_ok=True)

print("Installation finished.")

In [None]:
# @title ## Install caching framework

from os.path import isfile, abspath, join
from os import makedirs
from urllib.request import urlretrieve
from shutil import copy

if len(cache_dir.strip()) == 0:
    raise RuntimeError("cache_dir is not set, please set it in variables above.")

if len(models_dir.strip()) == 0:
    raise RuntimeError("models_dir is not set, please set it in variables above.")

model_cache_dir = abspath(join(cache_dir, 'models')) # models cache should in general be in a persistent disk

# create directories if they don't exist.
makedirs(models_dir, exist_ok=True)
makedirs(cache_dir, exist_ok=True)
makedirs(model_cache_dir, exist_ok=True)

def download_vqgan_model(name, config_url, checkpoint_url):
    # create a directory to store the models in
    model_cache = join (model_cache_dir, name)
    makedirs(model_cache, exist_ok=True)

    # create vairables for the config
    config_filename = f'{name}.yaml'
    config_file = join(models_dir, config_filename)
    config_cache = join(model_cache, config_filename)

    # create variables for the checkpoint
    checkpoint_filename = f'{name}.ckpt'
    checkpoint_file = join(models_dir, checkpoint_filename)
    checkpoint_cache = join(model_cache, checkpoint_filename)

    # test for availability of the files
    is_available = isfile(config_file) and isfile(checkpoint_file)
    is_cached = isfile(config_cache) and isfile(checkpoint_cache)

    # just return the name of the model it's already available locally
    if (is_available):
        print(f'The model ({name}) is already available locally.')
        return name

    # copy the files from the cache and return the name of the model
    if (is_cached):
        print(f'The model ({name}) is available in the cache, copying locally.')
        copy(config_cache, config_file)
        copy(checkpoint_cache, checkpoint_file)
        return name

    # downlaod everything, cache it and return the name of the model
    print(f'The model ({name}) was not found, downloading and caching.')
    urlretrieve(config_url, config_file)
    urlretrieve(checkpoint_url, checkpoint_file)
    copy(config_file, config_cache)
    copy(checkpoint_file, checkpoint_cache)
    return name

In [None]:
# @title ## Download required models

models = []

if imagenet_1024:
    downloded_model = download_vqgan_model(
        name='vqgan_imagenet_f16_1024',
        config_url='https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1',
        checkpoint_url='https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1')
    models.append(downloded_model)
if imagenet_16384:
    downloded_model = download_vqgan_model(
        name='vqgan_imagenet_f16_16384',
        config_url='https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1',
        checkpoint_url='https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1')
    models.append(downloded_model)
if coco:
    downloded_model = download_vqgan_model(
        name='coco',
        config_url='https://dl.nmkd.de/ai/clip/coco/coco.yaml',
        checkpoint_url='https://dl.nmkd.de/ai/clip/coco/coco.ckpt')
    models.append(downloded_model)
if faceshq:
    downloded_model = download_vqgan_model(
        name='faceshq',
        config_url='https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT',
        checkpoint_url='https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt')
    models.append(downloded_model)
if wikiart_16384:
    downloded_model = download_vqgan_model(
        name='wikiart_16384',
        config_url='http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml',
        checkpoint_url='http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt')
    models.append(downloded_model)
if sflckr:
    downloded_model = download_vqgan_model(
        name='sflckr',
        config_url='https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1',
        checkpoint_url='https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1')
    models.append(downloded_model)

print(models)

# these are so small we don't bother caching them
srcnn_models_url = 'https://raw.githubusercontent.com/justinjohn0306/SRCNN/master/models/'
for srcnn_model in ["2x", "3x", "4x"]:
    urlretrieve(f"{srcnn_models_url}model_{srcnn_model}.pth", join(models_dir, f"model_{srcnn_model}.pth"))

from os.path import exists
def download_from_google_drive(file_id, file_name):
  # download a file from the Google Drive link
  # TODO: convert this to pure Python
  !rm -f ./cookie
  !curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id={file_id}" > /dev/null
  confirm_text = !awk '/download/ {print $NF}' ./cookie
  confirm_text = confirm_text[0]
  !curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm={confirm_text}&id={file_id}" -o {file_name}
  
pretrained_model = 'SuperSloMo.ckpt'
if not exists(pretrained_model):
  download_from_google_drive('1IvobLDbRiBgZr3ryCRrWL8xDbMZ-KnpF', pretrained_model)

In [None]:
# @title ## Loading of libraries and definitions
 
import argparse
import math
from pathlib import Path
import sys
import os
import cv2
import pandas as pd
import numpy as np
import subprocess
from shutil import copy
from os.path import join
 
sys.path.append('./taming-transformers')

# Some models include transformers, others need explicit pip install
try:
    import transformers
except Exception:
    !pip install transformers
    import transformers

from IPython import display
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
 
from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image
from imgtag import ImgTag    # metadata 
from libxmp import *         # metadata
import libxmp                # metadata
from stegano import lsb
import json
ImageFile.LOAD_TRUNCATED_IMAGES = True
 
def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
 
 
def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()
 
 
def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]
 
 
def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size
 
    input = input.view([n * c, 1, h, w])
 
    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])
 
    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])
 
    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
 
 
class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward
 
    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)
 
 
replace_grad = ReplaceGrad.apply
 
 
class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)
 
    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
 
 
clamp_with_grad = ClampWithGrad.apply
 
 
def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)
 
 
class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))
 
    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
 
 
def parse_prompt(prompt):
    vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])
 
 
class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.augs = nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            # K.RandomSolarize(0.01, 0.01, p=0.7),
            K.RandomSharpness(0.3,p=0.4),
            K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
            K.RandomPerspective(0.2,p=0.4),
            K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
        self.noise_fac = 0.1
 
 
    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch
 
 
def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model
 
 
def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)

In [None]:
#@title ## Configure Arguments for the Neural Network
if initial_image != "":
    print(
        "WARNING: You have specified an initial image. Note that the image resolution "
        "will be inherited from this image, not whatever width and height you specified. "
        "If the initial image resolution is too high, this can result in out of memory errors."
    )
elif width * height > 160000:
    print(
        "WARNING: The width and height you have specified may be too high, in which case "
        "you will encounter out of memory errors either at the image generation stage or the "
        "video synthesis stage. If so, try reducing the resolution"
    )

if seed == -1:
    seed = None

def parse_key_frames(string, prompt_parser=None):
    import re
    pattern = r'((?P<frame>[0-9]+):[\s]*[\(](?P<param>[\S\s]*?)[\)])'
    frames = dict()
    for match_object in re.finditer(pattern, string):
        frame = int(match_object.groupdict()['frame'])
        param = match_object.groupdict()['param']
        if prompt_parser:
            frames[frame] = prompt_parser(param)
        else:
            frames[frame] = param

    if frames == {} and len(string) != 0:
        raise RuntimeError('Key Frame string not correctly formatted')
    return frames

def get_inbetweens(key_frames, integer=False):
    key_frame_series = pd.Series([np.nan for a in range(max_frames)])
    for i, value in key_frames.items():
        key_frame_series[i] = value
    key_frame_series = key_frame_series.astype(float)
    key_frame_series = key_frame_series.interpolate(limit_direction='both')
    if integer:
        return key_frame_series.astype(int)
    return key_frame_series

def split_key_frame_text_prompts(frames):
    prompt_dict = dict()
    for i, parameters in frames.items():
        prompts = parameters.split('|')
        for prompt in prompts:
            string, value = prompt.split(':')
            string = string.strip()
            value = float(value.strip())
            if string in prompt_dict:
                prompt_dict[string][i] = value
            else:
                prompt_dict[string] = {i: value}
    prompt_series_dict = dict()
    for prompt, values in prompt_dict.items():
        value_string = (
            ', '.join([f'{value}: ({values[value]})' for value in values])
        )
        prompt_series = get_inbetweens(parse_key_frames(value_string))
        prompt_series_dict[prompt] = prompt_series
    prompt_list = []
    for i in range(max_frames):
        prompt_list.append(
            ' | '.join(
                [f'{prompt}: {prompt_series_dict[prompt][i]}'
                 for prompt in prompt_series_dict]
            )
        )
    return prompt_list

if key_frames:
    try:
        text_prompts_series = split_key_frame_text_prompts(
            parse_key_frames(text_prompts)
        )
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `text_prompts` correctly for key frames.\n"
            "Attempting to interpret `text_prompts` as "
            f'"0: ({text_prompts}:1)"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        text_prompts = f"0: ({text_prompts}:1)"
        text_prompts_series = split_key_frame_text_prompts(
            parse_key_frames(text_prompts)
        )

    try:
        target_images_series = split_key_frame_text_prompts(
            parse_key_frames(target_images)
        )
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `target_images` correctly for key frames.\n"
            "Attempting to interpret `target_images` as "
            f'"0: ({target_images}:1)"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        target_images = f"0: ({target_images}:1)"
        target_images_series = split_key_frame_text_prompts(
            parse_key_frames(target_images)
        )

    try:
        angle_series = get_inbetweens(parse_key_frames(angle))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `angle` correctly for key frames.\n"
            "Attempting to interpret `angle` as "
            f'"0: ({angle})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        angle = f"0: ({angle})"
        angle_series = get_inbetweens(parse_key_frames(angle))

    try:
        zoom_series = get_inbetweens(parse_key_frames(zoom))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `zoom` correctly for key frames.\n"
            "Attempting to interpret `zoom` as "
            f'"0: ({zoom})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        zoom = f"0: ({zoom})"
        zoom_series = get_inbetweens(parse_key_frames(zoom))

    try:
        translation_x_series = get_inbetweens(parse_key_frames(translation_x))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `translation_x` correctly for key frames.\n"
            "Attempting to interpret `translation_x` as "
            f'"0: ({translation_x})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        translation_x = f"0: ({translation_x})"
        translation_x_series = get_inbetweens(parse_key_frames(translation_x))

    try:
        translation_y_series = get_inbetweens(parse_key_frames(translation_y))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `translation_y` correctly for key frames.\n"
            "Attempting to interpret `translation_y` as "
            f'"0: ({translation_y})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        translation_y = f"0: ({translation_y})"
        translation_y_series = get_inbetweens(parse_key_frames(translation_y))

    try:
        iterations_per_frame_series = get_inbetweens(
            parse_key_frames(iterations_per_frame), integer=True
        )
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `iterations_per_frame` correctly for key frames.\n"
            "Attempting to interpret `iterations_per_frame` as "
            f'"0: ({iterations_per_frame})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        iterations_per_frame = f"0: ({iterations_per_frame})"
        
        iterations_per_frame_series = get_inbetweens(
            parse_key_frames(iterations_per_frame), integer=True
        )
else:
    text_prompts = [phrase.strip() for phrase in text_prompts.split("|")]
    if text_prompts == ['']:
        text_prompts = []
    if target_images == "None" or not target_images:
        target_images = []
    else:
        target_images = target_images.split("|")
        target_images = [image.strip() for image in target_images]

    angle = float(angle)
    zoom = float(zoom)
    translation_x = float(translation_x)
    translation_y = float(translation_y)
    iterations_per_frame = int(iterations_per_frame)

args = argparse.Namespace(
    prompts=text_prompts,
    image_prompts=target_images,
    noise_prompt_seeds=[],
    noise_prompt_weights=[],
    size=[width, height],
    init_weight=0.,
    clip_model='ViT-B/32',
    vqgan_config=join(models_dir, f'{vqgan_model}.yaml'),
    vqgan_checkpoint=join(models_dir, f'{vqgan_model}.ckpt'),
    step_size=0.1,
    cutn=64,
    cut_pow=1.,
    display_freq=interval,
    seed=seed,
)

In [None]:
#@title ## Video Creation Helpers
#@markdown `ffmpeg` will need to be installed

import ffmpeg
import os
from os.path import isfile, join

def display_video_in_output(video_file):
    video_data = open(video_file,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(video_data).decode()
    display.HTML(f'<video width=400 controls><source src="{data_url}" type="video/mp4"></video>')

def create_video_from_dir(directory, video_file, input_height, input_width, framerate):
    if isfile(video_file):
        os.remove(video_file)

    print(f'creating {video_file} from {directory} as a {framerate}fps {input_height}x{input_width} mp4.')

    (
    ffmpeg
        .input(f'{directory}/%04d.png', pattern_type='sequence', s=f'{input_height}x{input_width}', framerate=framerate)
        .output(video_file, preset='fast')
        .run()
    )
    display_video_in_output(video_file)

latest_video = None

# Image Synthesis

In [None]:
#@title ## Fire up the Neural Network

# Delete memory from previous runs
!nvidia-smi -caa
for var in ['device', 'model', 'perceptor', 'z']:
    try:
        del globals()[var]
    except:
        pass

try:
    import gc
    gc.collect()
except:
    pass

try:
    torch.cuda.empty_cache()
except:
    pass

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if not key_frames:
    if text_prompts:
        print('Using text prompts:', text_prompts)
    if target_images:
        print('Using image prompts:', target_images)

if args.seed is None:
    seed = torch.seed()
else:
    seed = args.seed

torch.manual_seed(seed)
print('Using seed:', seed)

# copy the initial image into the steps folder
if (init_frame > 1) and (len(initial_image) > 0):
    copy(initial_image, join(working_dir, 'steps', os.path.split(initial_image)[1]))
 
model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
 
cut_size = perceptor.visual.input_resolution
e_dim = model.quantize.e_dim
f = 2**(model.decoder.num_resolutions - 1)
make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
n_toks = model.quantize.n_e
toksX, toksY = args.size[0] // f, args.size[1] // f
sideX, sideY = toksX * f, toksY * f
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
stop_on_next_loop = False  # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete

def read_image_workaround(path):
    """OpenCV reads images as BGR, Pillow saves them as RGB. Work around
    this incompatibility to avoid colour inversions."""
    im_tmp = cv2.imread(path)
    return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)

for i in range(init_frame, max_frames):
    if stop_on_next_loop:
      break
    if key_frames:
        text_prompts = text_prompts_series[i]
        text_prompts = [phrase.strip() for phrase in text_prompts.split("|")]
        if text_prompts == ['']:
            text_prompts = []
        args.prompts = text_prompts

        target_images = target_images_series[i]

        if target_images == "None" or not target_images:
            target_images = []
        else:
            target_images = target_images.split("|")
            target_images = [image.strip() for image in target_images]
        args.image_prompts = target_images

        angle = angle_series[i]
        zoom = zoom_series[i]
        translation_x = translation_x_series[i]
        translation_y = translation_y_series[i]
        iterations_per_frame = iterations_per_frame_series[i]
        print(
            f'text_prompts: {text_prompts}'
            f'angle: {angle}',
            f'zoom: {zoom}',
            f'translation_x: {translation_x}',
            f'translation_y: {translation_y}',
            f'iterations_per_frame: {iterations_per_frame}'
        )
    try:
        if i == 0 and len(initial_image.strip()) > 0:
            img_0 = read_image_workaround(initial_image)
            z, *_ = model.encode(TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1)
        elif i == 0 and not os.path.isfile(f'{working_dir}/steps/{i:04d}.png'):
            one_hot = F.one_hot(
                torch.randint(n_toks, [toksY * toksX], device=device), n_toks
            ).float()
            z = one_hot @ model.quantize.embedding.weight
            z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
        else:
            if save_all_iterations:
                img_0 = read_image_workaround(
                    f'{working_dir}/steps/{i:04d}_{iterations_per_frame}.png')
            else:
                img_0 = read_image_workaround(f'{working_dir}/steps/{i:04d}.png')

            center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)
            trans_mat = np.float32(
                [[1, 0, translation_x],
                [0, 1, translation_y]]
            )
            rot_mat = cv2.getRotationMatrix2D( center, angle, zoom )

            trans_mat = np.vstack([trans_mat, [0,0,1]])
            rot_mat = np.vstack([rot_mat, [0,0,1]])
            transformation_matrix = np.matmul(rot_mat, trans_mat)

            img_0 = cv2.warpPerspective(
                img_0,
                transformation_matrix,
                (img_0.shape[1], img_0.shape[0]),
                borderMode=cv2.BORDER_WRAP
            )
            z, *_ = model.encode(TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1)
        i += 1

        z_orig = z.clone()
        z.requires_grad_(True)
        opt = optim.Adam([z], lr=args.step_size)

        normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                        std=[0.26862954, 0.26130258, 0.27577711])

        pMs = []

        for prompt in args.prompts:
            txt, weight, stop = parse_prompt(prompt)
            embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
            pMs.append(Prompt(embed, weight, stop).to(device))

        for prompt in args.image_prompts:
            path, weight, stop = parse_prompt(prompt)
            img = resize_image(Image.open(path).convert('RGB'), (sideX, sideY))
            batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
            embed = perceptor.encode_image(normalize(batch)).float()
            pMs.append(Prompt(embed, weight, stop).to(device))

        for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
            gen = torch.Generator().manual_seed(seed)
            embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
            pMs.append(Prompt(embed, weight).to(device))

        def synth(z):
            z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
            return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

        def add_xmp_data(filename):
            imagen = ImgTag(filename=filename)
            imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'creator', 'VQGAN+CLIP', {"prop_array_is_ordered":True, "prop_value_is_array":True})
            if args.prompts:
                imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', " | ".join(args.prompts), {"prop_array_is_ordered":True, "prop_value_is_array":True})
            else:
                imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', 'None', {"prop_array_is_ordered":True, "prop_value_is_array":True})
            imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'i', str(i), {"prop_array_is_ordered":True, "prop_value_is_array":True})
            imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'model', model_name, {"prop_array_is_ordered":True, "prop_value_is_array":True})
            imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'seed',str(seed) , {"prop_array_is_ordered":True, "prop_value_is_array":True})
            imagen.close()

        def add_stegano_data(filename):
            data = {
                "title": " | ".join(args.prompts) if args.prompts else None,
                "notebook": "VQGAN+CLIP",
                "i": i,
                "model": model_name,
                "seed": str(seed),
            }
            lsb.hide(filename, json.dumps(data)).save(filename)

        @torch.no_grad()
        def checkin(i, losses):
            losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
            tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
            out = synth(z)
            TF.to_pil_image(out[0].cpu()).save('progress.png')
            add_stegano_data('progress.png')
            add_xmp_data('progress.png')
            display.display(display.Image('progress.png'))

        def save_output(i, img, suffix=None):
            filename = f"{i:04}{'_' + suffix if suffix else ''}.png"
            steps_dir = join(working_dir, 'steps')
            out_file = join(steps_dir, filename)
            imageio.imwrite(out_file, np.array(img))
            add_stegano_data(out_file)
            add_xmp_data(out_file)
            project.add_artifact('steps', i, out_file)

        def ascend_txt(i, save=True, suffix=None):
            out = synth(z)
            iii = perceptor.encode_image(normalize(make_cutouts(out))).float()

            result = []

            if args.init_weight:
                result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)

            for prompt in pMs:
                result.append(prompt(iii))
            img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
            img = np.transpose(img, (1, 2, 0))
            if save:
                save_output(i, img, suffix=suffix)
            return result

        def train(i, save=True, suffix=None):
            opt.zero_grad()
            lossAll = ascend_txt(i, save=save, suffix=suffix)
            if i % args.display_freq == 0 and save:
                checkin(i, lossAll)
            loss = sum(lossAll)
            loss.backward()
            opt.step()
            with torch.no_grad():
                z.copy_(z.maximum(z_min).minimum(z_max))

        with tqdm() as pbar:
            if iterations_per_frame == 0:
                save_output(i, img_0)
            j = 1
            while True:
                suffix = (str(j) if save_all_iterations else None)
                if j >= iterations_per_frame:
                    train(i, save=True, suffix=suffix)
                    break
                if save_all_iterations:
                    train(i, save=True, suffix=suffix)
                else:
                    train(i, save=False, suffix=suffix)
                j += 1
                pbar.update()
    except KeyboardInterrupt:
      stop_on_next_loop = True
      pass

## Create preview video from generated frames

In [9]:
if generate_videos:
    steps_frames_dir = join(working_dir, 'steps')
    steps_video = join(working_dir, 'steps.mp4')

    create_video_from_dir(steps_frames_dir, steps_video, height, width, 12)

    project.add_output('steps', steps_video)
    latest_video = steps_video

# Crop Frames

In [None]:
from PIL import Image
from os import listdir, remove
from os.path import isfile, join
from pathlib import Path

if crop_frames:
    steps_dir = join(working_dir, "steps")
    cropped_steps_dir = join(working_dir, "cropped_steps")
    Path(cropped_steps_dir).mkdir(parents=True, exist_ok=True)

    for source_frame_filename in sorted(listdir(steps_dir)):
        step_filename = join(steps_dir, source_frame_filename)
        cropped_filename = join(cropped_steps_dir, source_frame_filename)
        frame_number = int(source_frame_filename.split('.')[0])

        print(f'Cropping frame {frame_number} ({step_filename} => {cropped_filename}).')

        if isfile(cropped_filename):
            remove(cropped_filename)

        original = Image.open(step_filename)
        width, height = original.size

        left = trim_width / 2
        top = trim_height / 2
        right = width - (trim_width / 2)
        bottom = height - (trim_height / 2)

        cropped_frame = original.crop((left, top, right, bottom))

        cropped_frame.save(cropped_filename)
        project.add_artifact("cropped_steps", frame_number, cropped_filename)

## Create video from cropped frames

In [None]:
if generate_videos:
    cropped_frames_dir = join(working_dir, 'cropped_steps')
    cropped_video = join(working_dir, 'cropped_steps.mp4')

    create_video_from_dir(cropped_frames_dir, cropped_video, height - trim_height, width - trim_width, 12)

    project.add_output('cropped_steps', cropped_video)
    latest_video = cropped_video

# Increase Resolution

In [None]:
import subprocess
import shutil
from os import listdir
from os.path import isfile, join
from pathlib import Path

if upscale_frames:
    zoom_factor = superscale_resolution.rstrip("x")

    cropped_steps_dir = join(working_dir, "cropped_steps")
    zoomed_steps_dir = join(working_dir, "zoomed_steps")
    Path(zoomed_steps_dir).mkdir(parents=True, exist_ok=True)

    for cropped_frame in sorted(listdir(cropped_steps_dir)):
        cropped_frame_filename = join(cropped_steps_dir, cropped_frame) 
        zoomed_frame_filename = join(zoomed_steps_dir, cropped_frame)
        frame_number = int(cropped_frame.split('.')[0])

        cmd = [
            'python3',
            f'{working_dir}/SRCNN/run.py',
            '--zoom_factor',
            zoom_factor,
            '--model',
            f"{working_dir}/models/model_{superscale_resolution}.pth",  # 2x, 3x and 4x are available from the repo above
            '--image',
            cropped_frame,
            '--cuda'
        ]
        print(f'Upscaling frame {frame_number} ({cropped_frame_filename})')

        if isfile(zoomed_frame_filename):
            remove(zoomed_frame_filename)

        process = subprocess.Popen(cmd, cwd=cropped_steps_dir)
        stdout, stderr = process.communicate()
        if process.returncode != 0:
            print(stdout)
            print(stderr)
            raise RuntimeError(stderr)

        shutil.move(join(cropped_steps_dir, f"zoomed_{cropped_frame}"), zoomed_frame_filename)
        project.add_artifact("zoomed", frame_number, zoomed_frame_filename)

## Crate video from zoomed frames

In [None]:
if generate_videos:
    zoomed_frames_dir = join(working_dir, 'zoomed_steps')
    zoomed_video = join(working_dir, 'zoomed_steps.mp4')

    zoom_factor = superscale_resolution.rstrip("x")

    create_video_from_dir(zoomed_frames_dir, zoomed_video, height * int(zoom_factor), width * int(zoom_factor), 12)

    project.add_output('zoomed_steps', zoomed_video)
    latest_video = zoomed_video

# Super-Slomo for smoothing movement

This step might run out of memory if you run it right after the steps above. If it does, restart the notebook, upload a saved copy of the video from the previous step (or get it from google drive) and define the variable `filepath` with the path to the video before running the cells below again

In [None]:
# import subprocess in case this cell is run without the above cells
import subprocess

if super_slomo:
    SLOW_MOTION_FACTOR = 2 #@param {type:"slider", min:0, max:10, step:1}
    TARGET_FPS = 14 #@param {type:"slider", min:0, max:30, step:1}

    latest_video_stem = latest_video.split('.')[0]

    cmd1 = [
        'python',
        'Super-SloMo/video_to_slomo.py',
        '--checkpoint',
        pretrained_model,
        '--video',
        latest_video,
        '--sf',
        str(SLOW_MOTION_FACTOR),
        '--fps',
        str(TARGET_FPS),
        '--output',
        f'{latest_video_stem}-slomo.mkv',
    ]

    process = subprocess.Popen(cmd1, cwd=f'/content', stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    stdout, stderr = process.communicate()
    if process.returncode != 0:
        raise RuntimeError(stderr)

    cmd2 = [
        'ffmpeg',
        '-i',
        f'{latest_video_stem}-slomo.mkv',
        '-pix_fmt',
        'yuv420p',
        '-crf',
        '17',
        '-preset',
        'veryslow',
        f'{latest_video_stem}-slomo.mp4',
    ]

    process = subprocess.Popen(cmd2, cwd=f'/content', stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()
    if process.returncode != 0:
        print(stderr)
        print(
            "You may be able to avoid this error by backing up the frames,"
            "restarting the notebook, and running only the video synthesis cells,"
            "or by decreasing the resolution of the image generation steps. "
            "If you restart the notebook, you will have to define the `filepath` manually"
            "by adding `filepath = 'PATH_TO_THE_VIDEO'` to the beginning of this cell. "
            "If these steps do not work, please post the traceback in the github."
        )
        raise RuntimeError(stderr)