<a href="https://colab.research.google.com/github/SisekoC/My-Notebooks/blob/main/FILM_for_video_diffusion_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#FILM for video diffusion models

In [1]:
!pip install -q mediapy --quiet
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import cv2
from typing import Generator, Iterable, List, Optional
from PIL import Image
import mediapy as media

In [2]:
model = hub.load('https://tfhub.dev/google/film/1')

In [3]:
video_file = '/content/Teddy bear dancing is AWESOME.mp4'
cap = cv2.VideoCapture(video_file)
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = []

for _ in range(n_frames):
  _, frame = cap.read()
  frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

cap.release()
frames = np.array(frames)
frames.shape

(215, 854, 480, 3)

In [4]:
def load_image(img):
  UINT8_MAX_F = float(np.iinfo(np.uint8).max)
  img_numpy = tf.cast(img, dtype=tf.float32).numpy()
  return img_numpy / UINT8_MAX_F

def film(img1, img2):
  time = np.array([0.5], dtype=np.float32)
  input = {
      'time': np.expand_dims(time, axis=0),
      'x0': np.expand_dims(img1, axis=0),
      'x1': np.expand_dims(img2, axis=0)
  }
  mid_frame = model(input)
  return mid_frame['image'][0].numpy()

In [5]:
img1 = load_image(frames[0])
img2 = load_image(frames[5])

mid_frame = film(img1, img2)

output_frames = [img1, mid_frame, img2]
media.show_images(output_frames, titles=['First input image', 'Generated image', 'Second input image'],
                  height=250)

0,1,2
First input image,Generated image,Second input image


In [6]:
output_frames = [img1]

for i in range(n_frames - 1):
  img1 = load_image(frames[i])
  img5 = load_image(frames[i + 1])

  img3 = film(img1, img5)
  img2 = film(img1, img3)
  img4 = film(img3, img5)

  output_frames.append(img2)
  output_frames.append(img3)
  output_frames.append(img4)
  output_frames.append(img5)

  print(f'{4 * (i+1)}/{4 * n_frames} frames processed.')

output_frames = np.array(output_frames)
output_frames.shape

4/860 frames processed.
8/860 frames processed.
12/860 frames processed.
16/860 frames processed.
20/860 frames processed.
24/860 frames processed.
28/860 frames processed.
32/860 frames processed.
36/860 frames processed.
40/860 frames processed.
44/860 frames processed.
48/860 frames processed.
52/860 frames processed.
56/860 frames processed.
60/860 frames processed.
64/860 frames processed.
68/860 frames processed.
72/860 frames processed.
76/860 frames processed.
80/860 frames processed.
84/860 frames processed.
88/860 frames processed.
92/860 frames processed.
96/860 frames processed.
100/860 frames processed.
104/860 frames processed.
108/860 frames processed.
112/860 frames processed.
116/860 frames processed.
120/860 frames processed.
124/860 frames processed.
128/860 frames processed.
132/860 frames processed.
136/860 frames processed.
140/860 frames processed.
144/860 frames processed.
148/860 frames processed.
152/860 frames processed.
156/860 frames processed.
160/860 fram

(857, 854, 480, 3)

In [7]:
output_filename = 'result.mp4'
fps = 16
resolution = (256, 256)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(output_filename, fourcc, fps, resolution)

for frame in output_frames:
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  frame = (frame * 255).astype(np.uint8)
  writer.write(frame)

writer.release()