In [None]:
!pip install -q wandb pytorch_lightning av imageio

In [2]:
import os
import math
import wandb
import imageio
import torch
import torchvision
import numpy as np
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
from torch import nn, optim
from torch.nn import functional as F
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale
import torchvision.transforms.functional as TF

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.tuner import Tuner
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

pl.seed_everything(42)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
steamboat_willie_gdrive_path = '/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4'
!cp -r /content/drive/My\ Drive/SteamboatWillie/clips .

# Data

In [None]:
class SteamboatWillieDataset(Dataset):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.preprocess_transforms = Compose([
                Lambda(lambda x: x.permute(0, 3, 1, 2)), # (T, H, W, C) to (T, C, H, W) for Greyscale
                Grayscale(num_output_channels=1), # Convert to grayscale
                Lambda(lambda x: x.permute(1, 0, 2, 3)), # (T, C, H, W) to (C, T, H, W) for Conv3d
                Lambda(lambda x: CenterCrop((480, 575))(x)), # Center crop to remove virtical bars
                Lambda(lambda x: Resize((config.img_size, config.img_size))(x))
                # Lambda(lambda x: Resize((config.img_size, config.img_size), interpolation=TF.InterpolationMode.BICUBIC)(x)), # Resize frames
        ])

        self.postprocess_transforms = Compose([
            Lambda(lambda x: x / 255.),
            Lambda(lambda x: x.view(self.config.in_channels, self.config.clip_length, self.config.img_size, self.config.img_size))
        ])

        if os.path.exists(config.dest_dir):
            clip_paths = self.build_existing_clip_paths(config.dest_dir)
            self.clips = self.build_clip_refs(clip_paths)
        else:
            video_clips = VideoClips(
                config.paths,
                clip_length_in_frames=config.clip_length,
                frames_between_clips=config.clip_length
            )

            self.clips = self.build_clip_refs(self.build_clip_paths(video_clips, self.preprocess_transforms, config.dest_dir))

    def build_clip_paths(self, video_clips, transforms, dest_dir):
        """
        Build set of binary files to store processed video clips
        returns dict of clip_idx -> mmapped file path
        """
        clip_paths = {}

        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)

        for idx in tqdm(range(video_clips.num_clips()), desc='Creating clip .bin files'):
            # transform clips and write to mmap file
            clip, _, _, _ = video_clips.get_clip(idx)
            clip = self.preprocess_transforms(clip)
            clip_np = clip.numpy().astype(np.uint8)

            mmapped_file_path = os.path.join(dest_dir, f'clip_{idx}.bin')
            fp = np.memmap(mmapped_file_path, dtype='uint8', mode='w+', shape=clip_np.shape)
            fp[:] = clip_np[:]
            fp.flush()
            del fp
            clip_paths[idx] = mmapped_file_path

        return clip_paths

    def build_existing_clip_paths(self, dest_dir):
        """"
        returns dict of clip_idx -> mmapped file path
        from existing .bin files
        """
        clips_paths = {}
        for filename in os.listdir(dest_dir):
            if filename.startswith('clip_') and filename.endswith('.bin'):
                idx = int(filename.split('_')[1].split('.')[0])
                file_path = os.path.join(dest_dir, filename)
                clips_paths[idx] = file_path

        return clips_paths

    def build_clip_refs(self, clip_paths):
        """
        Build mmap reference to bin files
        returns dict of clip_idx -> np.array mmapped to respective bin file
        """
        clips = {}
        for idx, path in tqdm(clip_paths.items(), desc='Building clip refs'):
            clips[idx] = np.memmap(path, dtype='uint8', mode='r')

        return clips

    def __len__(self):
        return len(self.clips)

    def __getitem__(self, idx):
        clip = self.clips[idx]
        return self.postprocess_transforms(torch.tensor(clip, dtype=torch.float32))


class SteamboatWillieDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.batch_size = config.batch_size
        self.config = config

    def prepare_data(self):
        self.full_dataset = SteamboatWillieDataset(self.config)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            train_len = int(len(self.full_dataset) * self.config.train_split)
            val_len = len(self.full_dataset) - train_len
            self.train_dataset, self.val_dataset = random_split(self.full_dataset, [train_len, val_len])

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.config.num_workers)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.config.num_workers)

# Model Components

In [None]:
class ResBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels
        ):
        super().__init__()
        if in_channels != out_channels:
            self.identity = nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            )
        else:
            self.identity = nn.Identity()

        self.block = nn.Sequential(
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm3d(in_channels),
            nn.ReLU(),
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            ),
            nn.BatchNorm3d(out_channels)
        )
        self.res_act = nn.ReLU()

    def forward(self, x):
        out = self.block(x) + self.identity(x)
        return self.res_act(out)

In [None]:
class Encoder(nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            nlayers,
            nblocks
        ):
        super().__init__()
        self.downsample_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.Conv3d(
                    in_channels=in_channels if i==0 else hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=(4, 4, 4),
                    stride=(2, 2, 2),
                    padding=(1, 1, 1)
                ),
                nn.BatchNorm3d(hidden_channels),
                nn.ReLU()
            ) for i in range(nlayers)
        ])

        self.res_blocks = nn.Sequential(*[
            ResBlock(
                in_channels=hidden_channels,
                out_channels=out_channels if i==nblocks-1 else hidden_channels
            ) for i in range(nblocks)
        ])

    def forward(self, x):
        x = self.downsample_blocks(x)
        x = self.res_blocks(x)
        return x

In [None]:
class Decoder(torch.nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            nlayers,
            nblocks
        ):
        super().__init__()
        self.res_blocks = nn.Sequential(*[
            ResBlock(
                in_channels=in_channels if i==0 else hidden_channels,
                out_channels=hidden_channels,
            )
            for i in range(nblocks)
        ])

        self.upsample_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.ConvTranspose3d(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=(4, 4, 4),
                    stride=(2, 2, 2),
                    padding=(1, 1, 1)),
                nn.BatchNorm3d(hidden_channels),
                nn.ReLU()
            ) for i in range(nlayers-1)
        ])

        self.out_layer = nn.ConvTranspose3d(
            in_channels=hidden_channels,
            out_channels=out_channels,
            kernel_size=(4, 4, 4),
            stride=(2, 2, 2),
            padding=(1, 1, 1)
        )

        self.out_act = nn.Sigmoid()

    def forward(self, x):
        x = self.res_blocks(x)
        x = self.upsample_blocks(x)
        x = self.out_layer(x)
        x = self.out_act(x)
        return x

In [None]:
class QuantizerEMA(nn.Module):
    def __init__(
            self,
            codebook_size,
            latent_channels,
            ema_gamma,
            commit_loss_beta,
            track_codebook
        ):
        super().__init__()
        self.codebook_size = codebook_size
        self.latent_channels = latent_channels
        self.ema_gamma = ema_gamma
        self.commit_loss_beta = commit_loss_beta
        self.track_codebook = track_codebook

        self.codebook = nn.Embedding(codebook_size, latent_channels)
        nn.init.uniform_(self.codebook.weight, -1/codebook_size, 1/codebook_size)

        self.register_buffer('N', torch.zeros(codebook_size) + 1e-6)
        self.register_buffer('m', torch.zeros(codebook_size, latent_channels))
        nn.init.uniform_(self.m, -1/codebook_size, 1/codebook_size)

        if track_codebook:
            self.register_buffer('codebook_usage', torch.zeros(codebook_size, dtype=torch.float))
            self.register_buffer('total_usage', torch.tensor(0, dtype=torch.float))

    def ema_update(self, code_idxs, flat_inputs):
        # we don't want to track grads for ops in EMA calculation
        code_idxs, flat_inputs = code_idxs.detach(), flat_inputs.detach()

        # number of vectors for each i which quantize to e_i
        n = torch.bincount(code_idxs, minlength=self.codebook_size)

        # sum of vectors for each i which quantize to code e_i
        one_hot_indices = F.one_hot(code_idxs, num_classes=self.codebook_size).type(flat_inputs.dtype)
        embed_sums = one_hot_indices.T @ flat_inputs

        # update EMA of code usage and sum of codes
        self.N = self.N * self.ema_gamma + n * (1 - self.ema_gamma)
        self.m = self.m * self.ema_gamma + embed_sums * (1 - self.ema_gamma)

        self.codebook.weight.data.copy_(self.m / self.N.unsqueeze(-1))

    def reset_usage_stats(self):
        self.codebook_usage.zero_()
        self.total_usage.zero_()

    def calculate_perplexity(self, enc_idxs):
        unique_indices, counts = torch.unique(enc_idxs, return_counts=True)
        self.codebook_usage.index_add_(0, unique_indices, counts.float())
        self.total_usage += torch.sum(counts)

        if self.total_usage > 0:
            probs = self.codebook_usage / self.total_usage
            perplexity = torch.exp(-torch.sum(probs * torch.log(probs + 1e-10)))
            return perplexity
        else:
            return torch.tensor([0.0])

    def forward(self, inputs):
        B, C, T, H, W = inputs.shape

        # (B, C, T, H, W) --> (B, T, H, W, C)
        inputs_permuted = inputs.permute(0, 2, 3, 4, 1).contiguous()
        # (B, T, H, W, C) --> (BTHW, C)
        flat_inputs = inputs_permuted.reshape(-1, self.latent_channels)

        # (BTHW, Codebook Size)
        # Σ(x-y)^2 = Σx^2 - 2xy + Σy^2
        dists = (
            torch.sum(flat_inputs ** 2, dim=1, keepdim=True) - # Σx^2
            2 * (flat_inputs @ self.codebook.weight.t()) +     # 2*xy
            torch.sum(self.codebook.weight ** 2, dim=1)        # Σy^2
        )

        # (BTHW, 1)
        code_idxs = torch.argmin(dists, dim=1)
        # (BTHW, C)
        codes = self.codebook(code_idxs)
        # (BTHW, C) --> (B, T, H, W, C)
        quantized_inputs = codes.reshape(B, T, H, W, C)
        # (B, T, H, W, C) --> (B, C, T, H, W)
        quantized_inputs = quantized_inputs.permute(0, 4, 1, 2, 3).contiguous()

        if self.training:
            # perform exponential moving average update for codebook
            self.ema_update(code_idxs, flat_inputs)

        # "since the volume of the embedding space is dimensionless, it can grow
        # arbitrarily if the embeddings e_i do not train as fast as the encoder
        # parameters. To make sure the encoder commits to an embedding and its
        # output does not grow, we add a commitment loss"
        commitment_loss = F.mse_loss(quantized_inputs.detach(), inputs)

        # part 3 of full loss (ie. not including reconstruciton loss or embedding loss)
        vq_loss = commitment_loss * self.commit_loss_beta

        # sets the output to be the input plus the residual value between the
        # quantized latents and the inputs like a resnet for Straight Through
        # Estimation (STE)
        quantized_inputs = inputs + (quantized_inputs - inputs).detach()

        if self.track_codebook:
            perplexity = self.calculate_perplexity(code_idxs)

        return {
            'q_z':              quantized_inputs,
            'vq_loss':          vq_loss,
            'commitment_loss':  commitment_loss,
            'perplexity':       perplexity if self.track_codebook else torch.tensor([0.0])
        }

# Model

In [None]:
class VQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(
            in_channels     = config.in_channels,
            hidden_channels = config.hidden_channels,
            out_channels    = config.latent_channels,
            nlayers         = config.nlayers,
            nblocks         = config.nblocks
        )
        self.decoder = Decoder(
            in_channels     = config.latent_channels,
            hidden_channels = config.hidden_channels,
            out_channels    = config.in_channels,
            nlayers         = config.nlayers,
            nblocks         = config.nblocks
        )
        self.quantizer = QuantizerEMA(
            codebook_size    = config.codebook_size,
            latent_channels  = config.latent_channels,
            ema_gamma        = config.ema_gamma,
            commit_loss_beta = config.commit_loss_beta,
            track_codebook   = config.track_codebook,
        )

    def encode(self, inputs):
        return self.encoder(inputs)

    def quantize(self, z):
        return self.quantizer(z)

    def decode(self, q_z):
        return self.decoder(q_z)

    def loss(self, x_hat, x, quantized):
        MSE = F.mse_loss(x_hat, x)
        loss = MSE + quantized['vq_loss']

        return {
            'MSE':  MSE,
            'loss': loss,
            **quantized
        }

    def forward(self, x):
        z = self.encode(x)
        quantized = self.quantize(z)
        x_hat = self.decode(quantized['q_z'])
        losses = self.loss(x_hat, x, quantized)

        return {'x_hat': x_hat, **losses}

# Lightning Module

In [None]:
class LitVQVAE(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.lr = config.lr

        if self.logger:
            self.logger.experiment.config.update(config)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch
        out = self(x)

        self.log('train/loss',            out['loss'],            prog_bar=True)
        self.log('train/MSE',             out['MSE'],             prog_bar=True)
        self.log('train/vq_loss',         out['vq_loss'],         prog_bar=True)
        self.log('train/commitment_loss', out['commitment_loss'], prog_bar=True)
        self.log('train/perplexity',      out['perplexity'],      prog_bar=True)

        return out['loss']

    def validation_step(self, batch, batch_idx):
        x = batch
        out = self(x)

        self.log('val/loss',            out['loss'],            prog_bar=True)
        self.log('val/MSE',             out['MSE'],             prog_bar=True)
        self.log('val/vq_loss',         out['vq_loss'],         prog_bar=True)
        self.log('val/commitment_loss', out['commitment_loss'], prog_bar=True)

        if batch_idx == 0:
            self.log_val_clips(x, out)

        return out['loss']

    def on_epoch_end(self):
        # tracking perplexity per epoch
        self.model.quantizer.reset_usage_stats()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            betas=(self.config.beta1, self.config.beta2),
            weight_decay=self.config.weight_decay
        )

        if self.config.use_lr_schedule:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config.max_epochs)
            return [optimizer], [scheduler]

        return optimizer

    def log_val_clips(self, x, out, num_clips=2):
        n_clips = min(x.size(0), num_clips)

        for i in range(n_clips):
            # Extract the ith original and reconstructed clip
            original_clip = x[i]  # (C, T, H, W)
            reconstructed_clip = out['x_hat'][i]  # (C, T, H, W)

            # convert tensors to numpy arrays and transpose to (T, H, W, C) for GIF creation
            original_clip_np = original_clip.permute(1, 2, 3, 0).cpu().numpy()
            reconstructed_clip_np = reconstructed_clip.permute(1, 2, 3, 0).cpu().numpy()

            original_clip_np = (original_clip_np - original_clip_np.min()) / (original_clip_np.max() - original_clip_np.min())
            reconstructed_clip_np = (reconstructed_clip_np - reconstructed_clip_np.min()) / (reconstructed_clip_np.max() - reconstructed_clip_np.min())

            original_clip_np = (original_clip_np * 255).astype(np.uint8)
            reconstructed_clip_np = (reconstructed_clip_np * 255).astype(np.uint8)

            # grayscale videos need to be of shape (T, H, W)
            if original_clip_np.shape[-1] == 1:
                original_clip_np = original_clip_np.squeeze(-1)

            if reconstructed_clip_np.shape[-1] == 1:
                reconstructed_clip_np = reconstructed_clip_np.squeeze(-1)

            # create GIFs for the original and reconstructed clips
            original_gif_path = f'/tmp/original_clip_{i}.gif'
            reconstructed_gif_path = f'/tmp/reconstructed_clip_{i}.gif'
            imageio.mimsave(original_gif_path, original_clip_np, fps=5)
            imageio.mimsave(reconstructed_gif_path, reconstructed_clip_np, fps=5)

            # log the GIFs to wandb
            self.logger.experiment.log({
                f"val/original_clip_{i}": wandb.Video(original_gif_path, fps=5, format="gif", caption="Original"),
                f"val/reconstructed_clip_{i}": wandb.Video(reconstructed_gif_path, fps=5, format="gif", caption="Reconstructed")
            })

# Config

In [None]:
def get_num_downsample_layers(img_size):
    """
    Get the number of strided Conv2D layers
    required to produce a 2x2 output volume
    """
    if img_size < 2:
        raise ValueError("Image size must be at least 2x2.")

    # Calculate the minimum number of downsample layers required for 2x2 final
    num_layers = math.ceil(math.log2(img_size / 2))
    return num_layers

def build_channel_dims(start_channels, nlayers):
    """
    Construct a list of channel counts for nlayers downsample layers
    assuming the channels double as spatial dims halve
    """
    channels = []
    for _ in range(nlayers):
        channels.append(start_channels)
        start_channels *= 2
    return channels

class VideoVQVAEConfig:
    def __init__(self):
        # dataset properties
        self.paths = ['/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4']
        self.dest_dir = './clips/'
        # model checkpoints
        self.checkpoint_path = "./checkpoints"
        self.save_top_k = 1
        # training
        self.train_split = 0.8
        self.batch_size = 32
        self.max_epochs = 120
        self.training_steps = 100000
        self.num_workers = 2
        # optimizer
        self.lr = 1e-4
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.weight_decay = 0.0 # 1e-2
        self.use_wd_schedule = False
        self.use_lr_schedule = False
        # input properties
        self.clip_length = 16
        self.img_size = 256
        self.in_channels = 1
        # latents / quantization
        self.latent_channels = 16
        self.codebook_size = 1024
        self.commit_loss_beta = 0.25
        self.track_codebook = True
        self.use_ema = True
        self.ema_gamma = 0.99
        # encoder/decoder
        self.hidden_channels = 256
        self.nblocks = 2
        self.nlayers = 4

    def update(self, updates):
        for key, value in updates.items():
            if hasattr(self, key):
                setattr(self, key, value)

    def to_dict(self):
        return {k: v for k, v in self.__dict__.items()}

# Component Testing

In [None]:
config = VideoVQVAEConfig()
data_module = SteamboatWillieDataModule(config)
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

In [None]:
train_batch = next(iter(train_loader))
train_batch.shape

In [None]:
encoder = Encoder()
sample_enc = encoder(train_batch)
sample_enc.shape

In [None]:
quantizer = QuantizerEMA(
    codebook_size=512,
    latent_channels=16,
    ema_gamma=0.99,
    commit_loss_beta=0.25,
    track_codebook=False
)
qz = quantizer(sample_enc)
qz['quantized_inputs'].shape

In [None]:
decoder = Decoder()
recon_clip = decoder(sample_enc)
recon_clip.shape

In [None]:
config = VideoVQVAEConfig()
vqvae = VQVAE(config)
output = vqvae(train_batch)

In [None]:
print(f'q_z.shape: {output["q_z"].shape}')
print(f'x_hat.shape: {output["x_hat"].shape}')

# Train

In [None]:
config = VideoVQVAEConfig()
model = VQVAE(config)
lit_model = LitVQVAE(model, config)
steamboat_willie_data = SteamboatWillieDataModule(config)

wandb.init(project="VQ-VAE Steamboat Willie", config=config.to_dict())
wandb_logger = WandbLogger(project="VQ-VAE Steamboat Willie", log_model=False)
wandb_logger.watch(lit_model, log="all")

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
    dirpath=config.checkpoint_path,
    filename='model-{epoch:02d}-{val_loss:.2f}',
    every_n_epochs=5,
    save_top_k=config.save_top_k,
    monitor='val/loss',
    mode='min',
    save_last=True
)

# Define the EarlyStopping callback
early_stop_callback = EarlyStopping(
    monitor='val/loss',
    min_delta=0.00,
    patience=3,
    verbose=True,
    check_finite=True
)

In [None]:
#wandb.finish()

In [None]:
trainer = pl.Trainer(
    max_epochs=config.max_epochs,
    devices=1,
    accelerator="gpu",
    precision="16-mixed",
    logger=wandb_logger,
    callbacks=[
        lr_monitor,
        early_stop_callback,
        # checkpoint_callback
    ],
    log_every_n_steps=1,
    # overfit_batches=1,
)

# tuner = Tuner(trainer)
# tuner.lr_find(lit_model, datamodule=steamboat_willie_data)

In [None]:
trainer.fit(lit_model, steamboat_willie_data)
wandb.finish()

# Display Clip

In [None]:
def display_clip(clip):
    def update(frame_idx):
        ax.clear()
        ax.imshow(video_clip_np[frame_idx], cmap='gray')
        ax.axis('off')

    video_clip_np = clip.permute(1, 2, 3, 0).numpy()
    fig, ax = plt.subplots()
    ani = FuncAnimation(fig, update, frames=range(video_clip_np.shape[0]), interval=50)
    plt.close()
    display(HTML(ani.to_html5_video()))

In [None]:
clip = train_batch[7].view(1, -1, 256, 256)
display_clip(clip)