The goal here is to build a FSQ-VAE which builds latent vectors for spatio-temporal "tublets" described in the ViViT. Sequences of tublet latent vectors are to then be modeled by a Transformer Decoder

![](https://i.imgur.com/9G7QTfV.png)

In [None]:
!pip install -q wandb pytorch_lightning av imageio

In [1]:
import os
import gc
import math
import wandb
import imageio
import torch
import torchvision
import numpy as np
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
from torch import nn, optim
from torch.nn import functional as F
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale
import torchvision.transforms.functional as TF

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.tuner import Tuner
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

pl.seed_everything(42)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
steamboat_willie_gdrive_path = '/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4'
!cp -r /content/drive/My\ Drive/SteamboatWillie/clips .

In [None]:
class RandomHorizontalFlipVideo(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, x):
        # x shape is expected to be (C, T, H, W)
        if torch.rand(1) < self.p:
            # Flip all frames in the clip
            return x.flip(-1)
        return x


class SteamboatWillieDataset(Dataset):
    def __init__(self, config, mode='train', train_split=0.8):
        super().__init__()
        self.config = config
        self.train_split = train_split
        self.mode = mode

        self.preprocess_transforms = Compose([
                Lambda(lambda x: x.permute(0, 3, 1, 2)), # (T, H, W, C) to (T, C, H, W) for Greyscale
                Grayscale(num_output_channels=1), # Convert to grayscale
                Lambda(lambda x: x.permute(1, 0, 2, 3)), # (T, C, H, W) to (C, T, H, W) for Conv3d
                CenterCrop((480, 575)), # Center crop to remove virtical bars
                Resize((config.img_size, config.img_size))
        ])

        self.postprocess_transforms = Compose([
            Lambda(lambda x: x / 255.),
            Lambda(lambda x: x.view(self.config.in_channels, self.config.clip_length, self.config.img_size, self.config.img_size))
        ])

        if self.mode == 'train':
            self.postprocess_transforms.transforms.append(RandomHorizontalFlipVideo(p=0.5))

        if os.path.exists(config.dest_dir):
            clip_paths = self.build_existing_clip_paths(config.dest_dir)
            self.clips = self.build_clip_refs(clip_paths)
        else:
            video_clips = VideoClips(
                config.paths,
                clip_length_in_frames=config.clip_length,
                frames_between_clips=config.clip_length
            )

            self.clips = self.build_clip_refs(self.build_clip_paths(video_clips, self.preprocess_transforms, config.dest_dir))

        if mode in ['train', 'val']:
            total_clips = len(self.clips)

            indices = torch.randperm(total_clips).tolist()
            train_size = int(total_clips * train_split)

            if mode == 'train':
                self.clip_indices = indices[:train_size]
            else:
                self.clip_indices = indices[train_size:]
        else:
            self.clip_indices = list(range(len(self.clips)))

    def build_clip_paths(self, video_clips, transforms, dest_dir):
        """
        Build set of binary files to store processed video clips
        returns dict of clip_idx -> mmapped file path
        """
        clip_paths = {}

        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)

        for idx in tqdm(range(video_clips.num_clips()), desc='Creating clip .bin files'):
            # transform clips and write to mmap file
            clip, _, _, _ = video_clips.get_clip(idx)
            clip = self.preprocess_transforms(clip)
            clip_np = clip.numpy().astype(np.uint8)

            mmapped_file_path = os.path.join(dest_dir, f'clip_{idx}.bin')
            fp = np.memmap(mmapped_file_path, dtype='uint8', mode='w+', shape=clip_np.shape)
            fp[:] = clip_np[:]
            fp.flush()
            del fp
            clip_paths[idx] = mmapped_file_path

        return clip_paths

    def build_existing_clip_paths(self, dest_dir):
        """"
        returns dict of clip_idx -> mmapped file path
        from existing .bin files
        """
        clips_paths = {}
        for filename in os.listdir(dest_dir):
            if filename.startswith('clip_') and filename.endswith('.bin'):
                idx = int(filename.split('_')[1].split('.')[0])
                file_path = os.path.join(dest_dir, filename)
                clips_paths[idx] = file_path

        return clips_paths

    def build_clip_refs(self, clip_paths):
        """
        Build mmap reference to bin files
        returns dict of clip_idx -> np.array mmapped to respective bin file
        """
        clips = {}
        for idx, path in tqdm(clip_paths.items(), desc='Building clip refs'):
            clips[idx] = np.memmap(path, dtype='uint8', mode='r')

        return clips

    def __len__(self):
        return len(self.clip_indices)

    def __getitem__(self, idx):
        clip = self.clips[self.clip_indices[idx]]
        return self.postprocess_transforms(torch.tensor(clip, dtype=torch.float32))


class SteamboatWillieDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.batch_size = config.batch_size
        self.config = config

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = SteamboatWillieDataset(self.config, mode='train')
            self.val_dataset = SteamboatWillieDataset(self.config, mode='val')

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.config.num_workers)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.config.num_workers)

In [None]:
class VideoVAEConfig:
    def __init__(self):
        self.project_name = 'FSQ-VAE Steamboat Willie'
        # dataset properties
        self.paths = ['/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4']
        self.dest_dir = './clips/'
        # model checkpoints
        self.checkpoint_path = "./checkpoints"
        self.save_top_k = 1
        # training
        self.train_split = 0.8
        self.batch_size = 32
        self.max_epochs = 500
        self.training_steps = 100000
        self.num_workers = 2
        # optimizer
        self.lr = 6e-4
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.weight_decay = 0.0 # 1e-2
        self.use_wd_schedule = False
        self.use_lr_schedule = True
        # input properties
        self.clip_length = 16
        self.img_size = 256
        self.in_channels = 1
        # quantization
        self.quant_mode = 'fsq' # 'vq'
        self.latent_channels = 10 # 8
        self.codebook_size = 512
        self.commit_loss_beta = 0.25
        self.track_codebook = True
        self.use_ema = True
        self.ema_gamma = 0.99
        self.level = 7
        self.levels = [self.level for _ in range(self.latent_channels)]
        # encoder/decoder
        self.hidden_channels = 256
        self.start_channels = 32
        self.nblocks = 5
        self.nlayers = 3

    def update(self, updates):
        for key, value in updates.items():
            if hasattr(self, key):
                setattr(self, key, value)

    def to_dict(self):
        return {k: v for k, v in self.__dict__.items()}

In [2]:
# I want to take a video tensor of shape (B, C, T, H, W) and split it into
# patches of shape (t, p, p) aka "tubelets", but instead of performing a single
# Conv3d operation on them such that the tublet dims are the kernel dims, I want
# to have tublets be the input to a small VAE so that the resulting latents
# can be mapped back to pixel space

# video dimensions
B, C, T, H, W = 4, 3, 64, 256, 256

# patch dim
t, p = 8, 16

assert T % t == 0
assert H % p == 0
assert W % p == 0

n_t, n_h, n_w = T // t, H // p, W // p

vid = torch.randn(B, C, T, H, W)

In [None]:
def extract_tubelets(vid, t=8, p=16):
    '''
    extract tubelet sequence from a video tensor of shape (B, C, T, H, W)
    '''
    assert len(vid.shape) == 5, 'vid.shape must be (B, C, T, H, W)'
    
    B, C, T, H, W = vid.shape

    assert T % t == 0, 'T (vid.shape[2]) must be divisible by t'
    assert H % p == 0, 'H (vid.shape[3]) must be divisible by p'
    assert W % p == 0, 'W (vid.shape[4]) must be divisible by p'

    n_t, n_h, n_w = T // t, H // p, W // p
    
    # (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
    tubelets = vid.reshape(B, C, n_t, t, n_h, p, n_w, p)

    # (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
    tubelets = tubelets.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()

    # (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
    tubelets = tubelets.reshape(-1, C, t, p, p)

    return tubelets

In [None]:
config = VideoVAEConfig()
data_module = SteamboatWillieDataModule(config)
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()
train_batch = next(iter(train_loader))
print(f'train_batch.shape: {train_batch.shape}')

In [None]:
def display_clip(clip):
    def update(frame_idx):
        ax.clear()
        ax.imshow(video_clip_np[frame_idx], cmap='gray')
        ax.axis('off')

    video_clip_np = clip.permute(1, 2, 3, 0).numpy()
    fig, ax = plt.subplots()
    ani = FuncAnimation(fig, update, frames=range(video_clip_np.shape[0]), interval=50)
    plt.close()
    display(HTML(ani.to_html5_video()))

In [None]:
tublets = extract_tubelets(train_batch)
print(tublets.shape)

In [None]:
display_clip(tublets[1001])

In [None]:
# (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
tubelets = vid.reshape(B, C, n_t, t, n_h, p, n_w, p)
print(tubelets.shape)

In [None]:
# (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
tubelets = tubelets.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
print(tubelets.shape)

In [None]:
# (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
tubelets = tubelets.reshape(-1, C, t, p, p)
print(tubelets.shape)

In [None]:
conv3d_1 = nn.Conv3d(in_channels=C, out_channels=32, kernel_size=3, stride=2, padding=1)
out1 = conv3d_1(tubelets)
print(out1.shape)

In [None]:
conv3d_2 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
out2 = conv3d_2(out1)
print(out2.shape)

In [None]:
conv3d_3 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
out3 = conv3d_3(out2)
print(out3.shape)

In [None]:
conv3d_4 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=(1, 2, 2), padding=1)
out4 = conv3d_4(out3)
print(out4.shape)

In [None]:
out5 = out4.reshape(-1, 32)
print(out5.shape)

In [None]:
# separate out batch dimension post VAE encoding since dim=0 is of shape (B * n_t * n_h * n_w)
out6 = out5.reshape(out5.shape[0] // (n_t * n_h * n_w), -1, 32)
print(out6.shape)

In [None]:
out7 = out6.reshape(-1, 32)
print(out7.shape)

In [None]:
out7 = out7.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
print(out7.shape)

In [None]:
decoder1 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1))
out8 = decoder1(out7)
print(out8.shape)

In [None]:
decoder2 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=(1, 1, 1))
out9 = decoder2(out8)
print(out9.shape)

In [None]:
decoder3 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=(1, 1, 1))
out10 = decoder2(out9)
print(out10.shape)

In [None]:
decoder3 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=(1, 1, 1))
out11 = decoder2(out10)
print(out11.shape)