In [1]:
# generic libraries
import os
import gc 
import time
import torch
import torchvision
import random
import numpy as np
import matplotlib.pyplot as plt

# library for image augmentation
from albumentations import Compose, Resize

# library for unet model
from denoising_diffusion_pytorch import Unet

# library for VQGAN encode/decoder
from river.model import Model
from river.lutils.configuration import Configuration

# library for stochastic interpolators with follmer processes
import ProbForecastFollmerProcess as pffp

In [2]:
# setting plotting style and defining the device
plt.style.use('ggplot')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_folder = "./store/multi_modal_jump_diffusion"
print('Computing on ' + str(device))

Computing on cuda


In [3]:
# data paths
# defining root to CLEVREV data
CLEVREV_root = "store/video_generation/data/CLEVREV"
# defining base folder of CLEVREV train, validation and test data
CLEVREV_train_folder = os.path.join(CLEVREV_root, "video_train")
CLEVREV_validation_folder = os.path.join(CLEVREV_root, "video_validation")
CLEVREV_test_folder = os.path.join(CLEVREV_root, "video_test")
# defining video folders of CLEVEREV train, validation and test data
CLEVREV_train_video_folders = [os.path.join(CLEVREV_train_folder, video_folder) for video_folder in os.listdir(CLEVREV_train_folder)]
CLEVREV_validation_video_folders = [os.path.join(CLEVREV_validation_folder, video_folder) for video_folder in os.listdir(CLEVREV_validation_folder)]
CLEVREV_test_video_folders = [os.path.join(CLEVREV_test_folder, video_folder) for video_folder in os.listdir(CLEVREV_test_folder)]
# defining video files for each of the CLEVEREV train, validation and test data
CLEVRER_train_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_train_video_folders for mp4_file in os.listdir(video_folder)]
CLEVRER_validation_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_validation_video_folders for mp4_file in os.listdir(video_folder)]
CLEVRER_test_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_test_video_folders for mp4_file in os.listdir(video_folder)]

In [4]:
# dataset and transforms
# defining target frame heght and width
target_frame_height = target_frame_width = 128
# defining custom train, validation and test augmentations
train_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
validation_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
test_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
# defining datasets
CLEVREV_train_dataset = pffp.data.VideoDataset(CLEVRER_train_video_paths, augmentations = train_augmentations)
CLEVREV_validation_dataset = pffp.data.VideoDataset(CLEVRER_validation_video_paths, augmentations = validation_augmentations)
CLEVREV_test_dataset = pffp.data.VideoDataset(CLEVRER_test_video_paths, augmentations = test_augmentations)

In [5]:
# model's weights and configuration
# defining root to CLEVREV weights
CLEVREV_weights_root = "store/video_generation/weights"
# defining path to VQVAE CLEVREV weights
VQVAE_CLEVREV_weights_path = os.path.join(CLEVREV_weights_root, "vqvae.pth")
# defining path to VQGAN CLEVREV weights
VQGAN_CLEVREV_weights_path = os.path.join(CLEVREV_weights_root, "model.pth")
# defining path to CLEVREV configuration
CLEVREV_config_path = "store/video_generation/configs/clevrer.yaml"
# initializing configurations
CLEVEREV_configuration = Configuration(CLEVREV_config_path)
# initializing model
CLEVREV_model = Model(CLEVEREV_configuration["model"]).to(device)
# getting the encoder
CLEVREV_VQGAN_encoder = CLEVREV_model.ae.backbone.encoder
# getting the decoder
CLEVREV_VQGAN_decoder = CLEVREV_model.ae.backbone.decoder
# defining latent datasets
CLEVREV_latent_train_dataset = pffp.data.VQGANLatentVideoDataset(CLEVREV_train_dataset, CLEVREV_VQGAN_encoder, device)
CLEVREV_latent_validation_dataset = pffp.data.VQGANLatentVideoDataset(CLEVREV_validation_dataset, CLEVREV_VQGAN_encoder, device)
CLEVREV_latent_test_dataset = pffp.data.VQGANLatentVideoDataset(CLEVREV_test_dataset, CLEVREV_VQGAN_encoder, device)

In [6]:
# defining function for creating a lagged dataset with random context from a tensor holding a video
def get_video_dataset_with_random_context(video_tensor):
    target_frame, past_frame, random_context_frame = pffp.utils.pair_lagged_observations_with_random_context(video_tensor)
    video_dataset_with_random_context = pffp.data.LaggedDatasetWithRandomContext(past_frame, target_frame, random_context_frame, device)
    return video_dataset_with_random_context