The goal here is to build a FSQ-VAE which builds latent vectors for spatio-temporal "tublets" described in the ViViT. Sequences of tublet latent vectors are to then be modeled by a Transformer Decoder

![](https://i.imgur.com/9G7QTfV.png)

# Setup

In [None]:
!pip install -q wandb pytorch_lightning av imageio

In [1]:
import os
import gc
import math
import wandb
import imageio
import torch
import torchvision
import numpy as np
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
from torch import nn, optim
from torch.nn import functional as F
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale
import torchvision.transforms.functional as TF

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.tuner import Tuner
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

pl.seed_everything(42)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
steamboat_willie_gdrive_path = '/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4'
!cp -r /content/drive/My\ Drive/SteamboatWillie/clips .

# Clip Dataset

In [None]:
class RandomHorizontalFlipVideo(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, x):
        # x shape is expected to be (C, T, H, W)
        if torch.rand(1) < self.p:
            # Flip all frames in the clip
            return x.flip(-1)
        return x


class SteamboatWillieDataset(Dataset):
    def __init__(self, config, mode='train', train_split=0.8):
        super().__init__()
        self.config = config
        self.train_split = train_split
        self.mode = mode

        self.preprocess_transforms = Compose([
                Lambda(lambda x: x.permute(0, 3, 1, 2)),   # (T, H, W, C) to (T, C, H, W) for Greyscale
                Grayscale(num_output_channels=1),          # Convert to grayscale
                Lambda(lambda x: x.permute(1, 0, 2, 3)),   # (T, C, H, W) to (C, T, H, W) for Conv3d
                CenterCrop((480, 575)),                    # Center crop to remove virtical bars
                Resize((config.img_size, config.img_size))
        ])

        self.postprocess_transforms = Compose([
            Lambda(lambda x: x / 255.),
            Lambda(lambda x: x.view(self.config.in_channels, self.config.clip_length, self.config.img_size, self.config.img_size))
        ])

        if self.mode == 'train':
            self.postprocess_transforms.transforms.append(RandomHorizontalFlipVideo(p=0.5))

        if os.path.exists(config.dest_dir):
            clip_paths = self.build_existing_clip_paths(config.dest_dir)
            self.clips = self.build_clip_refs(clip_paths)
        else:
            video_clips = VideoClips(
                config.paths,
                clip_length_in_frames=config.clip_length,
                frames_between_clips=config.clip_length
            )

            self.clips = self.build_clip_refs(self.build_clip_paths(video_clips, self.preprocess_transforms, config.dest_dir))

        if mode in ['train', 'val']:
            total_clips = len(self.clips)

            indices = torch.randperm(total_clips).tolist()
            train_size = int(total_clips * train_split)

            if mode == 'train':
                self.clip_indices = indices[:train_size]
            else:
                self.clip_indices = indices[train_size:]
        elif mode == 'full':
            self.clip_indices = list(range(len(self.clips)))

    def build_clip_paths(self, video_clips, transforms, dest_dir):
        """
        Build set of binary files to store processed video clips
        returns dict of clip_idx -> mmapped file path
        """
        clip_paths = {}

        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)

        for idx in tqdm(range(video_clips.num_clips()), desc='Creating clip .bin files'):
            # transform clips and write to mmap file
            clip, _, _, _ = video_clips.get_clip(idx)
            clip = self.preprocess_transforms(clip)
            clip_np = clip.numpy().astype(np.uint8)

            mmapped_file_path = os.path.join(dest_dir, f'clip_{idx}.bin')
            fp = np.memmap(mmapped_file_path, dtype='uint8', mode='w+', shape=clip_np.shape)
            fp[:] = clip_np[:]
            fp.flush()
            del fp
            clip_paths[idx] = mmapped_file_path

        return clip_paths

    def build_existing_clip_paths(self, dest_dir):
        """"
        returns dict of clip_idx -> mmapped file path
        from existing .bin files
        """
        clips_paths = {}
        for filename in os.listdir(dest_dir):
            if filename.startswith('clip_') and filename.endswith('.bin'):
                idx = int(filename.split('_')[1].split('.')[0])
                file_path = os.path.join(dest_dir, filename)
                clips_paths[idx] = file_path

        return clips_paths

    def build_clip_refs(self, clip_paths):
        """
        Build mmap reference to bin files
        returns dict of clip_idx -> np.array mmapped to respective bin file
        """
        clips = {}
        for idx, path in tqdm(clip_paths.items(), desc='Building clip refs'):
            clips[idx] = np.memmap(path, dtype='uint8', mode='r')

        return clips

    def __len__(self):
        return len(self.clip_indices)

    def __getitem__(self, idx):
        clip = self.clips[self.clip_indices[idx]]
        return self.postprocess_transforms(torch.tensor(clip, dtype=torch.float32))


class SteamboatWillieDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.batch_size = config.batch_size
        self.config = config

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = SteamboatWillieDataset(self.config, mode='train')
            self.val_dataset = SteamboatWillieDataset(self.config, mode='val')

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.config.num_workers)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.config.num_workers)

# FSQ-VAE Model

In [None]:
class ResBlock3d(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            singleton_time_dim=False
        ):
        super().__init__()
        if in_channels != out_channels:
            self.identity = nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            )
        else:
            self.identity = nn.Identity()

        hidden_channels = in_channels // 4

        self.block = nn.Sequential(
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=hidden_channels,
                kernel_size=(1, 3, 3) if singleton_time_dim else (3, 3, 3),
                padding=(0, 1, 1) if singleton_time_dim else (1, 1, 1)
            ),
            nn.BatchNorm3d(hidden_channels),
            nn.ReLU(),
            nn.Conv3d(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
                kernel_size=(1, 3, 3) if singleton_time_dim else (3, 3, 3),
                padding=(0, 1, 1) if singleton_time_dim else (1, 1, 1)
            ),
            nn.BatchNorm3d(hidden_channels),
            nn.ReLU(),
            nn.Conv3d(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
                kernel_size=(1, 3, 3) if singleton_time_dim else (3, 3, 3),
                padding=(0, 1, 1) if singleton_time_dim else (1, 1, 1)
            ),
            nn.BatchNorm3d(hidden_channels),
            nn.ReLU(),
            nn.Conv3d(
                in_channels=hidden_channels,
                out_channels=out_channels,
                kernel_size=1
            ),
            nn.BatchNorm3d(out_channels)
        )
        self.res_act = nn.ReLU()

    def forward(self, x):
        out = self.block(x) + self.identity(x)
        return self.res_act(out)




In [None]:
class Encoder(nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            nlayers,
            nblocks
        ):
        super().__init__()
        self.downsample_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.Conv3d(
                    in_channels=in_channels if i==0 else hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=(3, 3, 3),
                    stride=(2, 2, 2),
                    padding=(1, 1, 1)
                ),
                nn.BatchNorm3d(hidden_channels),
                nn.ReLU(),
                nn.Sequential(*[
                    ResBlock3d(
                        in_channels=hidden_channels,
                        out_channels=out_channels if i==nlayers-1 and j==nblocks-1 else hidden_channels,
                        singleton_time_dim=(i==nlayers-1)
                    ) for j in range(nblocks)
                ])
            ) for i in range(nlayers)
        ])

    def forward(self, x):
        # ((B * n_t * n_h * n_w), C, t, p, p) -> ((B * n_t * n_h * n_w), latent_channels, 1, 4, 4)
        x = self.downsample_blocks(x)
        return x

In [None]:
class Decoder(torch.nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            nlayers,
            nblocks
        ):
        super().__init__()

        self.res_blocks = nn.Sequential(*[
            ResBlock3d(
                in_channels=in_channels if i==0 else hidden_channels,
                out_channels=hidden_channels,
            )
            for i in range(nblocks)
        ])

        self.upsample_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.Sequential(*[
                    ResBlock3d(
                        in_channels=in_channels if i==0 and j==0 else hidden_channels,
                        out_channels=hidden_channels,
                        singleton_time_dim=(i==0)
                    )
                    for j in range(nblocks)
                ]),
                nn.ConvTranspose3d(
                    in_channels=hidden_channels,
                    out_channels=out_channels if i==nlayers-1 else hidden_channels,
                    kernel_size=(3, 3, 3),
                    stride=(2, 2, 2),
                    padding=(1, 1, 1),
                    output_padding=(1, 1, 1)
                    ),
                nn.BatchNorm3d(hidden_channels) if i<nlayers-1 else nn.Identity(),
                nn.ReLU() if i<nlayers-1 else nn.Identity()
            ) for i in range(nlayers)
        ])
        self.out_act = nn.Sigmoid()

    def forward(self, x):
        # ((B * n_t * n_h * n_w), latent_channels, 1, 4, 4) -> ((B * n_t * n_h * n_w), C, t, p, p)
        x = self.upsample_blocks(x)
        x = self.out_act(x)
        return x

In [None]:
class FSQ(nn.Module):
    def __init__(self, levels, eps=1e-3):
        super().__init__()
        self.register_buffer('levels', torch.tensor(levels))
        self.register_buffer(
            'basis',
            torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
        )

        self.eps = eps
        self.codebook_size = torch.prod(self.levels)

        # self.register_buffer('implicit_codebook', self.idxs_to_code(torch.arange(self.codebook_size)))

    def round_ste(self, z):
        z_q = torch.round(z)
        return z + (z_q - z).detach()

    def quantize(self, z):
        # half_l is used to determine how to scale tanh; we
        # subtract 1 from the number of levels to account for 0
        # being a quantization bin and tanh being symmetric around 0
        half_l = (self.levels - 1) * (1 - self.eps) / 2

        # if a given level is even, it will result in a scale for tanh
        # which is halfway between integer values, so we offset
        # the tanh output down by 0.5 to line it with whole integers
        offset = torch.where(self.levels % 2 == 0, 0.5, 0.0)

        # if our level is even, we want to shift the tanh input to
        # ensure the 0 quantization bin is centered
        shift = torch.tan(offset / half_l)

        # once we have our shift and offset (in the case of an even level)
        # we can round to the nearest integer bin and allow for STE
        z_q = self.round_ste(torch.tanh(z + shift) * half_l - offset)

        # after quantization, we want to renormalize the quantized
        # values to be within the range expected by the model (ie. [-1, 1])
        half_width = self.levels // 2
        return z_q / half_width

    def scale_and_shift(self, z_q_normalized):
        half_width = self.levels // 2
        return (z_q_normalized * half_width) + half_width

    def scale_and_shift_inverse(self, z_q):
        half_width = self.levels // 2
        return (z_q - half_width) / half_width

    def code_to_idxs(self, z_q):
        z_q = self.scale_and_shift(z_q)
        return (z_q * self.basis).sum(dim=-1).to(torch.int32)

    def idxs_to_code(self, idxs):
        idxs = idxs.unsqueeze(-1)
        codes_not_centered = (idxs // self.basis) % self.levels
        return self.scale_and_shift_inverse(codes_not_centered)

    def forward(self, z):
        # TODO: make this work for generic tensor sizes
        # TODO: use einops to clean up

        if len(z.shape) == 5: # video
            B, C, T, H, W = z.shape
            # (B, C, T, H, W) -> (B, T, H, W, C)
            z_c_last = z.permute(0, 2, 3, 4, 1).contiguous()
            # (B, T, H, W, C) -> (BTHW, C)
            z_flatten = z_c_last.reshape(-1, C)
            z_flatten_q = self.quantize(z_flatten)
            # (BTHW, C) -> (B, T, H, W, C) -> (B, C, T, H, W)
            z_q = z_flatten_q.reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()

        elif len(z.shape) == 4: # image
            B, C, H, W = z.shape
            # (B, C, H, W) -> (B, H, W, C)
            z_c_last = z.permute(0, 2, 3, 1).contiguous()
            # (B, H, W, C) -> (BHW, C)
            z_flatten = z_c_last.reshape(-1, C)
            z_flatten_q = self.quantize(z_flatten)
            # (BHW, C) -> (B, H, W, C) -> (B, C, T, H, W)
            z_q = z_flatten_q.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()

        elif len(z.shape) == 2: # vector sequence
            # (B, C)
            z_q = self.quantize(z)

        return {'z_q': z_q}

In [None]:
class TubeletFSQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(
            in_channels     = config.in_channels,
            hidden_channels = config.hidden_channels,
            out_channels    = config.latent_channels,
            nlayers         = config.nlayers,
            nblocks         = config.nblocks
        )
        self.decoder = Decoder(
            in_channels     = config.latent_channels,
            hidden_channels = config.hidden_channels,
            out_channels    = config.in_channels,
            nlayers         = config.nlayers,
            nblocks         = config.nblocks
        )

        self.quantizer = FSQ(config.levels)

        self.config = config

    def encode(self, inputs):
        inputs = self.extract_tubelets(inputs, self.config.t, self.config.p)
        return self.encoder(inputs)

    def quantize(self, z):
        return self.quantizer(z)

    def decode(self, z_q):
        return self.decoder(z_q)

    def extract_tubelets(self, x, t, p):
        '''
        extract tubelet sequence of shape ((B * n_t * n_h * n_w), C, t, p, p)
        from a video tensor of shape (B, C, T, H, W)
        where n_t = T // t, n_h = H // p, and n_w = W // p
        '''
        assert len(x.shape) == 5, 'x.shape must be (B, C, T, H, W)'

        B, C, T, H, W = x.shape

        assert T % t == 0, 't must divide T (x.shape[2]) evenly'
        assert H % p == 0, 'p must divide H (x.shape[3]) evenly'
        assert W % p == 0, 'p must divide W (x.shape[4]) evenly'

        n_t, n_h, n_w = T // t, H // p, W // p

        # (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
        x = x.reshape(B, C, n_t, t, n_h, p, n_w, p)
        # (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
        x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
        # (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
        x = x.reshape(-1, C, t, p, p)

        return x

    def assemble_tubelets(self, x, B, C, T, H, W, t, p):
        '''
        reassemble tubelet sequence of shape ((B * n_t * n_h * n_w), C, t, p, p)
        into a video tensor of shape (B, C, T, H, W)
        '''
        assert len(x.shape) == 5, 'x.shape must be ((B * n_t * n_h * n_w), C, t, p, p)'

        assert T % t == 0, 't must divide T (x.shape[2]) evenly'
        assert H % p == 0, 'p must divide H (x.shape[3]) evenly'
        assert W % p == 0, 'p must divide W (x.shape[4]) evenly'

        n_t, n_h, n_w = T // t, H // p, W // p

        # ((B * n_t * n_h * n_w), C, t, p, p) -> (B, n_t, n_h, n_w, C, t, p, p)
        x = x.reshape(B, n_t, n_h, n_w, C, t, p, p)
        # (B, n_t, n_h, n_w, C, t, p, p) -> (B, C, n_t, t, n_h, p, n_w, p)
        x = x.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous()
        # (B, C, n_t, t, n_h, p, n_w, p) -> (B, C, T, H, W)
        x = x.reshape(B, C, T, H, W)

        return x

    def latent_img_to_seq(self, latent_img):
        '''
        convert latent image from encoder of shape (B, latent_channels, 1, 4, 4) into sequence
        of latent vectors of shape ((B * 1 * 4 * 4), latent_channels)
        '''
        return latent_img.permute(0, 2, 3, 4, 1).contiguous().reshape(-1, self.config.latent_channels)

    def latent_seq_to_img(self, latent_seq):
        '''
        convert latent sequence of shape ((B * 1 * 4 * 4), latent_channels) into latent images
        of shape (B, latent_channels, 1, 4, 4)
        '''
        return latent_seq.reshape(-1, 1, 4, 4, self.config.latent_channels).permute(0, 4, 1, 2, 3).contiguous()


    def loss(self, x_hat, x, quantized):
        MSE = F.mse_loss(x_hat, x)

        return {
            'loss': MSE,
            **quantized
        }

    def forward(self, x):
        z = self.encode(x)
        quantized = self.quantize(z)
        x_hat = self.decode(quantized['z_q'])

        # ((B * n_t * n_h * n_w), C, t, p, p)
        losses = self.loss(x_hat, self.extract_tubelets(x, self.config.t, self.config.p), quantized)

        return {'x_hat': x_hat, 'z': z, 'z_q': quantized['z_q'], **losses}

# Transformer Model

In [None]:
class ALiBiSelfAttention(nn.Module):
    def __init__(self, emb_dim, qkv_dim, nheads, ctx_size, window_size, dropout=0.0, use_flash_attn=True):
        super().__init__()
        assert qkv_dim % nheads == 0
        # TODO: sweep over attn_emb_dim values

        # project latent embedding to higher qkv dim to increased attention capacity
        self.W_Q = nn.Linear(emb_dim, qkv_dim, bias=False)
        self.W_K = nn.Linear(emb_dim, qkv_dim, bias=False)
        self.W_V = nn.Linear(emb_dim, qkv_dim, bias=False)

        # project latent embedding back to emb_dim after attention has been computed
        self.W_O = nn.Linear(qkv_dim, emb_dim, bias=False)

        if not use_flash_attn:
            self.register_buffer(
                "mask",
                torch.tril(torch.ones((ctx_size, ctx_size))).reshape(
                    1, 1, ctx_size, ctx_size
                ),
            )

        self.qkv_dim = qkv_dim
        self.head_dim = qkv_dim // nheads
        self.nheads = nheads
        self.window_size = window_size

        self.dropout_p = dropout
        self.dropout = nn.Dropout(dropout)

        self.use_flash_attn = use_flash_attn

    def forward(self, x, slopes):
        B, T, D = x.size()

        # (B, T, D) -> (B, T, qkv_dim) -> (B, T, H, H_D)
        Q = self.W_Q(x).reshape(B, T, self.nheads, self.head_dim)
        K = self.W_K(x).reshape(B, T, self.nheads, self.head_dim)
        V = self.W_V(x).reshape(B, T, self.nheads, self.head_dim)

        if self.use_flash_attn:
            out = flash_attn_func(
                Q, K, V,
                dropout_p=self.dropout_p if self.training else 0.0,
                softmax_scale=None,
                causal=True,
                window_size=(self.window_size, 0),
                alibi_slopes=slopes.to(torch.float32),
                deterministic=False
            )
        else:
            # (B, T, H, H_D) -> (B, H, T, H_D)
            Q = Q.transpose(1, 2)
            K = K.transpose(1, 2)
            V = V.transpose(1, 2)

            # (B, H, T, H_D) @ (B, H, H_D, T) -> (B, H, T, T)
            attn = (Q @ K.transpose(-2, -1)) / (1.0 * math.sqrt(self.head_dim))
            # attn = attn + bias
            attn = attn.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
            attn = F.softmax(attn, dim=-1)

            attn = self.dropout(attn)

            # ((B, H, T, T)) @ (B, H, T, H_D) -> (B, H, T, H_D)
            out = attn @ V

        # (B, H, T, H_D) -> (B, T, H, H_D) -> (B, T, D)
        out = out.transpose(1, 2).reshape(B, T, self.qkv_dim)

        # (B, T, D)
        out = self.W_O(out)

        return out


class MLP(nn.Module):
    def __init__(self, emb_dim, dropout=0.0, fan_out=100):
        super().__init__()
        # TODO: sweep over fan_out
        self.fc1 = nn.Linear(emb_dim, fan_out * emb_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(fan_out * emb_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.dropout(x)

        return x


class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, qkv_dim, nheads, ctx_size, window_size, dropout=0.0, fan_out=100, use_flash_attn=True):
        super().__init__()
        self.ln_1 = nn.LayerNorm(emb_dim)
        self.attn = ALiBiSelfAttention(emb_dim, qkv_dim, nheads, ctx_size, window_size, dropout, use_flash_attn)
        self.ln_2 = nn.LayerNorm(emb_dim)
        self.mlp = MLP(emb_dim, dropout, fan_out)

    def forward(self, x, slopes):
        x = x + self.attn(self.ln_1(x), slopes)
        x = x + self.mlp(self.ln_2(x))

        return x


class TransformerDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(
                config.emb_dim,
                config.qkv_dim,
                config.nheads,
                config.ctx_size,
                config.window_size,
                config.dropout,
                config.fan_out,
                config.use_flash_attn
            ) for _ in range(config.ntlayers)
        ])
        self.pred_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
        self.ln_f = nn.LayerNorm(config.emb_dim)

        self.register_buffer("m", self.get_alibi_slope(config.nheads))
        self.window_size = config.window_size

        self.use_flash_attn = config.use_flash_attn

        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        B, T, D = x.size()

        if not self.use_flash_attn:
            bias = (self.m * self.get_relative_positions(self.window_size).to(x.device)).unsqueeze(0)

        x = self.dropout(x)

        for block in self.blocks:
            x = block(x, self.m)

        logits = self.pred_head(self.ln_f(x))

        return logits

    def get_relative_positions(self, seq_len: int) -> torch.tensor:
        x = torch.arange(seq_len)[None, :]
        y = torch.arange(seq_len)[:, None]
        return (x - y).clamp_max_(0)


    def get_alibi_slope(self, num_heads):
        x = (2 ** 8) ** (1 / num_heads)
        return (
            torch.tensor([1 / x ** (i + 1) for i in range(num_heads)])
            # .unsqueeze(-1)
            # .unsqueeze(-1)
        )

In [None]:
def extract_tubelets(x, t=8, p=16):
    '''
    extract tubelet sequence of shape ((B * n_t * n_h * n_w), C, t, p, p)
    from a video tensor of shape (B, C, T, H, W)
    where n_t = T // t, n_h = H // p, and n_w = W // p
    '''
    assert len(x.shape) == 5, 'vid.shape must be (B, C, T, H, W)'

    B, C, T, H, W = x.shape

    assert T % t == 0, 't must divide T (vid.shape[2]) evenly'
    assert H % p == 0, 'p must divide H (vid.shape[3]) evenly'
    assert W % p == 0, 'p must divide W (vid.shape[4]) evenly'

    n_t, n_h, n_w = T // t, H // p, W // p

    # (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
    x = x.reshape(B, C, n_t, t, n_h, p, n_w, p)
    # (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
    x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
    # (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
    x = x.reshape(-1, C, t, p, p)

    return x

# Lightning Module

In [None]:
class LitTubeletFSQVAE(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.lr = config.lr

        if self.logger:
            self.logger.experiment.config.update(config)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch
        out = self(x)

        self.log('train/loss', out['loss'], prog_bar=True)

        return out['loss']

    def validation_step(self, batch, batch_idx):
        x = batch
        out = self(x)

        self.log('val/loss', out['loss'], prog_bar=True)

        if batch_idx == 0:
            # ((B * n_t * n_h * n_w), C, t, p, p) -> (B, C, T, H, W)
            t, p = out['x_hat'].shape[-3], out['x_hat'].shape[-2]
            out['x_hat'] = self.model.assemble_tubelets(out['x_hat'], *x.shape, t, p)
            self.log_val_clips(x, out)

        return out['loss']

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            betas=(self.config.beta1, self.config.beta2),
            weight_decay=self.config.weight_decay
        )

        if self.config.use_lr_schedule:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config.max_epochs)
            return [optimizer], [scheduler]

        return optimizer

    def log_val_clips(self, x, out, num_clips=2):
        n_clips = min(x.size(0), num_clips)

        for i in range(n_clips):
            # extract the ith original and reconstructed clip
            original_clip = x[i]  # (C, T, H, W)
            reconstructed_clip = out['x_hat'][i]  # (C, T, H, W)

            # convert tensors to numpy arrays and transpose to (T, H, W, C) for GIF creation
            original_clip_np = original_clip.permute(1, 2, 3, 0).cpu().numpy()
            reconstructed_clip_np = reconstructed_clip.permute(1, 2, 3, 0).cpu().numpy()

            original_clip_np = (original_clip_np - original_clip_np.min()) / (original_clip_np.max() - original_clip_np.min())
            reconstructed_clip_np = (reconstructed_clip_np - reconstructed_clip_np.min()) / (reconstructed_clip_np.max() - reconstructed_clip_np.min())

            original_clip_np = (original_clip_np * 255).astype(np.uint8)
            reconstructed_clip_np = (reconstructed_clip_np * 255).astype(np.uint8)

            # grayscale videos need to be of shape (T, H, W)
            if original_clip_np.shape[-1] == 1:
                original_clip_np = original_clip_np.squeeze(-1)

            if reconstructed_clip_np.shape[-1] == 1:
                reconstructed_clip_np = reconstructed_clip_np.squeeze(-1)

            # create GIFs for the original and reconstructed clips
            original_gif_path = f'/tmp/original_clip_{i}.gif'
            reconstructed_gif_path = f'/tmp/reconstructed_clip_{i}.gif'
            imageio.mimsave(original_gif_path, original_clip_np, fps=10)
            imageio.mimsave(reconstructed_gif_path, reconstructed_clip_np, fps=10)

            # log the GIFs to wandb
            self.logger.experiment.log({
                f"val/original_clip_{i}": wandb.Video(original_gif_path, fps=10, format="gif", caption="Original"),
                f"val/reconstructed_clip_{i}": wandb.Video(reconstructed_gif_path, fps=10, format="gif", caption="Reconstructed")
            })

In [None]:
class LitTransformerDecoder(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.lr = config.lr

        self.fsq = FSQ(config.levels) # need codes -> idxs

        if self.logger:
            self.logger.experiment.config.update(config)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)

        y_idxs = self.fsq.codes_to_idxs(y).to(torch.long)
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y_idxs.reshape(-1))

        self.log('train/loss', loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)

        y_idxs = self.fsq.codes_to_idxs(y).to(torch.long)
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y_idxs.reshape(-1))

        self.log('val/loss', loss, prog_bar=True)

        if batch_idx == 0:
            self.log_val_clips(x)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            betas=(self.config.beta1, self.config.beta2),
            weight_decay=self.config.weight_decay
        )

        if self.config.use_lr_schedule:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config.max_epochs)
            return [optimizer], [scheduler]

        return optimizer

    def log_val_clips(self, lat_seq):
        pass

# Config

In [None]:
def get_num_downsample_layers(img_size, target_size=2):
    """
    get the number of strided Conv layers
    required to produce a target output volume.

    The minimum number of downsample layers required for achieving the target output volume.
    """
    if img_size < target_size:
        raise ValueError(f"Image size must be at least {target_size}x{target_size}.")

    # Calculate the minimum number of downsample layers required for the target final size
    num_layers = math.ceil(math.log2(img_size / target_size))
    return num_layers

class VideoGenConfig:
    def __init__(self):
        self.project_name = 'Steamboat Willie Latent Model'#'Tubelet FSQ-VAE Steamboat Willie'
        # dataset properties
        self.paths = ['/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4']
        self.clip_dest_dir = './clips/'
        self.latent_seqs_dest_dir = './latent_seqs/'
        # model checkpoints
        self.checkpoint_path = "./checkpoints"
        self.fsq_vae_checkpoint_path = "./fsq-vae-model.ckpt"
        self.save_top_k = 1
        # FSQ-VAE training
        self.train_split = 0.8
        self.batch_size = 10
        self.max_epochs = 1000
        self.num_workers = 2
        # FSQ-VAE optimizer
        self.lr = 5e-4
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.weight_decay = 0.0 # 1e-2
        self.use_wd_schedule = False
        self.use_lr_schedule = True
        # FSQ-VAE input properties
        self.clip_length = 16
        self.img_size = 256
        self.tubelet_size = 16
        self.in_channels = 1
        # quantization
        self.quant_mode = 'fsq' # 'vq'
        self.latent_channels = 5
        self.codebook_size = 512
        self.commit_loss_beta = 0.25
        self.track_codebook = True
        self.use_ema = True
        self.ema_gamma = 0.99
        self.level = 11
        self.levels = [self.level for _ in range(self.latent_channels)]
        # encoder/decoder
        self.start_channels = 32
        self.nlayers = get_num_downsample_layers(self.tubelet_size, 4)
        self.nblocks = 4
        self.hidden_channels = 256
        # self.nlayers = 3
        # tubelet
        self.t = 4
        self.p = 16
        # transformer
        self.emb_dim=5
        self.qkv_dim=2048
        self.nheads=8
        self.ntlayers=12
        self.ctx_size=16383,
        self.window_size=8192#16383,
        self.vocab_size=math.prod(self.levels)
        self.dropout=0.0
        self.fan_out=410
        self.use_flash_attn=True
        self.t_batch_size=1
        self.t_num_workers=2
        self.grad_acc_steps=2


    def update(self, updates):
        for key, value in updates.items():
            if hasattr(self, key):
                setattr(self, key, value)
                if key == 'level' or key == 'latent_channels':
                    self.levels = [self.level for _ in range(self.latent_channels)]

    def to_dict(self):
        return {k: v for k, v in self.__dict__.items()}

# Component Testing

In [None]:
config = VideoGenConfig()
model = TubeletFSQVAE(config)
dataset = SteamboatWillieDataset(config, mode='full')

In [None]:
checkpoint = torch.load('fsq-vae-model.ckpt')

In [None]:
new_state_dict = {}
for key, value in checkpoint['state_dict'].items():
    new_key = key.replace('model.', '')
    new_state_dict[new_key] = value

model.load_state_dict(new_state_dict)

In [None]:
clips = {}
q_latents = {}
reconstructions = {}
model.eval()
model = model.to('cuda:0')
with torch.no_grad():
    for i, clip in tqdm(enumerate(dataset)):
        # collect original clips to display
        clips[i] = clip
        clip = clip.to('cuda:0')
        
        # collect clips to display
        out = model(clip.unsqueeze(0))

        # collect latents
        q_latents[i] = out['z_q'].cpu().numpy()
        
        # collect reconstructions
        x_hat = model.assemble_tubelets(out['x_hat'], *clip.unsqueeze(0).shape, out['x_hat'].shape[-3], out['x_hat'].shape[-2])
        reconstructions[i] = x_hat.cpu().numpy()

In [None]:
reconstructions[0].shape

In [None]:
display_clip(torch.tensor(reconstructions[0][0]))

In [None]:
batch = next(iter(train_loader))

In [None]:
model.eval()
with torch.no_grad():
    batch = batch.to('cuda:0')
    model = model.to('cuda:0')
    
    latents = model.encode(batch)
    print(latents.shape)

In [None]:
latents = latents.permute(0, 2, 3, 4, 1).contiguous()
print(latents.shape)

In [None]:
latents = latents.reshape(-1, latents.shape[-1])
print(latents.shape)

In [None]:
tublets = extract_tubelets(train_batch)
print(tublets.shape)

In [None]:
enc=Encoder(in_channels=config.in_channels,
            hidden_channels=config.hidden_channels,
            nlayers=config.nlayers,
            nblocks=config.nblocks,
            out_channels=config.latent_channels)

fsq=FSQ(levels=[config.level for _ in range(config.latent_channels*4)])

dec=Decoder(in_channels=config.latent_channels,
            hidden_channels=config.hidden_channels,
            nlayers=config.nlayers,
            nblocks=config.nblocks,
            out_channels=config.in_channels)

enc, fsq, dec, tublets = enc.to('cuda:0'), fsq.to('cuda:0'), dec.to('cuda:0'), tublets.to('cuda:0')

In [None]:
tubelets_enc = enc(tublets)
print(tubelets_enc.shape)

In [None]:
tubelets_fsq = fsq(tubelets_enc)
print(tubelets_fsq['z_q'].shape)

In [None]:
tubelets_dec = dec(tubelets_fsq['z_q'])
print(tubelets_dec.shape)

# Display Clips

In [None]:
def extract_tubelets(x, t=8, p=16):
    '''
    extract tubelet sequence of shape ((B * n_t * n_h * n_w), C, t, p, p)
    from a video tensor of shape (B, C, T, H, W)
    where n_t = T // t, n_h = H // p, and n_w = W // p
    '''
    assert len(x.shape) == 5, 'vid.shape must be (B, C, T, H, W)'

    B, C, T, H, W = x.shape

    assert T % t == 0, 't must divide T (vid.shape[2]) evenly'
    assert H % p == 0, 'p must divide H (vid.shape[3]) evenly'
    assert W % p == 0, 'p must divide W (vid.shape[4]) evenly'

    n_t, n_h, n_w = T // t, H // p, W // p

    # (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
    x = x.reshape(B, C, n_t, t, n_h, p, n_w, p)
    # (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
    x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
    # (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
    x = x.reshape(-1, C, t, p, p)

    # ((B * n_t * n_h * n_w), C, t, p, p) -> (B, n_t, n_h, n_w, C, t, p, p)
    x_ = x.reshape(B, n_t, n_h, n_w, C, t, p, p)
    # (B, n_t, n_h, n_w, C, t, p, p) -> (B, C, n_t, t, n_h, p, n_w, p)
    x_ = x_.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous()
    # (B, C, n_t, t, n_h, p, n_w, p) -> (B, C, T, H, W)
    x_ = x_.reshape(B, C, T, H, W)

    return x, x_

In [None]:
def display_clip(clip):
    def update(frame_idx):
        ax.clear()
        ax.imshow(video_clip_np[frame_idx], cmap='gray')
        ax.axis('off')

    video_clip_np = clip.permute(1, 2, 3, 0).numpy()
    fig, ax = plt.subplots()
    ani = FuncAnimation(fig, update, frames=range(video_clip_np.shape[0]), interval=50)
    plt.close()
    display(HTML(ani.to_html5_video()))

In [None]:
tublets, tublets_ = extract_tubelets(train_batch)
print(tublets.shape)
print(tublets_.shape)

In [None]:
display_clip(tublets_[0])

# Initial Tubelet Shape Exploration

In [None]:
# (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
tubelets = vid.reshape(B, C, n_t, t, n_h, p, n_w, p)
print(tubelets.shape)

In [None]:
# (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
tubelets = tubelets.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
print(tubelets.shape)

In [None]:
# (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
tubelets = tubelets.reshape(-1, C, t, p, p)
print(tubelets.shape)

In [None]:
conv3d_1 = nn.Conv3d(in_channels=C, out_channels=32, kernel_size=3, stride=2, padding=1)
out1 = conv3d_1(tubelets)
print(out1.shape)

In [None]:
conv3d_2 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
out2 = conv3d_2(out1)
print(out2.shape)

In [None]:
conv3d_3 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
out3 = conv3d_3(out2)
print(out3.shape)

In [None]:
conv3d_4 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=(1, 2, 2), padding=1)
out4 = conv3d_4(out3)
print(out4.shape)

In [None]:
out5 = out4.reshape(-1, 32)
print(out5.shape)

In [None]:
# separate out batch dimension post VAE encoding since dim=0 is of shape (B * n_t * n_h * n_w)
out6 = out5.reshape(out5.shape[0] // (n_t * n_h * n_w), -1, 32)
print(out6.shape)

In [None]:
out7 = out6.reshape(-1, 32)
print(out7.shape)

In [None]:
out7 = out7.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
print(out7.shape)

In [None]:
decoder1 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1))
out8 = decoder1(out7)
print(out8.shape)

In [None]:
decoder2 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=(1, 1, 1))
out9 = decoder2(out8)
print(out9.shape)

In [None]:
decoder3 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=(1, 1, 1))
out10 = decoder2(out9)
print(out10.shape)

In [None]:
decoder3 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=(1, 1, 1))
out11 = decoder2(out10)
print(out11.shape)

# FSQ-VAE Training

In [None]:
resume_training = True

config = VideoGenConfig()

model = TubeletFSQVAE(config)
lit_model = LitTubeletFSQVAE(model, config)

steamboat_willie_data = SteamboatWillieDataModule(config)

if resume_training:
    run = wandb.init(project=config.project_name, config=config.to_dict(), resume=True)
    artifact = run.use_artifact('CKPT NAME HERE', type='model')
    artifact_dir = artifact.download()
else:
    wandb.init(project=config.project_name, config=config.to_dict())

wandb_logger = WandbLogger(project=config.project_name, log_model=True)
wandb_logger.watch(lit_model, log="all")

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
    dirpath=config.checkpoint_path,
    filename='model-{epoch:02d}-{val_loss:.2f}',
    every_n_epochs=5,
    save_top_k=config.save_top_k,
    monitor='val/loss',
    mode='min',
    save_last=True
)

# Define the EarlyStopping callback
early_stop_callback = EarlyStopping(
    monitor='val/loss',
    min_delta=0.0000000,
    patience=100,
    verbose=True,
    check_finite=True
)

trainer = pl.Trainer(
    max_epochs=config.max_epochs,
    devices=1,
    accelerator="gpu",
    precision="16-mixed",
    logger=wandb_logger,
    callbacks=[
        lr_monitor,
        early_stop_callback,
        checkpoint_callback
    ],
    log_every_n_steps=1,
    # gradient_clip_val=1.0,
    # overfit_batches=1,
)

# tuner = Tuner(trainer)
# lr_result = tuner.lr_find(lit_model, datamodule=steamboat_willie_data, max_lr=100)
# lr_result.plot(show=True, suggest=True)
if resume_training:
    trainer.fit(lit_model, steamboat_willie_data, ckpt_path=f'{artifact_dir}/model.ckpt')
else:
    trainer.fit(lit_model, steamboat_willie_data)

wandb.finish()

# FSQ-VAE Sweeps

In [None]:
config = VideoGenConfig()

sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'val/loss',
        'goal': 'minimize'
    },
    'parameters': {
        'lr': {
            'min': 1e-4,
            'max': 2e-4,
            'distribution': 'log_uniform_values'
        },

    }
}

sweep_id = wandb.sweep(sweep_config, project=config.project_name)

In [None]:
def train():
    with wandb.init() as run:
        config = VideoGenConfig()
        config.update(wandb.config)

        model = TubeletFSQVAE(config)
        lit_model = LitTubeletFSQVAE(model, config)
        steamboat_willie_data = SteamboatWillieDataModule(config)
        wandb_logger = WandbLogger(project=config.project_name, log_model=False)

        early_stop_callback = EarlyStopping(
            monitor='val/loss',
            min_delta=0.00,
            patience=10,
            verbose=True,
            check_finite=True
        )

        trainer = pl.Trainer(
            max_epochs=20,
            devices=1,
            accelerator="gpu",
            precision="16-mixed",
            logger=wandb_logger,
            callbacks=[early_stop_callback]
        )

        # tuner = Tuner(trainer)
        # lr_result = tuner.lr_find(lit_model, datamodule=steamboat_willie_data, max_lr=1)
        # lr_result.plot(show=True, suggest=True)

        trainer.fit(lit_model, steamboat_willie_data)

In [None]:
# wandb.agent(sweep_id, train, count=25)
wandb.agent(sweep_id, train)

# Latent Sequence Dataset

In [None]:
class LatentSequenceDataset(Dataset):
    def __init__(self, config, mode='train', train_split=0.8):
        '''
        We want to ensure that the transformer is only tasked with learning
        relationships at the tubelet level for faster convergence, so we must
        ensure chunks sampled are aligned with tubelet boundries
        '''
        self.config = config

        if not os.path.exists(config.latent_seqs_dest_dir):
            print('Building latent sequences...')
            fsq_vae = TubeletFSQVAE(config)
            checkpoint = torch.load(config.fsq_vae_checkpoint_path)
            new_state_dict = {}
            for key, value in checkpoint['state_dict'].items():
                new_key = key.replace('model.', '')
                new_state_dict[new_key] = value

            fsq_vae.load_state_dict(new_state_dict)

            clip_train_dataset = SteamboatWillieDataset(config, mode='train')
            clip_val_dataset = SteamboatWillieDataset(config, mode='val')

            clip_lat_seq_shape = (config.latents_per_clip, config.latent_channels)

            self.build_latent_datasets(fsq_vae, clip_train_dataset, clip_val_dataset, clip_lat_seq_shape, config.latent_seqs_dest_dir)

        if mode in ['train', 'val']: self.mode = mode
    
    def build_latent_datasets(self, enc_model, clip_train_dataset, clip_val_dataset, clip_lat_seq_shape, dest_dir):
        """
        Build set of binary files to store encoded/quantized clip latents sequences
        returns dict of latent_seq_idx -> mmapped file path
        """
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)

        if torch.cuda.is_available():
            enc_model.cuda()
        
        enc_model.eval()
        for split, dataset in {'train': clip_train_dataset, 'val': clip_val_dataset}.items():
            mmapped_file_path = os.path.join(dest_dir, f'{split}.bin')
            # (number of clips x latents per clip, latent dim)
            shape = (len(dataset) * clip_lat_seq_shape[0], clip_lat_seq_shape[1])
            fp = np.memmap(mmapped_file_path, dtype='float32', mode='w+', shape=shape)
            
            idx = 0
            for clip in tqdm(dataset, desc='Creating latent sequence .bin files'):
                if torch.cuda.is_available():
                    clip = clip.cuda()
                with torch.no_grad():
                    quantized = enc_model.quantize(enc_model.encode(clip.unsqueeze(0)))

                # (B, 5, 1, 4, 4) -> (B, 1, 4, 4, 5) -> ((B*1*4*4), 5)
                latent_seq_np = enc_model.latent_img_to_seq(quantized['z_q']).cpu().numpy()
                fp[idx:idx + latent_seq_np.shape[0]] = latent_seq_np[:]
                idx += latent_seq_np.shape[0]
            
            fp.flush()

    def __len__(self):
        '''
        we need to ensure our idxs are aligned with tubelet boundaries (which contain
        16 vectors each)
        '''
        data_file_path = os.path.join(self.config.latent_seqs_dest_dir, f'{self.mode}.bin')
        total_sequences, _ = np.memmap(data_file_path, dtype='float32', mode='r').shape
        return (total_sequences - (self.config.chunk_size - 1)) // 16

    def __getitem__(self, idx):
        '''
        we need to ensure our idxs are aligned with tubelet boundaries (which contain
        16 vectors each)
        '''
        dataset = np.memmap(
            os.path.join(self.config.latent_seqs_dest_dir, f'{self.mode}.bin'),
            dtype='float32',
            mode='r'
        ).reshape(-1, self.config.latent_channels)

        aligned_start_idx = idx * 16
        
        latent_seq = dataset[aligned_start_idx : aligned_start_idx + self.config.chunk_size]

        return torch.tensor(latent_seq[:-1]), torch.tensor(latent_seq[1:])


class LatentSequenceDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.t_batch_size = config.t_batch_size
        self.config = config

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = LatentSequenceDataset(self.config, mode='train')
            self.val_dataset = LatentSequenceDataset(self.config, mode='val')

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.t_batch_size, shuffle=True, num_workers=self.config.num_workers)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.t_batch_size, shuffle=False, num_workers=self.config.num_workers)

In [None]:
config = VideoGenConfig()
train_dataset = LatentSequenceDataset(config, mode='train')
val_dataset = LatentSequenceDataset(config, mode='val')

In [None]:
train_batch = train_dataset[0]
val_batch = val_dataset[0]

# Transformer Testing

In [None]:
# verify clip sequence by decoding each latent sequence back to clips and display
config = VideoGenConfig()
latent_dataset = LatentSequenceDataset(config, mode='full')

In [None]:
latent_dataset[0].shape

In [None]:
fsq_vae = TubeletFSQVAE(config)
checkpoint = torch.load(config.fsq_vae_checkpoint_path)
new_state_dict = {}
for key, value in checkpoint['state_dict'].items():
    new_key = key.replace('model.', '')
    new_state_dict[new_key] = value

fsq_vae.load_state_dict(new_state_dict)

In [None]:
with torch.no_grad():
    fsq_vae.eval()
    x, y = latent_dataset[-2]
    x = torch.cat((x, y[-1].unsqueeze(0)), dim=0)
    print(x.shape)
    if torch.cuda.is_available():
        fsq_vae = fsq_vae.cuda()
        x = x.cuda()
        x = x.reshape(-1, 1, 4, 4, 5).permute(0, 4, 1, 2, 3).contiguous()
        print(x.shape)
    clip_tubelets = fsq_vae.decode(x)
    print(clip_tubelets.shape)
    clip = fsq_vae.assemble_tubelets(
        clip_tubelets,
        B=1,
        C=1,
        T=config.clip_length*4,
        H=config.img_size,
        W=config.img_size,
        t=config.t,
        p=config.p
    )
    print(clip.shape)
    x = x.cpu()
    clip = clip.cpu()

display_clip(clip.squeeze(0))

In [None]:
transformer = TransformerDecoder(
    emb_dim=5,
    qkv_dim=128,
    nheads=8,
    nlayers=5,
    ctx_size=16383,
    vocab_size=fsq_vae.quantizer.codebook_size
)

In [None]:
x, y = latent_dataset[-2]
if torch.cuda.is_available():
    transformer = transformer.to(torch.float16).cuda()
    x = x.to(torch.float16).cuda()
logits = transformer(x.unsqueeze(0))
print(logits.shape)

# Transformer Decoder Training

In [None]:
resume_training = False

config = VideoGenConfig()
model = TransformerDecoder(config)
lit_model = LitTransformerDecoder(model, config)
latent_seq_data = LatentSequenceDataModule(config)

if resume_training:
    run = wandb.init(project=config.project_name, config=config.to_dict(), resume=True)
    artifact = run.use_artifact('', type='model')
    artifact_dir = artifact.download()
else:
    wandb.init(project=config.project_name, config=config.to_dict())

wandb_logger = WandbLogger(project=config.project_name, log_model=True)
wandb_logger.watch(lit_model, log="all")

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
    dirpath=config.checkpoint_path,
    filename='model-{epoch:02d}-{val_loss:.2f}',
    every_n_epochs=5,
    save_top_k=config.save_top_k,
    monitor='val/loss',
    mode='min'
)

# Define the EarlyStopping callback
early_stop_callback = EarlyStopping(
    monitor='val/loss',
    min_delta=0.0000000,
    patience=100,
    verbose=True,
    check_finite=True
)

trainer = pl.Trainer(
    max_epochs=config.max_epochs,
    devices=1,
    accelerator="gpu",
    precision="16-mixed",
    logger=wandb_logger,
    callbacks=[
        lr_monitor,
        early_stop_callback,
        # checkpoint_callback
    ],
    log_every_n_steps=1,
    accumulate_grad_batches=config.grad_acc_steps
    # gradient_clip_val=1.0,
    # overfit_batches=1,
)

# tuner = Tuner(trainer)
# lr_result = tuner.lr_find(lit_model, datamodule=LatentSequenceDataModule, max_lr=10)
# lr_result.plot(show=True, suggest=True)
if resume_training:
    trainer.fit(lit_model, latent_seq_data, ckpt_path='')
else:
    trainer.fit(lit_model, latent_seq_data)

wandb.finish()