The goal here is to build a FSQ-VAE which builds latent vectors for spatio-temporal "tublets" described in the ViViT. Sequences of tublet latent vectors are to then be modeled by a Transformer Decoder

![](https://i.imgur.com/9G7QTfV.png)

In [None]:
!pip install -q wandb pytorch_lightning av imageio

In [1]:
import os
import gc
import math
import wandb
import imageio
import torch
import torchvision
import numpy as np
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
from torch import nn, optim
from torch.nn import functional as F
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale
import torchvision.transforms.functional as TF

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.tuner import Tuner
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

pl.seed_everything(42)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
steamboat_willie_gdrive_path = '/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4'
!cp -r /content/drive/My\ Drive/SteamboatWillie/clips .

# Dataset

In [None]:
class RandomHorizontalFlipVideo(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, x):
        # x shape is expected to be (C, T, H, W)
        if torch.rand(1) < self.p:
            # Flip all frames in the clip
            return x.flip(-1)
        return x


class SteamboatWillieDataset(Dataset):
    def __init__(self, config, mode='train', train_split=0.8):
        super().__init__()
        self.config = config
        self.train_split = train_split
        self.mode = mode

        self.preprocess_transforms = Compose([
                Lambda(lambda x: x.permute(0, 3, 1, 2)),   # (T, H, W, C) to (T, C, H, W) for Greyscale
                Grayscale(num_output_channels=1),          # Convert to grayscale
                Lambda(lambda x: x.permute(1, 0, 2, 3)),   # (T, C, H, W) to (C, T, H, W) for Conv3d
                CenterCrop((480, 575)),                    # Center crop to remove virtical bars
                Resize((config.img_size, config.img_size))
        ])

        self.postprocess_transforms = Compose([
            Lambda(lambda x: x / 255.),
            Lambda(lambda x: x.view(self.config.in_channels, self.config.clip_length, self.config.img_size, self.config.img_size))
        ])

        if self.mode == 'train':
            self.postprocess_transforms.transforms.append(RandomHorizontalFlipVideo(p=0.5))

        if os.path.exists(config.dest_dir):
            clip_paths = self.build_existing_clip_paths(config.dest_dir)
            self.clips = self.build_clip_refs(clip_paths)
        else:
            video_clips = VideoClips(
                config.paths,
                clip_length_in_frames=config.clip_length,
                frames_between_clips=config.clip_length
            )

            self.clips = self.build_clip_refs(self.build_clip_paths(video_clips, self.preprocess_transforms, config.dest_dir))

        if mode in ['train', 'val']:
            total_clips = len(self.clips)

            indices = torch.randperm(total_clips).tolist()
            train_size = int(total_clips * train_split)

            if mode == 'train':
                self.clip_indices = indices[:train_size]
            else:
                self.clip_indices = indices[train_size:]
        else:
            self.clip_indices = list(range(len(self.clips)))

    def build_clip_paths(self, video_clips, transforms, dest_dir):
        """
        Build set of binary files to store processed video clips
        returns dict of clip_idx -> mmapped file path
        """
        clip_paths = {}

        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)

        for idx in tqdm(range(video_clips.num_clips()), desc='Creating clip .bin files'):
            # transform clips and write to mmap file
            clip, _, _, _ = video_clips.get_clip(idx)
            clip = self.preprocess_transforms(clip)
            clip_np = clip.numpy().astype(np.uint8)

            mmapped_file_path = os.path.join(dest_dir, f'clip_{idx}.bin')
            fp = np.memmap(mmapped_file_path, dtype='uint8', mode='w+', shape=clip_np.shape)
            fp[:] = clip_np[:]
            fp.flush()
            del fp
            clip_paths[idx] = mmapped_file_path

        return clip_paths

    def build_existing_clip_paths(self, dest_dir):
        """"
        returns dict of clip_idx -> mmapped file path
        from existing .bin files
        """
        clips_paths = {}
        for filename in os.listdir(dest_dir):
            if filename.startswith('clip_') and filename.endswith('.bin'):
                idx = int(filename.split('_')[1].split('.')[0])
                file_path = os.path.join(dest_dir, filename)
                clips_paths[idx] = file_path

        return clips_paths

    def build_clip_refs(self, clip_paths):
        """
        Build mmap reference to bin files
        returns dict of clip_idx -> np.array mmapped to respective bin file
        """
        clips = {}
        for idx, path in tqdm(clip_paths.items(), desc='Building clip refs'):
            clips[idx] = np.memmap(path, dtype='uint8', mode='r')

        return clips

    def __len__(self):
        return len(self.clip_indices)

    def __getitem__(self, idx):
        clip = self.clips[self.clip_indices[idx]]
        return self.postprocess_transforms(torch.tensor(clip, dtype=torch.float32))


class SteamboatWillieDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.batch_size = config.batch_size
        self.config = config

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = SteamboatWillieDataset(self.config, mode='train')
            self.val_dataset = SteamboatWillieDataset(self.config, mode='val')

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.config.num_workers)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.config.num_workers)

# Model

In [None]:
class ResBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels
        ):
        super().__init__()
        if in_channels != out_channels:
            self.identity = nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            )
        else:
            self.identity = nn.Identity()

        hidden_channels = in_channels // 4

        self.block = nn.Sequential(
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=hidden_channels,
                kernel_size=(1, 3, 3),
                padding=(0, 1, 1)
            ),
            nn.BatchNorm3d(hidden_channels),
            nn.ReLU(),
            nn.Conv3d(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
                kernel_size=(1, 3, 3),
                padding=(0, 1, 1)
            ),
            nn.BatchNorm3d(hidden_channels),
            nn.ReLU(),
            nn.Conv3d(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
                kernel_size=(1, 3, 3),
                padding=(0, 1, 1)
            ),
            nn.BatchNorm3d(hidden_channels),
            nn.ReLU(),
            nn.Conv3d(
                in_channels=hidden_channels,
                out_channels=out_channels,
                kernel_size=1
            ),
            nn.BatchNorm3d(out_channels)
        )
        self.res_act = nn.ReLU()

    def forward(self, x):
        out = self.block(x) + self.identity(x)
        return self.res_act(out)

In [None]:
class Encoder(nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            nlayers=3,
            nblocks=2
        ):
        super().__init__()
        self.downsample_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.Conv3d(
                    in_channels=in_channels if i==0 else hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=(3, 3, 3),
                    stride=(2, 2, 2),
                    padding=(1, 1, 1)
                ),
                nn.BatchNorm3d(hidden_channels),
                nn.ReLU()
            ) for i in range(nlayers)
        ])
        self.res_blocks = nn.Sequential(*[
            ResBlock(
                in_channels=hidden_channels,
                out_channels=out_channels if i==nblocks-1 else hidden_channels
            ) for i in range(nblocks)
        ])

    def forward(self, x):
        # ((B * n_t * n_h * n_w), C, t, p, p)
        x = self.downsample_blocks(x)
        x = self.res_blocks(x)
        # ((B * n_t * n_h * n_w), C, 1, 2, 2) -> ((B * n_t * n_h * n_w), (C * 1 * 2 * 2))
        x = torch.flatten(x, start_dim=1)
        return x

In [None]:
class Decoder(torch.nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            nlayers=3,
            nblocks=2
        ):
        super().__init__()

        self.res_blocks = nn.Sequential(*[
            ResBlock(
                in_channels=in_channels if i==0 else hidden_channels,
                out_channels=hidden_channels,
            )
            for i in range(nblocks)
        ])

        self.upsample_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.ConvTranspose3d(
                    in_channels=hidden_channels,
                    out_channels=out_channels if i==nlayers-1 else hidden_channels,
                    kernel_size=(3, 3, 3),
                    stride=(2, 2, 2),
                    padding=(1, 1, 1),
                    output_padding=(1, 1, 1)
                    ),
                nn.BatchNorm3d(hidden_channels) if i<nlayers-1 else nn.Identity(),
                nn.ReLU() if i<nlayers-1 else nn.Identity()
            ) for i in range(nlayers)
        ])

        self.out_act = nn.Sigmoid()

    def forward(self, x):
        # ((B * n_t * n_h * n_w), (C * 1 * 2 * 2)) -> ((B * n_t * n_h * n_w), C, 1, 2, 2)
        x = x.reshape(x.shape[0], -1, 1, 2, 2)
        x = self.res_blocks(x)
        # ((B * n_t * n_h * n_w), C, 1, 2, 2)
        x = self.upsample_blocks(x)
        x = self.out_act(x)
        return x

In [None]:
class FSQ(nn.Module):
    def __init__(self, levels, eps=1e-3):
        super().__init__()
        self.register_buffer('levels', torch.tensor(levels))
        self.register_buffer(
            'basis',
            torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
        )

        self.eps = eps
        self.codebook_size = torch.prod(self.levels)

        # self.register_buffer('implicit_codebook', self.idxs_to_code(torch.arange(self.codebook_size)))

    def round_ste(self, z):
        z_q = torch.round(z)
        return z + (z_q - z).detach()

    def quantize(self, z):
        # half_l is used to determine how to scale tanh; we
        # subtract 1 from the number of levels to account for 0
        # being a quantization bin and tanh being symmetric around 0
        half_l = (self.levels - 1) * (1 - self.eps) / 2

        # if a given level is even, it will result in a scale for tanh
        # which is halfway between integer values, so we offset
        # the tanh output down by 0.5 to line it with whole integers
        offset = torch.where(self.levels % 2 == 0, 0.5, 0.0)

        # if our level is even, we want to shift the tanh input to
        # ensure the 0 quantization bin is centered
        shift = torch.tan(offset / half_l)

        # once we have our shift and offset (in the case of an even level)
        # we can round to the nearest integer bin and allow for STE
        z_q = self.round_ste(torch.tanh(z + shift) * half_l - offset)

        # after quantization, we want to renormalize the quantized
        # values to be within the range expected by the model (ie. [-1, 1])
        half_width = self.levels // 2
        return z_q / half_width

    def scale_and_shift(self, z_q_normalized):
        half_width = self.levels // 2
        return (z_q_normalized * half_width) + half_width

    def scale_and_shift_inverse(self, z_q):
        half_width = self.levels // 2
        return (z_q - half_width) / half_width

    def code_to_idxs(self, z_q):
        z_q = self.scale_and_shift(z_q)
        return (z_q * self.basis).sum(dim=-1).to(torch.int32)

    def idxs_to_code(self, idxs):
        idxs = idxs.unsqueeze(-1)
        codes_not_centered = (idxs // self.basis) % self.levels
        return self.scale_and_shift_inverse(codes_not_centered)

    def forward(self, z):
        # TODO: make this work for generic tensor sizes
        # TODO: use einops to clean up
        
        # B, C, T, H, W = z.shape

        # # (B, C, T, H, W) -> (B, T, H, W, C)
        # z_c_last = z.permute(0, 2, 3, 4, 1).contiguous()

        # # (B, T, H, W, C) -> (BTHW, C)
        # z_flatten = z_c_last.reshape(-1, C)

        # z_flatten_q = self.quantize(z_flatten)

        # # (BTHW, C) -> (B, T, H, W, C) -> (B, C, T, H, W)
        # z_q = z_flatten_q.reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()
        
        # z already of shape (B, C)
        z_q = self.quantize(z)

        return {'z_q': z_q}

In [None]:
class TubeletFSQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(
            in_channels     = config.in_channels,
            hidden_channels = config.hidden_channels,
            out_channels    = config.latent_channels,
        )
        self.decoder = Decoder(
            in_channels     = config.latent_channels,
            hidden_channels = config.hidden_channels,
            out_channels    = config.in_channels
        )

        # TODO: change self.T = config.clip_length if possible
        self.t, self.p = config.t, config.p
        self.B, self.C, self.T, self.H, self.W = config.batch_size, config.in_channels, config.clip_length, config.img_size, config.img_size
        self.n_t, self.n_h, self.n_w = self.T // self.t, self.H // self.p, self.W // self.p

        self.quantizer = FSQ(config.levels)

        self.config = config

    def encode(self, inputs):
        inputs = self.extract_tubelets(inputs)
        return self.encoder(inputs)

    def quantize(self, z):
        return self.quantizer(z)

    def decode(self, z_q):
        return self.decoder(z_q)

    def extract_tubelets(self, x, t=8, p=16):
        '''
        extract tubelet sequence of shape ((B * n_t * n_h * n_w), C, t, p, p)
        from a video tensor of shape (B, C, T, H, W)
        where n_t = T // t, n_h = H // p, and n_w = W // p
        '''
        assert len(x.shape) == 5, 'vid.shape must be (B, C, T, H, W)'

        B, C, T, H, W = x.shape

        assert T % t == 0, 't must divide T (vid.shape[2]) evenly'
        assert H % p == 0, 'p must divide H (vid.shape[3]) evenly'
        assert W % p == 0, 'p must divide W (vid.shape[4]) evenly'

        n_t, n_h, n_w = T // t, H // p, W // p

        # (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
        x = x.reshape(B, C, n_t, t, n_h, p, n_w, p)
        # (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
        x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
        # (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
        x = x.reshape(-1, C, t, p, p)

        return x

    def assemble_tubelets(self, x):
        # ((B * n_t * n_h * n_w), C, t, p, p) -> (B, n_t, n_h, n_w, C, t, p, p)
        x = x.reshape(self.B, self.n_t, self.n_h, self.n_w, self.C, self.t, self.p, self.p)
        # (B, n_t, n_h, n_w, C, t, p, p) -> (B, C, n_t, t, n_h, p, n_w, p)
        x = x.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous()
        # (B, C, n_t, t, n_h, p, n_w, p) -> (B, C, T, H, W)
        x = x.reshape(self.B, self.C, self.T, self.H, self.W)

        return x

    def loss(self, x_hat, x, quantized):
        MSE = F.mse_loss(x_hat, x)

        return {
            'loss': MSE,
            **quantized
        }

    def forward(self, x):
        z = self.encode(x)
        quantized = self.quantize(z)
        x_hat = self.decode(quantized['z_q'])
        losses = self.loss(x_hat, self.extract_tubelets(x), quantized)

        return {'x_hat': x_hat, **losses}

In [None]:
def extract_tubelets(x, t=8, p=16):
    '''
    extract tubelet sequence of shape ((B * n_t * n_h * n_w), C, t, p, p)
    from a video tensor of shape (B, C, T, H, W)
    where n_t = T // t, n_h = H // p, and n_w = W // p
    '''
    assert len(x.shape) == 5, 'vid.shape must be (B, C, T, H, W)'

    B, C, T, H, W = x.shape

    assert T % t == 0, 't must divide T (vid.shape[2]) evenly'
    assert H % p == 0, 'p must divide H (vid.shape[3]) evenly'
    assert W % p == 0, 'p must divide W (vid.shape[4]) evenly'

    n_t, n_h, n_w = T // t, H // p, W // p

    # (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
    x = x.reshape(B, C, n_t, t, n_h, p, n_w, p)
    # (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
    x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
    # (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
    x = x.reshape(-1, C, t, p, p)

    return x

# Lightning Module

In [None]:
class LitTubeletFSQVAE(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.lr = config.lr

        if self.logger:
            self.logger.experiment.config.update(config)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch
        out = self(x)

        self.log('train/loss', out['loss'], prog_bar=True)

        return out['loss']

    def validation_step(self, batch, batch_idx):
        x = batch
        out = self(x)

        self.log('val/loss', out['loss'], prog_bar=True)

        if batch_idx == 0:
            out['x_hat'] = self.model.assemble_tubelets(out['x_hat'])
            self.log_val_clips(x, out)

        return out['loss']

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            betas=(self.config.beta1, self.config.beta2),
            weight_decay=self.config.weight_decay
        )

        if self.config.use_lr_schedule:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config.max_epochs)
            return [optimizer], [scheduler]

        return optimizer

    def log_val_clips(self, x, out, num_clips=2):
        n_clips = min(x.size(0), num_clips)

        for i in range(n_clips):
            # extract the ith original and reconstructed clip
            original_clip = x[i]  # (C, T, H, W)
            reconstructed_clip = out['x_hat'][i]  # (C, T, H, W)

            # convert tensors to numpy arrays and transpose to (T, H, W, C) for GIF creation
            original_clip_np = original_clip.permute(1, 2, 3, 0).cpu().numpy()
            reconstructed_clip_np = reconstructed_clip.permute(1, 2, 3, 0).cpu().numpy()

            original_clip_np = (original_clip_np - original_clip_np.min()) / (original_clip_np.max() - original_clip_np.min())
            reconstructed_clip_np = (reconstructed_clip_np - reconstructed_clip_np.min()) / (reconstructed_clip_np.max() - reconstructed_clip_np.min())

            original_clip_np = (original_clip_np * 255).astype(np.uint8)
            reconstructed_clip_np = (reconstructed_clip_np * 255).astype(np.uint8)

            # grayscale videos need to be of shape (T, H, W)
            if original_clip_np.shape[-1] == 1:
                original_clip_np = original_clip_np.squeeze(-1)

            if reconstructed_clip_np.shape[-1] == 1:
                reconstructed_clip_np = reconstructed_clip_np.squeeze(-1)

            # create GIFs for the original and reconstructed clips
            original_gif_path = f'/tmp/original_clip_{i}.gif'
            reconstructed_gif_path = f'/tmp/reconstructed_clip_{i}.gif'
            imageio.mimsave(original_gif_path, original_clip_np, fps=10)
            imageio.mimsave(reconstructed_gif_path, reconstructed_clip_np, fps=10)

            # log the GIFs to wandb
            self.logger.experiment.log({
                f"val/original_clip_{i}": wandb.Video(original_gif_path, fps=10, format="gif", caption="Original"),
                f"val/reconstructed_clip_{i}": wandb.Video(reconstructed_gif_path, fps=10, format="gif", caption="Reconstructed")
            })

# Config

In [None]:
class VideoVAEConfig:
    def __init__(self):
        self.project_name = 'Tubelet FSQ-VAE Steamboat Willie'
        # dataset properties
        self.paths = ['/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4']
        self.dest_dir = './clips/'
        # model checkpoints
        self.checkpoint_path = "./checkpoints"
        self.save_top_k = 1
        # training
        self.train_split = 0.8
        self.batch_size = 16
        self.max_epochs = 500
        self.training_steps = 100000
        self.num_workers = 2
        # optimizer
        self.lr = 5e-2
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.weight_decay = 0.0 # 1e-2
        self.use_wd_schedule = False
        self.use_lr_schedule = False
        # input properties
        self.clip_length = 16
        self.img_size = 256
        self.in_channels = 1
        # quantization
        self.quant_mode = 'fsq' # 'vq'
        self.latent_channels = 64 # 8
        self.codebook_size = 512
        self.commit_loss_beta = 0.25
        self.track_codebook = True
        self.use_ema = True
        self.ema_gamma = 0.99
        self.level = 9
        self.levels = [self.level for _ in range(self.latent_channels * 4)]
        # encoder/decoder
        self.hidden_channels = 256
        self.start_channels = 32
        self.nblocks = 5
        self.nlayers = 3
        # tubelet
        self.t = 8
        self.p = 16

    def update(self, updates):
        for key, value in updates.items():
            if hasattr(self, key):
                setattr(self, key, value)
                if key == 'level':
                    self.levels = [self.level for _ in range(self.latent_channels * 4)]

    def to_dict(self):
        return {k: v for k, v in self.__dict__.items()}

# Component Testing

In [None]:
config = VideoVAEConfig()
data_module = SteamboatWillieDataModule(config)
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()
train_batch = next(iter(train_loader))
print(f'train_batch.shape: {train_batch.shape}')

In [None]:
tublets = extract_tubelets(train_batch)
print(tublets.shape)

In [None]:
enc=Encoder(in_channels=1,
            hidden_channels=32,
            out_channels=5)
fsq=FSQ(levels=[5, 5, 5, 5, 5])
dec=Decoder(in_channels=5,
            hidden_channels=32,
            out_channels=1)

In [None]:
tubelets_enc = enc(tublets)
print(tubelets_enc.shape)

In [None]:
tubelets_fsq = fsq(tubelets_enc)
print(tubelets_fsq['z_q'].shape)

In [None]:
tubelets_dec = dec(tubelets_fsq['z_q'])
print(tubelets_dec.shape)

# Display Clips

In [None]:
def extract_tubelets(x, t=8, p=16):
    '''
    extract tubelet sequence of shape ((B * n_t * n_h * n_w), C, t, p, p)
    from a video tensor of shape (B, C, T, H, W)
    where n_t = T // t, n_h = H // p, and n_w = W // p
    '''
    assert len(x.shape) == 5, 'vid.shape must be (B, C, T, H, W)'

    B, C, T, H, W = x.shape

    assert T % t == 0, 't must divide T (vid.shape[2]) evenly'
    assert H % p == 0, 'p must divide H (vid.shape[3]) evenly'
    assert W % p == 0, 'p must divide W (vid.shape[4]) evenly'

    n_t, n_h, n_w = T // t, H // p, W // p

    # (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
    x = x.reshape(B, C, n_t, t, n_h, p, n_w, p)
    # (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
    x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
    # (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
    x = x.reshape(-1, C, t, p, p)

    # ((B * n_t * n_h * n_w), C, t, p, p) -> (B, n_t, n_h, n_w, C, t, p, p)
    x_ = x.reshape(B, n_t, n_h, n_w, C, t, p, p)
    # (B, n_t, n_h, n_w, C, t, p, p) -> (B, C, n_t, t, n_h, p, n_w, p)
    x_ = x_.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous()
    # (B, C, n_t, t, n_h, p, n_w, p) -> (B, C, T, H, W)
    x_ = x_.reshape(B, C, T, H, W)

    return x, x_

In [None]:
def display_clip(clip):
    def update(frame_idx):
        ax.clear()
        ax.imshow(video_clip_np[frame_idx], cmap='gray')
        ax.axis('off')

    video_clip_np = clip.permute(1, 2, 3, 0).numpy()
    fig, ax = plt.subplots()
    ani = FuncAnimation(fig, update, frames=range(video_clip_np.shape[0]), interval=50)
    plt.close()
    display(HTML(ani.to_html5_video()))

In [None]:
tublets, tublets_ = extract_tubelets(train_batch)
print(tublets.shape)
print(tublets_.shape)

In [None]:
display_clip(tublets_[0])

# Initial Tubelet Shape Exploration

In [None]:
# (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
tubelets = vid.reshape(B, C, n_t, t, n_h, p, n_w, p)
print(tubelets.shape)

In [None]:
# (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
tubelets = tubelets.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
print(tubelets.shape)

In [None]:
# (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
tubelets = tubelets.reshape(-1, C, t, p, p)
print(tubelets.shape)

In [None]:
conv3d_1 = nn.Conv3d(in_channels=C, out_channels=32, kernel_size=3, stride=2, padding=1)
out1 = conv3d_1(tubelets)
print(out1.shape)

In [None]:
conv3d_2 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
out2 = conv3d_2(out1)
print(out2.shape)

In [None]:
conv3d_3 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
out3 = conv3d_3(out2)
print(out3.shape)

In [None]:
conv3d_4 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=(1, 2, 2), padding=1)
out4 = conv3d_4(out3)
print(out4.shape)

In [None]:
out5 = out4.reshape(-1, 32)
print(out5.shape)

In [None]:
# separate out batch dimension post VAE encoding since dim=0 is of shape (B * n_t * n_h * n_w)
out6 = out5.reshape(out5.shape[0] // (n_t * n_h * n_w), -1, 32)
print(out6.shape)

In [None]:
out7 = out6.reshape(-1, 32)
print(out7.shape)

In [None]:
out7 = out7.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
print(out7.shape)

In [None]:
decoder1 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1))
out8 = decoder1(out7)
print(out8.shape)

In [None]:
decoder2 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=(1, 1, 1))
out9 = decoder2(out8)
print(out9.shape)

In [None]:
decoder3 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=(1, 1, 1))
out10 = decoder2(out9)
print(out10.shape)

In [None]:
decoder3 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=(1, 1, 1))
out11 = decoder2(out10)
print(out11.shape)

# Training

In [None]:
config = VideoVAEConfig()

model = TubeletFSQVAE(config)
lit_model = LitTubeletFSQVAE(model, config)

steamboat_willie_data = SteamboatWillieDataModule(config)

wandb.init(project=config.project_name, config=config.to_dict())
wandb_logger = WandbLogger(project=config.project_name, log_model=True)
wandb_logger.watch(lit_model, log="all")

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
    dirpath=config.checkpoint_path,
    filename='model-{epoch:02d}-{val_loss:.2f}',
    every_n_epochs=5,
    save_top_k=config.save_top_k,
    monitor='val/loss',
    mode='min',
    save_last=True
)

# Define the EarlyStopping callback
early_stop_callback = EarlyStopping(
    monitor='val/loss',
    min_delta=0.0000000,
    patience=50,
    verbose=True,
    check_finite=True
)

trainer = pl.Trainer(
    max_epochs=config.max_epochs,
    devices=1,
    accelerator="gpu",
    precision="16-mixed",
    logger=wandb_logger,
    callbacks=[
        lr_monitor,
        early_stop_callback,
        checkpoint_callback
    ],
    log_every_n_steps=1,
    gradient_clip_val=1.0,
    # overfit_batches=1,
)

tuner = Tuner(trainer)
lr_result = tuner.lr_find(lit_model, datamodule=steamboat_willie_data, max_lr=1)
lr_result.plot(show=True, suggest=True)

trainer.fit(lit_model, steamboat_willie_data)
wandb.finish()

# Sweeps

In [None]:
config = VideoVAEConfig()

sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'val/loss',
        'goal': 'minimize'
    },
    'parameters': {
        'lr': {
            'min': 1e-4,
            'max': 2e-4,
            'distribution': 'log_uniform_values'
        },

    }
}

sweep_id = wandb.sweep(sweep_config, project=config.project_name)

In [None]:
def train():
    with wandb.init() as run:
        config = VideoVAEConfig()
        config.update(wandb.config)

        model = TubeletFSQVAE(config)
        lit_model = LitTubeletFSQVAE(model, config)
        steamboat_willie_data = SteamboatWillieDataModule(config)
        wandb_logger = WandbLogger(project=config.project_name, log_model=False)

        early_stop_callback = EarlyStopping(
            monitor='val/loss',
            min_delta=0.00,
            patience=10,
            verbose=True,
            check_finite=True
        )

        trainer = pl.Trainer(
            max_epochs=20,
            devices=1,
            accelerator="gpu",
            precision="16-mixed",
            logger=wandb_logger,
            callbacks=[early_stop_callback]
        )

        # tuner = Tuner(trainer)
        # lr_result = tuner.lr_find(lit_model, datamodule=steamboat_willie_data, max_lr=1)
        # lr_result.plot(show=True, suggest=True)

        trainer.fit(lit_model, steamboat_willie_data)

In [None]:
# wandb.agent(sweep_id, train, count=25)
wandb.agent(sweep_id, train)