In [46]:
%load_ext autoreload
%autoreload 2

from text3d2video.artifacts.video_artifact import VideoArtifact
from text3d2video.utilities.video_util import clip_to_pil_frames, pil_frames_to_clip
from visual_tests.testing_utils import test_img
from moviepy.editor import VideoFileClip

video = VideoArtifact.from_wandb_artifact_tag('video:latest')
frames = video.get_frames()[0:20]

def write_and_decode(frames, path='test.mp4', fps=10, codec='h264'):
    clip = pil_frames_to_clip(frames, fps=fps)
    clip.write_videofile(path, codec=codec, verbose=False, logger=None, bitrate="50000k")
    read = VideoFileClip(path)
    decoded = clip_to_pil_frames(read, expected_frames=len(frames))
    return decoded

decoded = write_and_decode(frames, fps=10)

len(frames), len(decoded)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


(20, 20)

In [48]:
import torch
import torchvision.transforms.functional as TF
import torch.nn.functional as F


def compute_errors(original_frames, decoded_frames):
    mses = []
    for i in range(len(original_frames)):
        og = original_frames[i]
        dec = decoded_frames[i]

        og_pt = TF.to_tensor(og)
        dec_pt = TF.to_tensor(dec)

        mse = F.mse_loss(og_pt, dec_pt)
        mses.append(mse.item())

    mses = torch.tensor(mses)
    mean_mse = mses.mean().item()
    max_mse = mses.max().item()

    print("mean", mean_mse)
    print("max", max_mse)
    return mean_mse, max_mse


assert len(frames) == len(decoded)
mean_mse, max_mse = compute_errors(frames, decoded)

assert max_mse < 1e-4

mean 2.3630545911146328e-06
max 5.6512276387366e-06


In [40]:
from torch import randint
import random
import string

n_videos = 1

for _ in range(n_videos):
    all_colors = [
        "red",
        "green",
        "blue",
        "yellow",
        "purple",
        "orange",
        "black",
        "white",
    ]

    fps = randint(1, 30, (1,)).item()
    fps = 30
    n_frames = randint(1, 200, (1,)).item()
    chars = random.choices(string.ascii_lowercase, k=n_frames)
    colors = random.choices(all_colors, k=n_frames)

    # generate random video
    frames = [test_img(chars[i], color=colors[i]) for i in range(n_frames)]

    # decode it
    decoded = write_and_decode(frames, fps=fps)

    # assert sufficient quality
    print(len(frames), len(decoded))
    assert len(frames) == len(decoded)
    mean_mse, max_mse = compute_errors(frames, decoded)
    print(mean_mse)


78 78
mean 3.470530600679922e-06
max 1.0506983926461544e-05
3.470530600679922e-06
