<a href="https://colab.research.google.com/github/JordanLazzaro/VideoGen/blob/main/notebooks/VideoPoet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VideoPoet Implementation

![](https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgxFblHaHRJNH7Oi2_oOTosGN9XrjgjhWmnfADchMT8WR0XAo6SxiUfpUmn5R6akciiRduaKIMdgwHZzK3xW8mErarQ_ugx41ctQAMK08O9UMVevgkk-AgFI1xYFWAomd16OcOh0R-XpyZVLQXncpk2SHf-RmPzrqBbIWZc-nUG2TH6nC2R7qyHXn8eTC-u/s2680/image21.png)

[**VideoPoet**](https://research.google/blog/videopoet-a-large-language-model-for-zero-shot-video-generation/) is an autoregressive transformer decoder which models sequences of discrete, multimodal tokens produced by modality specific tokenizers; namely [MAGVIT-V2](https://arxiv.org/abs/2310.05737) for images/video, and [SoundStream](https://arxiv.org/abs/2107.03312) for audio. For text, the model leverages the T5 tokenizer as well as frozen T5 token embeddings. The set of these tokens together represent the model's full vocabulary. VideoPoet also employs a custom Super Resolution model to allow for lower resolution videos to be produced by the transformer for efficiency.

Implementing the video portion of this model will require the following checkpoints:

- [ ] Implement MAGVIT-V2 tokenizer
- [ ] Implement the Transformer Decoder
- [ ] Implement the Super Resolution model

# Setup

In [None]:
!pip install -q wandb pytorch_lightning av imageio
!pip install -q flash-attn --no-build-isolation

In [None]:
import os
import gc
import math
import wandb
import imageio
import torch
import torchvision
import numpy as np
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
from torch import nn, optim
from torch.nn import functional as F
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale
import torchvision.transforms.functional as TF

from pytorch_lightning import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.tuner import Tuner
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

pl.seed_everything(42)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
steamboat_willie_gdrive_path = '/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4'
!cp -r /content/drive/My\ Drive/SteamboatWillie/clips .

In [None]:
def display_clip(clip):
    '''
    util method for displaying tensors as video

    expects clip.shape = (C,T,H,W)
    '''
    assert len(clip.shape) == 4, 'clip shape must be PyTorch Tensor of shape (C,T,H,W)'

    def update(frame_idx):
        ax.clear()
        ax.imshow(video_clip_np[frame_idx], cmap='gray')
        ax.axis('off')

    video_clip_np = clip.permute(1, 2, 3, 0).numpy()
    fig, ax = plt.subplots()
    ani = FuncAnimation(fig, update, frames=range(video_clip_np.shape[0]), interval=60)
    plt.close()
    display(HTML(ani.to_html5_video()))

# Data

## Clips Dataset

In [None]:
class RandomHorizontalFlipVideo(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, x):
        # x shape is expected to be (C, T, H, W)
        if torch.rand(1) < self.p:
            # flip all frames in the clip
            return x.flip(-1)
        return x


class SteamboatWillieDataset(Dataset):
    def __init__(self, config, mode='train', train_split=0.8, shuffle=True, augment=True):
        super().__init__()
        self.config = config
        self.train_split = train_split
        self.mode = mode

        self.preprocess_transforms = Compose([
                Lambda(lambda x: x.permute(0, 3, 1, 2)),   # (T, H, W, C) to (T, C, H, W) for Greyscale
                Grayscale(num_output_channels=1),          # Convert to grayscale
                Lambda(lambda x: x.permute(1, 0, 2, 3)),   # (T, C, H, W) to (C, T, H, W) for Conv3d
                CenterCrop((480, 575)),                    # Center crop to remove virtical bars
                Resize((config.img_size, config.img_size))
        ])

        self.postprocess_transforms = Compose([
            Lambda(lambda x: x / 255.),
            Lambda(lambda x: x.view(self.config.in_channels, self.config.clip_length, self.config.img_size, self.config.img_size))
        ])

        if self.mode == 'train' and augment:
            self.postprocess_transforms.transforms.append(RandomHorizontalFlipVideo(p=0.5))

        if os.path.exists(config.clip_dest_dir):
            clip_paths = self.build_existing_clip_paths(config.clip_dest_dir)
            self.clips = self.build_clip_refs(clip_paths)
        else:
            video_clips = VideoClips(
                config.paths,
                clip_length_in_frames=config.clip_length,
                frames_between_clips=config.clip_length
            )

            self.clips = self.build_clip_refs(self.build_clip_paths(video_clips, self.preprocess_transforms, config.clip_dest_dir))

        if mode in ['train', 'val']:
            total_clips = len(self.clips)
            if shuffle:
                indices = torch.randperm(total_clips).tolist()
            else:
                indices = list(range(total_clips))

            train_size = int(total_clips * train_split)

            if mode == 'train':
                self.clip_indices = indices[:train_size]
            else:
                self.clip_indices = indices[train_size:]
        elif mode == 'full':
            self.clip_indices = list(range(len(self.clips)))

    def build_clip_paths(self, video_clips, transforms, clip_dest_dir):
        """
        Build set of binary files to store processed video clips
        returns dict of clip_idx -> mmapped file path
        """
        clip_paths = {}

        if not os.path.exists(clip_dest_dir):
            os.makedirs(clip_dest_dir)

        for idx in tqdm(range(video_clips.num_clips()), desc='Creating clip .bin files'):
            # transform clips and write to mmap file
            clip, _, _, _ = video_clips.get_clip(idx)
            clip = self.preprocess_transforms(clip)
            clip_np = clip.numpy().astype(np.uint8)

            mmapped_file_path = os.path.join(clip_dest_dir, f'clip_{idx}.bin')
            fp = np.memmap(mmapped_file_path, dtype='uint8', mode='w+', shape=clip_np.shape)
            fp[:] = clip_np[:]
            fp.flush()
            del fp
            clip_paths[idx] = mmapped_file_path

        return clip_paths

    def build_existing_clip_paths(self, clip_dest_dir):
        """"
        returns dict of clip_idx -> mmapped file path
        from existing .bin files
        """
        clips_paths = {}
        for filename in os.listdir(clip_dest_dir):
            if filename.startswith('clip_') and filename.endswith('.bin'):
                idx = int(filename.split('_')[1].split('.')[0])
                file_path = os.path.join(clip_dest_dir, filename)
                clips_paths[idx] = file_path

        return clips_paths

    def build_clip_refs(self, clip_paths):
        """
        Build mmap reference to bin files
        returns dict of clip_idx -> np.array mmapped to respective bin file
        """
        clips = {}
        for idx, path in tqdm(clip_paths.items(), desc='Building clip refs'):
            clips[idx] = np.memmap(path, dtype='uint8', mode='r')

        return clips

    def __len__(self):
        return len(self.clip_indices)

    def __getitem__(self, idx):
        clip = self.clips[self.clip_indices[idx]]
        return self.postprocess_transforms(torch.tensor(clip, dtype=torch.float32))

# MAGVIT-V2

## Model

## Testing

## Training

## Sweeps

# Transformer Decoder

## Model

## Testing

## Training

## Sweeps

# Super Resolution

## Model

## Testing

## Training

## Sweeps