In [None]:
# generic libraries
import os
import gc 
import time
import torch
import torchvision
import random
import numpy as np
import matplotlib.pyplot as plt

# library for image augmentation
from albumentations import Compose, Resize

# library for unet model
from denoising_diffusion_pytorch import Unet

# library for VQGAN encode/decoder
from river.model import Model
from river.lutils.configuration import Configuration

# library for stochastic interpolators with follmer processes
import ProbForecastFollmerProcess as pffp

In [None]:
# setting plotting style and defining the device
plt.style.use('ggplot')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_folder = "./store/multi_modal_jump_diffusion"
print('Computing on ' + str(device))

In [None]:
# setting reproducibility
reproducible = True
SEED = 1024 if reproducible else int(time.time())
pffp.utils.ensure_reproducibility(SEED)

In [None]:
# data paths
# defining root to CLEVREV data
CLEVREV_root = "store/video_generation/data/CLEVREV"
# defining base folder of CLEVREV train, validation and test data
CLEVREV_train_folder = os.path.join(CLEVREV_root, "video_train")
CLEVREV_validation_folder = os.path.join(CLEVREV_root, "video_validation")
CLEVREV_test_folder = os.path.join(CLEVREV_root, "video_test")
# defining video folders of CLEVEREV train, validation and test data
CLEVREV_train_video_folders = [os.path.join(CLEVREV_train_folder, video_folder) for video_folder in os.listdir(CLEVREV_train_folder)]
CLEVREV_validation_video_folders = [os.path.join(CLEVREV_validation_folder, video_folder) for video_folder in os.listdir(CLEVREV_validation_folder)]
CLEVREV_test_video_folders = [os.path.join(CLEVREV_test_folder, video_folder) for video_folder in os.listdir(CLEVREV_test_folder)]
# defining video files for each of the CLEVEREV train, validation and test data
CLEVRER_train_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_train_video_folders for mp4_file in os.listdir(video_folder)]
CLEVRER_validation_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_validation_video_folders for mp4_file in os.listdir(video_folder)]
CLEVRER_test_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_test_video_folders for mp4_file in os.listdir(video_folder)]

In [None]:
# dataset and transforms
# defining target frame heght and width
target_frame_height = target_frame_width = 128
# defining custom train, validation and test augmentations
train_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
validation_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
test_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
# defining datasets
CLEVREV_train_dataset = pffp.data.VideoDataset(CLEVRER_train_video_paths, augmentations = train_augmentations)
CLEVREV_validation_dataset = pffp.data.VideoDataset(CLEVRER_validation_video_paths, augmentations = validation_augmentations)
CLEVREV_test_dataset = pffp.data.VideoDataset(CLEVRER_test_video_paths, augmentations = test_augmentations)

In [None]:
# model's weights and configuration
# defining root to CLEVREV weights
CLEVREV_weights_root = "store/video_generation/weights"
# defining path to VQVAE CLEVREV weights
VQVAE_CLEVREV_weights_path = os.path.join(CLEVREV_weights_root, "vqvae.pth")
# defining path to VQGAN CLEVREV weights
VQGAN_CLEVREV_weights_path = os.path.join(CLEVREV_weights_root, "model.pth")
# defining path to CLEVREV configuration
CLEVREV_config_path = "store/video_generation/configs/clevrer.yaml"
# initializing configurations
CLEVEREV_configuration = Configuration(CLEVREV_config_path)
# initializing model
CLEVREV_model = Model(CLEVEREV_configuration["model"]).to(device)
# getting the encoder and setting it to inference mode
CLEVREV_VQGAN_encoder = CLEVREV_model.ae.backbone.encoder
CLEVREV_VQGAN_encoder.eval()
# getting the decoder and setting it to inference mode
CLEVREV_VQGAN_decoder = CLEVREV_model.ae.backbone.decoder
CLEVREV_VQGAN_decoder.eval()
# defining latent datasets
CLEVREV_latent_train_dataset = pffp.data.VQGANLatentVideoDataset(CLEVREV_train_dataset, CLEVREV_VQGAN_encoder, device)
CLEVREV_latent_validation_dataset = pffp.data.VQGANLatentVideoDataset(CLEVREV_validation_dataset, CLEVREV_VQGAN_encoder, device)
CLEVREV_latent_test_dataset = pffp.data.VQGANLatentVideoDataset(CLEVREV_test_dataset, CLEVREV_VQGAN_encoder, device)
print(f"Number of train videos {len(CLEVREV_latent_train_dataset)}")
print(f"Number of validation videos {len(CLEVREV_latent_validation_dataset)}")
print(f"Number of test videos {len(CLEVREV_latent_test_dataset)}")

In [None]:
# defining class for concatenated video dataset
# needed to be able to train over all frames of each 
# of the 10000 videos of the CLEVREV train dataset
class ConcatenatedVideoDataset(torch.utils.data.Dataset):
    def __init__(self, base_datasets, device):
        super(ConcatenatedVideoDataset, self).__init__()
        # the class holding the data of each video
        # instance of VQGANLatentVideoDataset
        # each of its element will be a full video
        # but we want to be able to sample training paired frames
        # from all the videos.
        self.base_datasets = base_datasets
        # initializing the device attribute 
        self.device = device
        # getting the number of videos in the dataset
        self.num_videos = len(self.base_datasets)
        # getting the number of frames per video
        # It assumes that all the videos will have the same number of frames.
        self.num_frames_per_video = len(self.base_datasets[0]) - 1
    
    def __len__(self):
        return self.num_videos*self.num_frames_per_video
    
    def __getitem__(self, idx):
        # getting video index
        video_idx = idx // self.num_frames_per_video
        # getting frame index
        frame_idx = idx % self.num_frames_per_video
        # getting the video
        video = self.base_datasets[video_idx]
        # creating dataset with random context
        video_paired_frames = pffp.utils.get_video_dataset_with_random_context(pffp.data.LaggedDatasetWithRandomContext, video, self.device) 
        # getting the frame data from the current video
        frame_data = video_paired_frames[frame_idx]
        return frame_data

In [None]:
# defining number of input channnels for drift estimator
num_latent_channels = 4 # four channels for each of the target, previous and random context frames
num_frames = 3
input_channels = num_latent_channels*num_frames
latent_height = latent_width = 16

# defining backbone model
backbone = Unet(
    latent_height, 
    out_dim = num_latent_channels, 
    channels = num_latent_channels, 
    input_channels = input_channels, 
    self_condition=True
).to(device)

# defining class wrapping around the 
# unet model needed for handling the 
# shape of the time feature as for B_Network
class UnetWrapper(torch.nn.Module):
    def __init__(self, base_model):
        super(UnetWrapper, self).__init__()
        self.base_model = base_model
    
    def forward(self, X, t, Xc):
        N = X.shape[0]
        if len(t.shape) == 0:
            t = t.repeat((N))
        out = self.base_model(X, t, Xc)
        return out

# defining wrapper around backbone model
backbone = UnetWrapper(backbone)

# defining sampling configurations
sample = {
    "g": pffp.interpolant["sigma"], 
    "num_euler_steps": 100
}

# defining state configurations 
state = {
    "input_dims": (num_latent_channels, latent_height, latent_width)
}

# defining data configurations
data = {    
    "train": ConcatenatedVideoDataset(CLEVREV_latent_train_dataset, device), 
    "test": ConcatenatedVideoDataset(CLEVREV_latent_validation_dataset, device), 
}

# defining optimization configurations
optim_config = {
    'batch_size': 6, 
    'num_epochs': 5,
    'learning_rate' : 2e-4,
    'num_mc_samples': 300,
    'max_num_grad_steps': 250000,
    'constant_lr': True  
}

# defining model
model = pffp.model(
    backbone, 
    data, 
    sample, 
    state, 
    pffp.utils.interpolant, 
    pffp.utils.velocity, 
    optim_config, 
    device = "cuda", 
    verbose = 2, 
    debug = False,
    random_ar_context = True
)

if "trained_pffp_model.pt" not in os.listdir(CLEVREV_weights_root):
    # print model and message
    print(model)
    print("checkpoint not found, starting training")
    # training model
    model.train()
    # saving model state dictionary
    torch.save(model.state_dict(), os.path.join(CLEVREV_weights_root, "trained_pffp_model.pt"))
    # retrieving loss and learning rates
    losses = model.loss_values
    lrs = model.lrs_values
    # plotting results
    # defining axes and figure
    fig, axes = plt.subplots(1, 2)
    # plotting loss
    axes[0].set_xlabel("Iteration")
    axes[0].set_ylabel("Loss")
    axes[0].plot(losses)
    # plotting learning rates
    axes[1].set_xlabel("Iteration")
    axes[1].set_ylabel("Learning rate")
    axes[1].plot(lrs)
else:
    # print model and message
    print("checkpoint found, loading model")
    print(model)
    state_dict = torch.load(os.path.join(CLEVREV_weights_root, "trained_pffp_model.pt"))
    model.load_state_dict(state_dict)