<a href="https://colab.research.google.com/github/JordanLazzaro/VideoGen/blob/main/notebooks/VideoPoet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VideoPoet Implementation

![](https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgxFblHaHRJNH7Oi2_oOTosGN9XrjgjhWmnfADchMT8WR0XAo6SxiUfpUmn5R6akciiRduaKIMdgwHZzK3xW8mErarQ_ugx41ctQAMK08O9UMVevgkk-AgFI1xYFWAomd16OcOh0R-XpyZVLQXncpk2SHf-RmPzrqBbIWZc-nUG2TH6nC2R7qyHXn8eTC-u/s2680/image21.png)

[**VideoPoet**](https://research.google/blog/videopoet-a-large-language-model-for-zero-shot-video-generation/) is an autoregressive transformer decoder which models sequences of discrete, multimodal tokens produced by modality specific tokenizers; namely [MAGVIT-V2](https://arxiv.org/abs/2310.05737) for images/video, and [SoundStream](https://arxiv.org/abs/2107.03312) for audio. For text, the model leverages the T5 tokenizer as well as frozen T5 token embeddings. The set of these tokens together represent the model's full vocabulary. VideoPoet also employs a custom Super Resolution model to allow for lower resolution videos to be produced by the transformer for efficiency.

Implementing the video portion of this model will require the following checkpoints:

- [ ] Implement MAGVIT-V2 tokenizer
    - [ ] Dialated Causal Convolution (in time dim)
    - [ ] Blur Pool
    - [ ] LFQ (can replace with fsq?)
    - [ ] VQVAE (FSQVAE)
    - [ ] Descriminator / GAN Loss
- [ ] Implement the Transformer Decoder
- [ ] Implement the Super Resolution model
- [ ] Incorporate audio (optional/if feasible)

# Setup

In [None]:
!pip install -q wandb pytorch_lightning av imageio
!pip install -q flash-attn --no-build-isolation

In [None]:
import os
import gc
import math
import wandb
import imageio
import torch
import torchvision
import numpy as np
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
from torch import nn, optim, einsum
from torch.nn import functional as F
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torch.autograd import grad as torch_grad
from torch.cuda.amp import autocast
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale
import torchvision.transforms.functional as TF
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

from pytorch_lightning import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.tuner import Tuner
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

pl.seed_everything(42)

In [None]:
drive.mount('/content/drive')
steamboat_willie_gdrive_path = '/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4'
!cp -r /content/drive/My\ Drive/SteamboatWillie/clips .

In [None]:
def display_clip(clip):
    '''
    util method for displaying tensors as video

    expects clip.shape = (C,T,H,W)
    '''
    assert len(clip.shape) == 4, 'clip shape must be PyTorch Tensor of shape (C,T,H,W)'

    def update(frame_idx):
        ax.clear()
        ax.imshow(video_clip_np[frame_idx], cmap='gray')
        ax.axis('off')

    video_clip_np = clip.permute(1, 2, 3, 0).numpy()
    fig, ax = plt.subplots()
    ani = FuncAnimation(fig, update, frames=range(video_clip_np.shape[0]), interval=60)
    plt.close()
    display(HTML(ani.to_html5_video()))

# Config

In [None]:
config = {
    "project": {
        "magvit2": {
            "name": "SteamboatWillie VideoPoet"
            "wandb_project": "SteamboatWillie VideoPoet"
        }
    },
    "data": {
        "paths": ['/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4'],
        "clip_length": 16,
        "clip_dest_dir": "clips"
    },
    "magvit2": {
        "fsqvae": {
            "encoder": {
                "in_channels": 1,
                "out_channels": 5,
                "init_channels": 32,
                "num_downsamples": 3,
                "nblocks": 2
            },
            "fsq": {
                "levels": [9, 9, 9, 9, 9],
            },
            "decoder": {
                "in_channels": 5,
                "out_channels": 1,
                "init_channels": 256,
                "num_upsamples": 3,
                "nblocks": 2
            }
        },
        "discriminator": {
            "input_shape": (1, 16, 128, 128),
            "in_channels": 1,
            "init_channels": 32,
            "num_downsamples": 5,
            "use_grad_penalty": True,
            "grad_penalty_weight": 10,
        },
        "gan_loss_weight": 0.1,
        "recon_loss_weight": 5,
        "input_shape": (1, 16, 128, 128),
        "latent_shape": (5, 2, 16, 16),
        "checkpoint_dir": "./checkpoints/magvit_v2/"
    },
    "transformer": {
        "emb_dim": 512,
        "nheads": 8,
        "ctx_size": 2048,
        "window_size": 1024,
        "fan_out": 4,
        "nlayers": 6,
        "dropout": 0.0,
        "use_flash_attn": True,
        "checkpoint_dir": "./checkpoints/transformer/"
    },
    "super_resolution": {
        "checkpoint_dir": "./checkpoints/super_resolution/"
    },
    "training": {
        "magvit2": {
            "batch_size": 16,
            "num_workers": 4,
            "epochs": 1024,
            "lr": 1e-4,
            "save_top_k": 2,
            "check": ""

        },
        "transformer": {},
        "super_resolution": {}
    },
    "logging": {}
}

In [None]:
class Config:
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            if isinstance(value, dict):
                value = Config(value)
            self.__dict__[key] = value

    def __getattr__(self, name):
        return self.__dict__.get(name, None)

In [None]:
magvit2_config = Config(config['magvit2'])
transformer_config = Config(config['transformer'])
super_resolution_config = Config(config['super_resolution'])

# Data

## Clips Dataset

In [None]:
class RandomHorizontalFlipVideo(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, x):
        # x shape is expected to be (C, T, H, W)
        if torch.rand(1) < self.p:
            # flip all frames in the clip
            return x.flip(-1)
        return x


class SteamboatWillieDataset(Dataset):
    def __init__(self, config, mode='train', train_split=0.8, shuffle=True, augment=True):
        super().__init__()
        self.config = config
        self.train_split = train_split
        self.mode = mode

        self.preprocess_transforms = Compose([
                Lambda(lambda x: x.permute(0, 3, 1, 2)),   # (T, H, W, C) to (T, C, H, W) for Greyscale
                Grayscale(num_output_channels=1),          # Convert to grayscale
                Lambda(lambda x: x.permute(1, 0, 2, 3)),   # (T, C, H, W) to (C, T, H, W) for Conv3d
                CenterCrop((480, 575)),                    # Center crop to remove virtical bars
                Resize((config.img_size, config.img_size))
        ])

        self.postprocess_transforms = Compose([
            Lambda(lambda x: x / 255.),
            Lambda(lambda x: x.view(self.config.in_channels, self.config.clip_length, self.config.img_size, self.config.img_size))
        ])

        if self.mode == 'train' and augment:
            self.postprocess_transforms.transforms.append(RandomHorizontalFlipVideo(p=0.5))

        if os.path.exists(config.clip_dest_dir):
            clip_paths = self.build_existing_clip_paths(config.clip_dest_dir)
            self.clips = self.build_clip_refs(clip_paths)
        else:
            video_clips = VideoClips(
                config.paths,
                clip_length_in_frames=config.clip_length,
                frames_between_clips=config.clip_length
            )

            self.clips = self.build_clip_refs(self.build_clip_paths(video_clips, self.preprocess_transforms, config.clip_dest_dir))

        if mode in ['train', 'val']:
            total_clips = len(self.clips)
            if shuffle:
                indices = torch.randperm(total_clips).tolist()
            else:
                indices = list(range(total_clips))

            train_size = int(total_clips * train_split)

            if mode == 'train':
                self.clip_indices = indices[:train_size]
            else:
                self.clip_indices = indices[train_size:]
        elif mode == 'full':
            self.clip_indices = list(range(len(self.clips)))

    def build_clip_paths(self, video_clips, transforms, clip_dest_dir):
        """
        Build set of binary files to store processed video clips
        returns dict of clip_idx -> mmapped file path
        """
        clip_paths = {}

        if not os.path.exists(clip_dest_dir):
            os.makedirs(clip_dest_dir)

        for idx in tqdm(range(video_clips.num_clips()), desc='Creating clip .bin files'):
            # transform clips and write to mmap file
            clip, _, _, _ = video_clips.get_clip(idx)
            clip = self.preprocess_transforms(clip)
            clip_np = clip.numpy().astype(np.uint8)

            mmapped_file_path = os.path.join(clip_dest_dir, f'clip_{idx}.bin')
            fp = np.memmap(mmapped_file_path, dtype='uint8', mode='w+', shape=clip_np.shape)
            fp[:] = clip_np[:]
            fp.flush()
            del fp
            clip_paths[idx] = mmapped_file_path

        return clip_paths

    def build_existing_clip_paths(self, clip_dest_dir):
        """"
        returns dict of clip_idx -> mmapped file path
        from existing .bin files
        """
        clips_paths = {}
        for filename in os.listdir(clip_dest_dir):
            if filename.startswith('clip_') and filename.endswith('.bin'):
                idx = int(filename.split('_')[1].split('.')[0])
                file_path = os.path.join(clip_dest_dir, filename)
                clips_paths[idx] = file_path

        return clips_paths

    def build_clip_refs(self, clip_paths):
        """
        Build mmap reference to bin files
        returns dict of clip_idx -> np.array mmapped to respective bin file
        """
        clips = {}
        for idx, path in tqdm(clip_paths.items(), desc='Building clip refs'):
            clips[idx] = np.memmap(path, dtype='uint8', mode='r')

        return clips

    def __len__(self):
        return len(self.clip_indices)

    def __getitem__(self, idx):
        clip = self.clips[self.clip_indices[idx]]
        return self.postprocess_transforms(torch.tensor(clip, dtype=torch.float32))

## Lightning Datamodule

In [None]:
class SteamboatWillieDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.batch_size = config.batch_size
        self.config = config

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = SteamboatWillieDataset(self.config, mode='train')
            self.val_dataset = SteamboatWillieDataset(self.config, mode='val')

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.config.num_workers)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.config.num_workers)

# MAGVIT-V2

## Model

### Components

In [None]:
class CausalConv3d(nn.Module):
    '''
    enforces causality in the time dimension (https://paperswithcode.com/method/causal-convolution)
    inspired by:
    https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889
    '''
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=(3,3,3),
            stride=(1, 1, 1),
            dilation=(1, 1, 1),
            pad_mode='constant'
        ):
        super().__init__()
        k_t, k_h, k_w = kernel_size
        pad_w, pad_h, pad_t = k_w//2, k_h//2, dilation[0] * (k_t - 1) + (1 - stride[0])

        # pad: (left, right, top, bottom, front, back)
        self.t_causal_pad = (pad_w, pad_w, pad_h, pad_h, pad_t, 0)
        self.pad_mode = pad_mode

        self.conv = nn.Conv3d(
            in_channels  = in_channels,
            out_channels = out_channels,
            kernel_size  = kernel_size,
            stride       = stride,
            dilation     = dilation
        )

    def forward(self, x):
        x = F.pad(x, self.t_causal_pad, mode=self.pad_mode)
        x = self.conv(x)

        return x

In [None]:
class BlurPool3d(nn.Module):
    '''
    https://arxiv.org/abs/1904.11486
    inspired by:
    https://github.com/adobe/antialiased-cnns/blob/master/antialiased_cnns/blurpool.py
    https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L509
    '''
    def __init__(self, in_channels, kernel_size=3, stride=2):
        super().__init__()
        self.in_channels = in_channels
        self.stride = stride
        self.padding = (kernel_size - 1) // 2

        self.register_buffer('blur_filter', self.get_blur_filter3d(kernel_size))

    def forward(self, x):
        assert len(x.shape) == 5, 'BlurPool3d only supports rank 5 tensors'
        return F.conv3d(x, self.blur_filter, stride=(self.stride, self.stride, self.stride), padding=self.padding)

    def get_blur_filter3d(self, kernel_size):
        if kernel_size == 1:
            filter = torch.tensor([1.,])
        elif kernel_size == 2:
            filter = torch.tensor([1., 1.])
        elif kernel_size == 3:
            filter = torch.tensor([1., 2., 1.])
        elif kernel_size == 4:
            filter = torch.tensor([1., 3., 3., 1.])
        elif kernel_size == 5:
            filter = torch.tensor([1., 4., 6., 4., 1.])
        elif kernel_size == 6:
            filter = torch.tensor([1., 5., 10., 10., 5., 1.])
        elif kernel_size == 7:
            filter = torch.tensor([1., 6., 15., 20., 15., 6., 1.])

        filter = einsum('i, j, k -> i j k', filter, filter, filter)
        filter = repeat(filter, 'd h w -> oc ic d h w', oc=self.in_channels, ic=self.in_channels)

        return filter / torch.sum(filter)

In [None]:
class ResBlock3d(nn.Module):
    ''' handles both same and different in/out channels '''
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=(3,3,3)
        ):
        super().__init__()
        if in_channels != out_channels:
            self.identity = CausalConv3d(
                in_channels  = in_channels,
                out_channels = out_channels,
                kernel_size  = (1, 1, 1)
            )
        else:
            self.identity = nn.Identity()

        self.block = nn.Sequential(
            nn.GroupNorm(
                num_groups   = in_channels // 2 if in_channels >= 2 else 1,
                num_channels = in_channels
            ),
            nn.SiLU(),
            CausalConv3d(
                in_channels  = in_channels,
                out_channels = out_channels,
                kernel_size  = kernel_size,
                stride       = (1,1,1),
                dilation     = (1,1,1)
            ),
            nn.GroupNorm(
                num_groups   = out_channels // 2 if out_channels >= 2 else 1,
                num_channels = out_channels
            ),
            nn.SiLU(),
            CausalConv3d(
                in_channels  = out_channels,
                out_channels = out_channels,
                kernel_size  = kernel_size,
                stride       = (1,1,1),
                dilation     = (1,1,1)
            ),
        )

    def forward(self, x):
        out = self.block(x) + self.identity(x)
        return out

In [None]:
class ResBlockDown3d(nn.Module):
    ''' strided conv for down-sampling + blur pooling (https://arxiv.org/abs/1904.11486) '''
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=(3,3,3)
        ):
        super().__init__()
        self.identity = nn.Sequential(
            BlurPool3d(in_channels),
            nn.Conv3d(
                in_channels  = in_channels,
                out_channels = out_channels,
                kernel_size  = (1, 1, 1),
                padding      = 'same'
            )
        )

        self.block = nn.Sequential(
            nn.Conv3d(
                in_channels  = in_channels,
                out_channels = out_channels,
                kernel_size  = kernel_size,
                padding      = 'same'
            ),
            nn.LeakyReLU(),
            BlurPool3d(out_channels),
            nn.Conv3d(
                in_channels  = out_channels,
                out_channels = out_channels,
                kernel_size  = kernel_size,
                padding      = 'same'
            ),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        return self.block(x) + self.identity(x)

In [None]:
class PixelShuffle3d(nn.Module):
    ''' https://arxiv.org/abs/1609.05158 '''
    def __init__(self, r=(2, 2, 2)):
        super().__init__()
        self.r = r
    
    def forward(self, x):
        return rearrange(
            x,
            'b (c1 r1 r2 r3) d h w -> b c1 (d r1) (h r2) (w r3)',
            r1=self.r[0], r2=self.r[1], r3=self.r[2]
        )

In [None]:
class Upsample3d(nn.Module):
    def __init__(self, in_channels, out_channels, upsample_time=True):
        super().__init__()
        self.conv = CausalConv3d(
            in_channels  = in_channels,
            out_channels = out_channels * (8 if upsample_time else 4),
            kernel_size  = (3,3,3)
        )
        # TODO: should we add an activation between?
        self.pixel_shuffle = PixelShuffle3d(
            r=(2,2,2) if upsample_time else (1,2,2)
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        return x

In [None]:
class EncoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            nblocks,
            kernel_size=(3,3,3),
            causal_stride=(2,2,2)
        ):
        super().__init__()
        self.block = nn.Sequential(
            CausalConv3d(
                in_channels  = in_channels,
                out_channels = out_channels,
                kernel_size  = kernel_size,
                stride       = causal_stride
            ),
            nn.Sequential(*[
                ResBlock3d(
                    in_channels  = out_channels,
                    out_channels = out_channels
                )
                for _ in range(nblocks)
            ])
        )

    def forward(self, x):
        return self.block(x)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            nblocks=2,
            upsample_time=True,
            silu=True
        ):
        super().__init__()
        self.group_norm = nn.GroupNorm(
            num_groups   = in_channels // 2 if in_channels >= 2 else 1,
            num_channels = in_channels
        )

        self.res_blocks = nn.Sequential(*[
            ResBlock3d(
                in_channels  = in_channels if i==0 else out_channels,
                out_channels = out_channels
            )
            for i in range(nblocks)
        ])

        self.upsample = Upsample3d(
            in_channels   = out_channels,
            out_channels  = out_channels,
            upsample_time = upsample_time
        )
        if silu:
            self.silu = nn.SiLU()
        else:
            self.silu = None

    def forward(self, x):
        x = self.group_norm(x)
        x = self.res_blocks(x)
        x = self.upsample(x)

        if self.silu is not None:
            x = self.silu(x)
        return x

In [None]:
class FSQ(nn.Module):
    def __init__(self, levels, eps=1e-3):
        super().__init__()
        self.register_buffer('levels', torch.tensor(levels))

        self.register_buffer('basis', 
            torch.cat([
                torch.tensor([1]),
                torch.cumprod(torch.tensor(levels[:-1]), dim=0)
            ], dim=0)
        )

        self.eps = eps
        self.codebook_size = torch.prod(self.levels)

    def round_ste(self, z):
        z_q = torch.round(z)
        return z + (z_q - z).detach()

    def quantize(self, z):
        # half_l is used to determine how to scale tanh; we
        # subtract 1 from the number of levels to account for 0
        # being a quantization bin and tanh being symmetric around 0
        half_l = (self.levels - 1) * (1 - self.eps) / 2

        # if a given level is even, it will result in a scale for tanh
        # which is halfway between integer values, so we offset
        # the tanh output down by 0.5 to line it with whole integers
        offset = torch.where(self.levels % 2 == 0, 0.5, 0.0)

        # if our level is even, we want to shift the tanh input to
        # ensure the 0 quantization bin is centered
        shift = torch.tan(offset / half_l)

        # once we have our shift and offset (in the case of an even level)
        # we can round to the nearest integer bin and allow for STE
        z_q = self.round_ste(torch.tanh(z + shift) * half_l - offset)

        # after quantization, we want to renormalize the quantized
        # values to be within the range expected by the model (ie. [-1, 1])
        half_width = self.levels // 2
        return z_q / half_width

    def scale_and_shift(self, z_q_normalized):
        half_width = self.levels // 2
        return (z_q_normalized * half_width) + half_width

    def scale_and_shift_inverse(self, z_q):
        half_width = self.levels // 2
        return (z_q - half_width) / half_width

    def codes_to_idxs(self, z_q):
        assert z_q.shape[-1] == len(self.levels)
        z_q = self.scale_and_shift(z_q)
        return (z_q * self.basis).sum(dim=-1).to(torch.int32)

    def idxs_to_codes(self, idxs):
        idxs = idxs.unsqueeze(-1)
        codes_not_centered = (idxs // self.basis) % self.levels
        return self.scale_and_shift_inverse(codes_not_centered)

    def forward(self, z):
        # TODO: make this work for generic tensor sizes
        # TODO: use einops to clean up

        if len(z.shape) == 5: # video
            B, C, T, H, W = z.shape
            # (B, C, T, H, W) -> (B, T, H, W, C)
            z_c_last = z.permute(0, 2, 3, 4, 1).contiguous()
            # (B, T, H, W, C) -> (BTHW, C)
            z_flatten = z_c_last.reshape(-1, C)
            z_flatten_q = self.quantize(z_flatten)
            # (BTHW, C) -> (B, T, H, W, C) -> (B, C, T, H, W)
            z_q = z_flatten_q.reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()

        elif len(z.shape) == 4: # image
            B, C, H, W = z.shape
            # (B, C, H, W) -> (B, H, W, C)
            z_c_last = z.permute(0, 2, 3, 1).contiguous()
            # (B, H, W, C) -> (BHW, C)
            z_flatten = z_c_last.reshape(-1, C)
            z_flatten_q = self.quantize(z_flatten)
            # (BHW, C) -> (B, H, W, C) -> (B, C, T, H, W)
            z_q = z_flatten_q.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()

        elif len(z.shape) == 2: # vector sequence
            # (B, C)
            z_q = self.quantize(z)

        return {'z_q': z_q}

In [None]:
class Encoder(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            init_channels,
            num_downsamples,
            nblocks
        ):
        super().__init__()
        channels = [init_channels * (2 ** i) for i in range(num_downsamples+1)]
        self.in_conv = CausalConv3d(
            in_channels=in_channels,
            out_channels=channels[0],
            kernel_size=(3,3,3)
        )

        self.enc_blocks = nn.Sequential(*[
            EncoderBlock(
                in_channels=channels[i],
                out_channels=channels[i+1],
                nblocks=nblocks,
                kernel_size=(3,3,3),
                causal_stride=(2,2,2)
            )
            for i in range(num_downsamples)
        ])

        self.out_block = nn.Sequential(
            nn.GroupNorm(
                num_groups=channels[-1] // 2,
                num_channels=channels[-1]
            ),
            nn.SiLU(),
            CausalConv3d(
                in_channels=channels[-1],
                out_channels=out_channels,
                kernel_size=(1,1,1)
            ),
        )

    def forward(self, x):
        x = self.in_conv(x)
        x = self.enc_blocks(x)
        x = self.out_block(x)
        
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            init_channels,
            num_upsamples,
            nblocks,
        ):
        super().__init__()
        channels = [init_channels // (2 ** i) for i in range(num_upsamples)]
        self.in_block = nn.Sequential(
            CausalConv3d(
                in_channels=in_channels,
                out_channels=channels[0],
                kernel_size=(3,3,3)
            ),
            nn.Sequential(*[
                ResBlock3d(
                    in_channels=channels[0],
                    out_channels=channels[0]
                )
                for i in range(nblocks)
            ])
        )

        self.dec_blocks = nn.Sequential(*[
            DecoderBlock(
                in_channels=channels[i],
                out_channels=channels[i+1] if i<num_upsamples-1 else out_channels,
                nblocks=nblocks,
                upsample_time=True,
                silu=(i == num_upsamples-1)
            )
            for i in range(num_upsamples)
        ])

    def forward(self, x):
        x = self.in_block(x)
        x = self.dec_blocks(x)

        return x

### Models

In [None]:
class FSQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(
            in_channels     = config.encoder.in_channels,
            out_channels    = config.encoder.out_channels,
            init_channels   = config.encoder.init_channels,
            num_downsamples = config.encoder.num_downsamples,
            nblocks         = config.encoder.nblocks
        )

        self.fsq = FSQ(levels = config.fsq.levels)

        self.decoder = Decoder(
            in_channels   = config.decoder.in_channels,
            out_channels  = config.decoder.out_channels,
            init_channels = config.decoder.init_channels,
            num_upsamples = config.decoder.num_upsamples,
            nblocks       = config.decoder.nblocks
        )

    def encode(self, x):
        return self.encoder(x)

    def quantize(self, z):
        return self.fsq(z)

    def decode(self, z_q):
        return self.decoder(z_q)

    def forward(self, x):
        z = self.encode(x)
        z_q = self.quantize(z)
        x_hat = self.decode(z_q)
        return { 'z': z, 'z_q': z_q, 'x_hat': x_hat }

In [None]:
class Discriminator(nn.Module):
    ''' bunch of downsampling resblocks + leaky relu '''
    def __init__(self, config):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv3d(
                in_channels  = config.in_channels,
                out_channels = config.init_channels,
                kernel_size  = (3, 3, 3),
                padding      = 'same'
            ),
            nn.LeakyReLU()
        )

        channels = [config.init_channels * (2 ** i) for i in range(config.num_downsamples+1)]
        self.downsample = nn.Sequential(*[
            ResBlockDown3d(
                in_channels  = channels[i],
                out_channels = channels[i+1]
            )
            for i in range(config.num_downsamples)
        ])

        self.conv2 = nn.Sequential(
            nn.Conv3d(
                in_channels  = channels[-1],
                out_channels = channels[-1],
                kernel_size  = (3, 3, 3),
                padding      = 'same'
            ),
            nn.LeakyReLU()
        )
        self.mlp = nn.Sequential(
            nn.Linear(channels[-1]*(4*4), channels[-1]), # 4x4 after downsampling
            nn.LeakyReLU(),
            nn.Linear(channels[-1], 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.downsample(x)
        x = self.conv2(x)
        x = rearrange(x, 'b c d h w -> b (c d h w)')
        x = self.mlp(x)
        return x

In [None]:
class MAGVIT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.fsqvae = FSQVAE(config.fsqvae)
        self.discriminator = Discriminator(config.discriminator)

    def forward(self, x):
        return self.fsqvae(x)

    def tokenize(self, x):
        z_q = self.fsqvae.quantize(self.fsqvae.encode(x))
        z_q = rearrange(z_q, 'b c d h w -> b (d h w) c')
        idxs = self.fsqvae.fsq.codes_to_idxs(z_q)
        return idxs

    def decode(self, idxs):
        c, d, h, w = self.config.latent_shape
        codes = self.fsqvae.fsq.idxs_to_codes(idxs)
        codes = rearrange(codes, 'b (d h w) c -> b c d h w', d=d, h=h, w=w)
        return self.fsqvae.decode(codes)

    def discriminate(self, x):
        return self.discriminator(x)

    def reconstruction_loss(self, x_hat, x):
        return F.mse_loss(x_hat, x)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def gradient_penalty(self, x, y_hat_real):
        '''
        inspired by:
        https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L99
        '''
        gradients = torch_grad(
            outputs = y_hat_real,
            inputs = x,
            grad_outputs = torch.ones(y_hat_real.size(), device = x.device),
            create_graph = True,
            retain_graph = True,
            only_inputs = True
        )[0]

        gradients = rearrange(gradients, 'b ... -> b (...)')
        return ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

## Lightning Module

In [None]:
class LitMAGVIT2(pl.LightningModule):
    def __init__(self, magvit2, config):
        super().__init__()

        self.magvit2 = magvit2
        self.config = config
        self.lr = config.training.lr

    def forward(self, x):
        return self.magvit2(x)

    def training_step(self, batch, batch_idx, optimizer_idx):
        x = batch
        out = self(x)

        rec_loss = self.magvit2.reconstruction_loss(out['x_hat'], x)

        # train generator
        if optimizer_idx == 0:
            y_hat_fake = self.magvit2.discriminate(out['x_hat'])
            generator_loss = self.magvit2.adversarial_loss(y_hat_fake, torch.ones_like(y_hat_fake))
            total_loss = rec_loss + self.magvit2.discriminator_weight * generator_loss

            self.log('train/rec_loss',       rec_loss)
            self.log('train/generator_loss', generator_loss)
            self.log('train/total_loss',     total_loss)

            return total_loss

        # train discriminator
        if optimizer_idx == 1:
            y_hat_real = self.magvit2.discriminate(x)
            y_hat_fake = self.magvit2.discriminate(out['x_hat'].detach())

            real_labels = torch.ones_like(y_hat_real)
            fake_labels = torch.zeros_like(y_hat_fake)

            real_loss = self.magvit2.adversarial_loss(y_hat_real, real_labels)
            fake_loss = self.magvit2.adversarial_loss(y_hat_fake, fake_labels)
            disc_loss = (real_loss + fake_loss) / 2

            if self.config.discriminator.use_grad_penalty:
                grad_penalty = self.magvit2.gradient_penalty(x, out['x_hat'].detach())
                disc_loss += self.config.discriminator.grad_penalty_weight * grad_penalty
                self.log('train/grad_penalty', grad_penalty)

            self.log('train/discriminator_loss', disc_loss)

            return disc_loss

    def validation_step(self, batch, batch_idx):
        x = batch
        out = self(x)

        rec_loss = self.magvit2.reconstruction_loss(out['x_hat'], x)
        self.log('val/rec_loss', rec_loss, prog_bar=True)

        self.log_val_clips(x, out)

        return rec_loss

    def configure_optimizers(self):
        fsqvae_optimizer = torch.optim.AdamW(
            self.magvit2.fsqvae.parameters(),
            lr=self.lr,
            betas=(self.config.training.beta1, self.config.training.beta2),
            weight_decay=self.config.weight_decay
        )

        disc_optimizer = torch.optim.AdamW(
            self.self.magvit2.discriminator.parameters(),
            lr=self.lr,
            betas=(self.config.training.beta1, self.config.training.beta2),
            weight_decay=self.config.weight_decay
        )

        if self.config.use_lr_schedule:
            fsqvae_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                fsqvae_optimizer, T_max=self.config.training_steps)

            disc_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                disc_optimizer, T_max=self.config.training_steps)

            return [fsqvae_optimizer, disc_optimizer], [fsqvae_scheduler, disc_scheduler]

        return [fsqvae_optimizer, disc_optimizer]

    def log_val_clips(self, x, out, num_clips=2):
        n_clips = min(x.size(0), num_clips)

        for i in range(n_clips):
            # extract the ith original and reconstructed clip
            original_clip = x[i]  # (C, T, H, W)
            reconstructed_clip = out['x_hat'][i]  # (C, T, H, W)

            # convert tensors to numpy arrays and transpose to (T, H, W, C) for GIF creation
            original_clip_np = original_clip.permute(1, 2, 3, 0).cpu().numpy()
            reconstructed_clip_np = reconstructed_clip.permute(1, 2, 3, 0).cpu().numpy()

            original_clip_np = (original_clip_np - original_clip_np.min()) / (original_clip_np.max() - original_clip_np.min())
            reconstructed_clip_np = (reconstructed_clip_np - reconstructed_clip_np.min()) / (reconstructed_clip_np.max() - reconstructed_clip_np.min())

            original_clip_np = (original_clip_np * 255).astype(np.uint8)
            reconstructed_clip_np = (reconstructed_clip_np * 255).astype(np.uint8)

            # grayscale videos need to be of shape (T, H, W)
            if original_clip_np.shape[-1] == 1:
                original_clip_np = original_clip_np.squeeze(-1)

            if reconstructed_clip_np.shape[-1] == 1:
                reconstructed_clip_np = reconstructed_clip_np.squeeze(-1)

            # create GIFs for the original and reconstructed clips
            original_gif_path = f'/tmp/original_clip_{i}.gif'
            reconstructed_gif_path = f'/tmp/reconstructed_clip_{i}.gif'
            imageio.mimsave(original_gif_path, original_clip_np, fps=10)
            imageio.mimsave(reconstructed_gif_path, reconstructed_clip_np, fps=10)

            # log the GIFs to wandb
            self.logger.experiment.log({
                f"val/original_clip_{i}": wandb.Video(original_gif_path, fps=10, format="gif", caption="Original"),
                f"val/reconstructed_clip_{i}": wandb.Video(reconstructed_gif_path, fps=10, format="gif", caption="Reconstructed")
            })

# Testing

In [None]:
x = torch.randn(5, 3, 16, 128, 128).cuda()

In [None]:
conv = CausalConv3d(in_channels=1, out_channels=16).cuda()
out = conv(x)

In [None]:
out.shape

In [None]:
res_block = ResBlock3d(in_channels=16, out_channels=16).cuda()
out = res_block(out)

In [None]:
out.shape

In [None]:
res_block2 = ResBlock3d(in_channels=16, out_channels=32).cuda()
out = res_block2(out)

In [None]:
out.shape

In [None]:
blur_pool = BlurPool3d(1).cuda()
out = blur_pool(x)

In [None]:
out.shape

In [None]:
res_block_down3d = ResBlockDown3d(in_channels=3, out_channels=32).cuda()
out = res_block_down3d(out)

In [None]:
out.shape

In [None]:
disc = Discriminator(in_channels=3).cuda()
out = disc(x)

In [None]:
out.shape

In [None]:
enc = Encoder(in_channels=3, out_channels=5, init_channels=32, num_downsamples=3, nblocks=2).cuda()
out = enc(x)

In [None]:
out.shape

In [None]:
x = torch.randn(5, 32, 16, 128, 128).cuda()

In [None]:
upsample = Upsample3d(in_channels=32, out_channels=64).cuda()
out = upsample(x)

In [None]:
out.shape

In [None]:
upsample = Upsample3d(in_channels=32, out_channels=64, upsample_time=False).cuda()
out = upsample(x)

In [None]:
out.shape

In [None]:
x = torch.randn(5, 5, 2, 16, 16).cuda()

In [None]:
dec = Decoder(in_channels=5, out_channels=3, init_channels=512, num_upsamples=3, nblocks=2).cuda()
out = dec(x)

In [None]:
out.shape

## Training

In [None]:
resume_training = False

project_config  = Config(config['project'])
magvit2_config  = Config(config['magvit2'])
training_config = Config(config['training'])
data_config     = Config(config['data'])

model = MAGVIT2(magvit2_config)
lit_model = LitMAGVIT2(model, magvit2_config)
data = SteamboatWillieDataModule(data_config)

if resume_training:
    run = wandb.init(
        project=project_config.magvit2.project_name,
        run="[RUN NAME]",
        config=config['magvit2'], resume=True
    )
    artifact = run.use_artifact('[ARTIFACT NAME]', type='model')
    artifact_dir = artifact.download()
else:
    wandb.init(
        project=project_config.magvit2.project_name,
        run="[RUN NAME]",
        config=config.to_dict()
    )

wandb_logger = WandbLogger(
    project=project_config.magvit2.project_name,
    log_model=True
)
wandb_logger.watch(lit_model, log="all")

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
    dirpath=magvit2_config.checkpoint_dir,
    filename='magvit2-{epoch:05d}',
    every_n_epochs=5,
    save_top_k=training_config.save_top_k,
    monitor='val/loss',
    mode='min'
)

# Define the EarlyStopping callback
early_stop_callback = EarlyStopping(
    monitor='val/loss',
    min_delta=0.0000000,
    patience=100,
    verbose=True,
    check_finite=True
)

trainer = pl.Trainer(
    max_epochs=training_config.magvit2.max_epochs,
    devices=1,
    accelerator="gpu",
    precision="16-mixed",
    logger=wandb_logger,
    callbacks=[
        lr_monitor,
        early_stop_callback,
        checkpoint_callback
    ],
    log_every_n_steps=1,
    # gradient_clip_val=1.0,
    # overfit_batches=1,
)

# tuner = Tuner(trainer)
# lr_result = tuner.lr_find(lit_model, datamodule=steamboat_willie_data, max_lr=100)
# lr_result.plot(show=True, suggest=True)
if resume_training:
    trainer.fit(lit_model, data, ckpt_path='[CHECKPOINT PATH]')
else:
    trainer.fit(lit_model, data)

wandb.finish()

## Sweeps

# Transformer Decoder

## Model

In [None]:
class Attention(nn.Module):
    def __init__(
            self,
            emb_dim,
            nheads,
            ctx_size,
            window_size,
            dropout=0.0,
            use_flash_attn=True
        ):
        super().__init__()
        assert emb_dim % nheads == 0

        self.W_Q = nn.Linear(emb_dim, emb_dim, bias=False)
        self.W_K = nn.Linear(emb_dim, emb_dim, bias=False)
        self.W_V = nn.Linear(emb_dim, emb_dim, bias=False)

        self.W_O = nn.Linear(emb_dim, emb_dim, bias=False)

        if not use_flash_attn:
            self.register_buffer(
                "mask",
                torch.tril(torch.ones((ctx_size, ctx_size))).reshape(
                    1, 1, ctx_size, ctx_size
                ),
            )

        self.head_dim = emb_dim // nheads
        self.nheads = nheads
        self.window_size = window_size

        self.dropout_p = dropout
        self.dropout = nn.Dropout(dropout)

        self.use_flash_attn = use_flash_attn

    def forward(self, x, slopes):
        B, T, D = x.size()

        # (B, T, D) -> (B, T, H, H_D)
        Q = self.W_Q(x).reshape(B, T, self.nheads, self.head_dim)
        K = self.W_K(x).reshape(B, T, self.nheads, self.head_dim)
        V = self.W_V(x).reshape(B, T, self.nheads, self.head_dim)

        if self.use_flash_attn:
            out = flash_attn_func(
                Q, K, V,
                dropout_p=self.dropout_p if self.training else 0.0,
                softmax_scale=None,
                causal=True,
                window_size=(self.window_size, 0),
                alibi_slopes=slopes.to(torch.float32),
                deterministic=False
            )
        else:
            # (B, T, H, H_D) -> (B, H, T, H_D)
            Q = Q.transpose(1, 2)
            K = K.transpose(1, 2)
            V = V.transpose(1, 2)

            # (B, H, T, H_D) @ (B, H, H_D, T) -> (B, H, T, T)
            attn = (Q @ K.transpose(-2, -1)) / (1.0 * math.sqrt(self.head_dim))
            # attn = attn + bias
            attn = attn.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
            attn = F.softmax(attn, dim=-1)

            attn = self.dropout(attn)

            # ((B, H, T, T)) @ (B, H, T, H_D) -> (B, H, T, H_D)
            out = attn @ V

        # (B, H, T, H_D) -> (B, T, H, H_D) -> (B, T, D)
        out = out.transpose(1, 2).reshape(B, T, self.emb_dim)

        # (B, T, D)
        out = self.W_O(out)

        return out

In [None]:
class MLP(nn.Module):
    def __init__(self, emb_dim, fan_out=4, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(emb_dim, fan_out * emb_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(fan_out * emb_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.dropout(x)

        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(
            self,
            emb_dim,
            nheads,
            ctx_size,
            window_size,
            fan_out=4,
            dropout=0.0,
            use_flash_attn=True
        ):
        super().__init__()
        self.ln_1 = nn.LayerNorm(emb_dim)
        self.attn = Attention(
            emb_dim,
            nheads,
            ctx_size,
            window_size,
            dropout,
            use_flash_attn
        )
        self.ln_2 = nn.LayerNorm(emb_dim)
        self.mlp = MLP(emb_dim, fan_out, dropout)

    def forward(self, x, slopes):
        x = x + self.attn(self.ln_1(x), slopes)
        x = x + self.mlp(self.ln_2(x))

        return x

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.blocks = nn.ModuleList([
            TransformerBlock(
                config.emb_dim,
                config.nheads,
                config.ctx_size,
                config.window_size,
                config.fan_out,
                config.dropout,
                config.use_flash_attn
            ) for _ in range(config.nlayers)
        ])
        self.pred_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
        self.ln_f = nn.LayerNorm(config.emb_dim)

        self.register_buffer("slopes", self.get_alibi_slope(config.nheads))
        self.window_size = config.window_size

        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        if not self.config.use_flash_attn:
            bias = (self.m * self.get_relative_positions(self.window_size).to(x.device)).unsqueeze(0)

        x = self.emb_up_proj(x)

        x = self.dropout(x)

        for block in self.blocks:
            x = block(x, self.slopes)

        logits = self.pred_head(self.ln_f(x))

        return logits

    def get_relative_positions(self, seq_len: int) -> torch.tensor:
        x = torch.arange(seq_len)[None, :]
        y = torch.arange(seq_len)[:, None]
        return (x - y).clamp_max_(0)


    def get_alibi_slope(self, num_heads):
        x = (2 ** 8) ** (1 / num_heads)
        if self.config.use_flash_attn:
            x = torch.tensor([1 / x ** (i + 1) for i in range(num_heads)])
        else:
            x = torch.tensor([1 / x ** (i + 1) for i in range(num_heads)]).unsqueeze(-1).unsqueeze(-1)

        return x

## Lightning Module

In [None]:
class StepwiseValidationScheduleCallback(Callback):
    def __init__(self, config):
        super().__init__()
        self.config = config

    def on_batch_end(self, trainer, pl_module):
        if (trainer.global_step + 1) % self.config.validation_step_interval == 0:
            trainer.validate(pl_module)

In [None]:
class StepwiseModelCheckpoint(ModelCheckpoint):
    def __init__(self, save_step_frequency, dirpath, filename="{step}", **kwargs):
        super().__init__(dirpath=dirpath, filename=filename, **kwargs)
        self.save_step_frequency = save_step_frequency

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=dataloader_idx)

        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            filepath = self.format_checkpoint_name(global_step, {}, verbatim=True).format(step=global_step)
            filepath = os.path.join(self.dirpath, f"{filepath}-step={global_step}.ckpt")
            self._save_model(filepath, trainer, pl_module)

            if self.save_top_k > 0:
                self._del_model(trainer, pl_module)

In [None]:
class LitTransformerDecoder(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.lr = config.lr

        if self.logger:
            self.logger.experiment.config.update(config)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))

        self.log('train/loss', loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)

        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))

        self.log('val/loss', loss, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            betas=(self.config.beta1, self.config.beta2),
            weight_decay=self.config.weight_decay
        )

        if self.config.use_lr_schedule:
            scheduler = {
                'scheduler': CosineAnnealingLR(optimizer, T_max=self.config.training_steps, eta_min=0),
                'interval': 'step',
                'frequency': 1,
            }
            return [optimizer], [scheduler]

        return optimizer

## Testing

## Training

## Sweeps

# Super Resolution

## Model

## Lightning Module

## Testing

## Training

## Sweeps