In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import torch
import cv2
import imageio
import numpy as np
from PIL import Image 

In [2]:
def pil_list_to_torch(images):
    tensor = [torch.from_numpy(x) for x in images]
    tensor = [x.permute(2, 0, 1).unsqueeze(0) for x in tensor]
    tensor = [x.float() / 255.0 for x in tensor]
    tensor = torch.cat(tensor)
    return tensor

def torch_to_np(tensor):
    images = tensor.permute(0, 2, 3, 1).numpy()
    images = np.clip(images * 255, 0, 255)
    images = images.astype(np.uint8)
    return images

def resize_gif(in_path, out_path, img_h=512, img_w=512, duration=60):
    gif_file = imageio.get_reader(in_path)
    images = [cv2.resize(x, (img_w, img_h)) for x in gif_file]
    imageio.mimsave(out_path, images, duration=duration)

In [3]:
interpolation_frames = 3

device = torch.device('cuda')
precision = torch.float16
model_path = '../models/interpolation-models/film_net_fp16.pt'

gif_path = './samples/generation/8908903743535013658.gif'
out_folder = './samples/interpolated'
os.makedirs(out_folder, exist_ok=True)

out_name = os.path.basename(gif_path).split('.')[0]
out_name

'8908903743535013658'

In [4]:
model = torch.jit.load(model_path, map_location='cpu')
model.eval().to(device=device, dtype=precision);

In [5]:
gif_file = imageio.get_reader(gif_path)
images = [np.array(x) for x in gif_file]
tensor = pil_list_to_torch(images)

In [6]:
prev_tensor = tensor[0:-1:1].to(precision).to(device)
next_tensor = tensor[1::1].to(precision).to(device)
dt = prev_tensor.new_full((1, 1), 0.5)

In [7]:
with torch.no_grad():
    dt_images_mid = model(prev_tensor, next_tensor, dt).cpu()
    generated_mid = torch_to_np(dt_images_mid)

if interpolation_frames > 1:
    with torch.no_grad():
        dt_images_prev = model(prev_tensor.to(device), dt_images_mid.to(device), dt).cpu()
        dt_images_next = model(dt_images_mid.to(device), next_tensor.to(device), dt).cpu()

    generated_prev = torch_to_np(dt_images_prev)
    generated_next = torch_to_np(dt_images_next)

  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
 does not have profile information (Triggered internally at ../third_party/nvfuser/csrc/graph_fuser.cpp:104.)
  return forward_call(*args, **kwargs)


In [8]:
extended_images = [images[0]]

if interpolation_frames == 1:
    for mid_img, orig_img in zip(generated_mid, images[1:]):
        extended_images.extend([mid_img, orig_img])
elif interpolation_frames == 2:
    for prev_img, next_img, orig_img in zip(generated_prev, generated_next, images[1:]):
        extended_images.extend([prev_img, next_img, orig_img])
elif interpolation_frames == 3:
    for prev_img, mid_img, next_img, orig_img in zip(generated_prev, generated_mid, generated_next, images[1:]):
        extended_images.extend([prev_img, mid_img, next_img, orig_img])

In [9]:
imageio.mimsave(f'{out_folder}/{out_name}.gif', extended_images, duration=60)

In [10]:
img_h=384
img_w=384

for duration in [60, 90, 120]:
    resize_gif(
        in_path=f'{out_folder}/{out_name}.gif', 
        out_path=f'{out_folder}/{out_name}_{duration}.gif',
        img_h=img_h, 
        img_w=img_w,
        duration=duration,
    )