In [None]:
# generic libraries
import os
import gc 
import time
import torch
import torchvision
import random
import numpy as np
import matplotlib.pyplot as plt

# library for image augmentation
from albumentations import Compose, Resize

# library for unet model
from denoising_diffusion_pytorch import Unet

# library for VQGAN encode/decoder
from river.model import Model
from river.lutils.configuration import Configuration

# library for stochastic interpolators with follmer processes
import ProbForecastFollmerProcess as pffp

In [None]:
# setting plotting style and defining the device
plt.style.use('ggplot')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_folder = "./store/multi_modal_jump_diffusion"
print('Computing on ' + str(device))

In [None]:
# setting reproducibility
reproducible = True
SEED = 1024 if reproducible else int(time.time())
pffp.utils.ensure_reproducibility(SEED)

In [None]:
# data paths
# defining root to CLEVREV data
CLEVREV_root = "store/video_generation/data/CLEVREV"
# defining base folder of CLEVREV train, validation and test data
CLEVREV_train_folder = os.path.join(CLEVREV_root, "video_train")
CLEVREV_validation_folder = os.path.join(CLEVREV_root, "video_validation")
CLEVREV_test_folder = os.path.join(CLEVREV_root, "video_test")
# defining video folders of CLEVEREV train, validation and test data
CLEVREV_train_video_folders = [os.path.join(CLEVREV_train_folder, video_folder) for video_folder in os.listdir(CLEVREV_train_folder)]
CLEVREV_validation_video_folders = [os.path.join(CLEVREV_validation_folder, video_folder) for video_folder in os.listdir(CLEVREV_validation_folder)]
CLEVREV_test_video_folders = [os.path.join(CLEVREV_test_folder, video_folder) for video_folder in os.listdir(CLEVREV_test_folder)]
# defining video files for each of the CLEVEREV train, validation and test data
CLEVRER_train_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_train_video_folders for mp4_file in os.listdir(video_folder)]
CLEVRER_validation_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_validation_video_folders for mp4_file in os.listdir(video_folder)]
CLEVRER_test_video_paths = [os.path.join(video_folder, mp4_file) for video_folder in CLEVREV_test_video_folders for mp4_file in os.listdir(video_folder)]

In [None]:
# dataset and transforms
# defining target frame heght and width
target_frame_height = target_frame_width = 128
# defining custom train, validation and test augmentations
train_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
validation_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
test_augmentations = Compose([Resize(target_frame_height, target_frame_width)])
# defining datasets
CLEVREV_train_dataset = pffp.data.VideoDataset(CLEVRER_train_video_paths, augmentations = train_augmentations)
CLEVREV_validation_dataset = pffp.data.VideoDataset(CLEVRER_validation_video_paths, augmentations = validation_augmentations)
CLEVREV_test_dataset = pffp.data.VideoDataset(CLEVRER_test_video_paths, augmentations = test_augmentations)

In [None]:
# model's weights and configuration
# defining root to CLEVREV weights
CLEVREV_weights_root = "store/video_generation/weights"
# defining path to VQVAE CLEVREV weights
VQVAE_CLEVREV_weights_path = os.path.join(CLEVREV_weights_root, "vqvae.pth")
# defining path to VQGAN CLEVREV weights
VQGAN_CLEVREV_weights_path = os.path.join(CLEVREV_weights_root, "model.pth")
# defining path to CLEVREV configuration
CLEVREV_config_path = "store/video_generation/configs/clevrer.yaml"
# initializing configurations
CLEVEREV_configuration = Configuration(CLEVREV_config_path)
# initializing model
CLEVREV_model = Model(CLEVEREV_configuration["model"]).to(device)
# getting the encoder
CLEVREV_VQGAN_encoder = CLEVREV_model.ae.backbone.encoder
# getting the decoder
CLEVREV_VQGAN_decoder = CLEVREV_model.ae.backbone.decoder
# defining latent datasets
CLEVREV_latent_train_dataset = pffp.data.VQGANLatentVideoDataset(CLEVREV_train_dataset, CLEVREV_VQGAN_encoder, device)
CLEVREV_latent_validation_dataset = pffp.data.VQGANLatentVideoDataset(CLEVREV_validation_dataset, CLEVREV_VQGAN_encoder, device)
CLEVREV_latent_test_dataset = pffp.data.VQGANLatentVideoDataset(CLEVREV_test_dataset, CLEVREV_VQGAN_encoder, device)

In [None]:
# getting the first element of the set of videos and try to train our model
# on this reduced example for debugging
train_example_video = CLEVREV_latent_train_dataset[0]
train_examples_paired_frames = pffp.utils.get_video_dataset_with_random_context(train_example_video, device) 
len(train_examples_paired_frames), type(train_examples_paired_frames)

In [None]:
# defining number of input channnels for drift estimator
num_channels = 4 # four channels for each of the target, previous and random context frames
num_frames = 3
input_channels = num_channels*num_frames
latent_height = latent_width = 16

# defining backbone model
backbone = Unet(latent_height, out_dim = num_channels, channels = num_channels, input_channels = input_channels, self_condition=True).to(device)

# defining class wrapping around the 
# unet model needed for handling the 
# shape of the time feature as for B_Network
class UnetWrapper(torch.nn.Module):
    def __init__(self, base_model):
        super(UnetWrapper, self).__init__()
        self.base_model = base_model
    
    def forward(self, X, t, Xc):
        N = X.shape[0]
        if len(t.shape) == 0:
            t = t.repeat((N))
        out = self.base_model(X, t, Xc)
        return out

# defining wrapper around backbone model
backbone = UnetWrapper(backbone)

# defining sampling configurations
sample = {
    "g": pffp.interpolant["sigma"], 
    "num_euler_steps": 100
}

# defining state configurations 
state = {
    "spatial_dims": (latent_height, latent_width)
}

# defining data configurations
data = {    
    "train": train_examples_paired_frames, 
    "test": None, 
}

# defining optimization configurations
optim_config = {
    'batch_size': 6, 
    'num_epochs': 10,
    'learning_rate' : 2e-4,
    'num_mc_samples': 300,
    'max_num_grad_steps': 250000  
}

# defining model
model = pffp.model(backbone, data, sample, state, pffp.utils.interpolant, pffp.utils.velocity, optim_config, device = "cuda")

if 0==0:#"trained_pffp_model.pt" not in os.listdir(CLEVREV_weights_root):
    # print model and message
    print(model)
    print("checkpoint not found, starting training")
    # training model
    model.train()
    # saving model state dictionary
    torch.save(model.state_dict(), os.path.join(CLEVREV_weights_root, "trained_pffp_model.pt"))
    # retrieving loss and learning rates
    losses = model.loss
    lrs = model.lrs
    # plotting results
    # defining axes and figure
    fig, axes = plt.subplots(1, 2)
    # plotting loss
    axes[0].set_xlabel("Iteration")
    axes[0].set_ylabel("Loss")
    axes[0].plot(losses)
    # plotting learning rates
    axes[1].set_xlabel("Iteration")
    axes[1].set_ylabel("Learning rate")
    axes[1].plot(lrs)
else:
    # print model and message
    print("checkpoint found, loading model")
    print(model)
    state_dict = torch.load(os.path.join(CLEVREV_weights_root, "trained_pffp_model.pt"))
    model.load_state_dict(state_dict)

In [None]:
# autoregressive sampling configuration
sample_config = {
    "num_ar_steps": 1000
}

# running autoregressive sampling
outout = model.sample_autoregressive(sample_config)
starting_point = output["initial_condition"]
gt_path = output["gt_path"]
ar_samples = output["ar_path"]
print(f"{starting_point.shape=}, {gt_path.shape=}, {ar_samples.shape=}")

In [None]:
# sampling 1000 observation given the first one
# sampling configuration
sample_config = {
    "num_samples": 1000,
    "num_obs": 1
}

# running sampling
output = model.sample(sample_config)
X0 = output["current_states"]
X1 = output["next_state"]
sample = output["sampled_states"]
print(f"{X0.shape=}, {X1.shape=}, {samples.shape=}")