In [14]:
# generic libraries
import os
import gc 
import time
import torch
import torchvision
import random
import numpy as np
import matplotlib.pyplot as plt

# library for image augmentation
from albumentations import Compose, Resize

# library for unet model
from denoising_diffusion_pytorch import Unet

# library for VQGAN encode/decoder
from river.model import Model
from river.lutils.configuration import Configuration

# library for stochastic interpolators with follmer processes
import ProbForecastFollmerProcess as pffp

In [2]:
# setting plotting style and defining the device
plt.style.use('ggplot')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_folder = "./store/multi_modal_jump_diffusion"
print('Computing on ' + str(device))

Computing on cuda


In [3]:
# defining class for video dataset
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, mp4_file_list, augmentations = None):
        super(VideoDataset, self).__init__()
        # list of paths to mp4 files
        self.mp4_file_list = mp4_file_list
        # (callable) optional albumentations tranforms 
        self.augmentations = augmentations
    
    def __len__(self):
        # getting number of videos in the underlying list
        num_videos = len(self.mp4_file_list)
        return num_videos
    
    def __getitem__(self, idx):
        # get file name
        mp4_file = self.mp4_file_list[idx]
        # read video
        video, _, _ = torchvision.io.read_video(mp4_file, pts_unit = 'sec') # shape: (num_frames, height, width, num_channels)
        # normalize pixels in [0, 1]
        video = video / 255 
        # apply optional augmentations
        # need to apply the augmentations 
        # for each frame of the video 
        if self.augmentations:
            # initializing the list for storing
            # each transformed frame
            transformed_frames_store = []
            # iterating over each frame of the vide
            for frame in video:
                # converting frame tensor to np array
                # as required by albumentations
                frame = frame.numpy()
                # applying augmentations on array
                frame = self.augmentations(image = frame)["image"]
                # converting array back to torch tensor
                frame = torch.from_numpy(frame)
                # appending transformed frame to the store
                transformed_frames_store.append(frame)
            # concatenating back the frames over the time dimension
            video = torch.stack(transformed_frames_store, dim = 0)
        # converting back to shape (num_frames, num_channels, height, width)
        video = torch.permute(video, (0, 3, 2, 1))
        return video

In [4]:
# data paths
# defining root to CLEVREV data
CLEVREV_root = "store/video_generation/data/CLEVREV"
# defining base folder of CLEVREV train, validation and test data
CLEVREV_train_folder = os.path.join(CLEVREV_root, "video_train")
CLEVREV_validation_folder = os.path.join(CLEVREV_root, "video_validation")
CLEVREV_test_folder = os.path.join(CLEVREV_root, "video_test")
# defining video folders of CLEVEREV train, validation and test data
CLEVREV_train_video_folders = [os.path.join(CLEVREV_train_folder, video_folder) for video_folder in os.listdir(CLEVREV_train_folder)]
CLEVREV_validation_video_folders = [os.path.join(CLEVREV_validation_folder, video_folder) for video_folder in os.listdir(CLEVREV_validation_folder)]
CLEVREV_test_video_folders = [os.path.join(CLEVREV_test_folder, video_folder) for video_folder in os.listdir(CLEVREV_test_folder)]
# defining video files for each of the CLEVEREV train, validation and test data
CLEVRER_train_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_train_video_folders for mp4_file in os.listdir(video_folder)]
CLEVRER_validation_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_validation_video_folders for mp4_file in os.listdir(video_folder)]
CLEVRER_test_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_test_video_folders for mp4_file in os.listdir(video_folder)]

In [5]:
# dataset and transforms
# defining target frame heght and width
target_frame_height = target_frame_width = 128
# defining custom train, validation and test augmentations
train_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
validation_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
test_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
# defining datasets
CLEVREV_train_dataset = VideoDataset(CLEVRER_train_video_paths, augmentations = train_augmentations)
CLEVREV_validation_dataset = VideoDataset(CLEVRER_validation_video_paths, augmentations = validation_augmentations)
CLEVREV_test_dataset = VideoDataset(CLEVRER_test_video_paths, augmentations = test_augmentations)
# defining dictionary with datasets
CLEVREV_datasets = {
    "train": CLEVREV_train_dataset,
    "validation": CLEVREV_validation_dataset,
    "test": CLEVREV_test_dataset
}

In [6]:
# model's weights and configuration
# defining root to CLEVREV weights
CLEVREV_weights_root = "store/video_generation/weights"
# defining path to VQVAE CLEVREV weights
VQVAE_CLEVREV_weights_path = os.path.join(CLEVREV_weights_root, "vqvae.pth")
# defining path to VQGAN CLEVREV weights
VQGAN_CLEVREV_weights_path = os.path.join(CLEVREV_weights_root, "model.pth")
# defining path to CLEVREV configuration
CLEVREV_config_path = "store/video_generation/configs/clevrer.yaml"

In [7]:
# model and configurations
# initializing configurations
CLEVEREV_configuration = Configuration(CLEVREV_config_path)
# initializing model
CLEVREV_model = Model(CLEVEREV_configuration["model"]).to(device)
# getting the encoder
CLEVREV_VQGAN_encoder = CLEVREV_model.ae.backbone.encoder
# getting the decoder
CLEVREV_VQGAN_decoder = CLEVREV_model.ae.backbone.decoder

In [8]:
# defining the VQGAN latent video dataset
class VQGANLatentVideoDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, encoder, device):
        super(VQGANLatentVideoDataset, self).__init__()
        self.base_dataset = base_dataset # instance of VideoDataset class
        self.encoder = encoder # instance of river.model.encoder.Encoder class
        self.device = device # device on which the encoder is run

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        # getting video and moving it to device
        video = self.base_dataset[idx].to(self.device)
        # encoding latent video, detaching it from computation graph 
        # as we don't need the gradients wrt the encoder and moving
        # it to cpu for saving up gpu memory
        latent_video = self.encoder(video).detach().cpu()
        # deleting unused video and cleaning up memory and gpu cache
        del video
        gc.collect()
        torch.cuda.empty_cache()
        return latent_video

In [9]:
# defining latent datasets
CLEVREV_latent_datasets = {dataset_name: VQGANLatentVideoDataset(dataset, CLEVREV_VQGAN_encoder, device) for dataset_name, dataset in CLEVREV_datasets.items()}