<a href="https://colab.research.google.com/github/JordanLazzaro/VideoGen/blob/main/notebooks/VideoPoet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VideoPoet Implementation

![](https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgxFblHaHRJNH7Oi2_oOTosGN9XrjgjhWmnfADchMT8WR0XAo6SxiUfpUmn5R6akciiRduaKIMdgwHZzK3xW8mErarQ_ugx41ctQAMK08O9UMVevgkk-AgFI1xYFWAomd16OcOh0R-XpyZVLQXncpk2SHf-RmPzrqBbIWZc-nUG2TH6nC2R7qyHXn8eTC-u/s2680/image21.png)

[**VideoPoet**](https://research.google/blog/videopoet-a-large-language-model-for-zero-shot-video-generation/) is an autoregressive transformer decoder which models sequences of discrete, multimodal tokens produced by modality specific tokenizers; namely [MAGVIT-V2](https://arxiv.org/abs/2310.05737) for images/video, and [SoundStream](https://arxiv.org/abs/2107.03312) for audio. For text, the model leverages the T5 tokenizer as well as frozen T5 token embeddings. The set of these tokens together represent the model's full vocabulary. VideoPoet also employs a custom Super Resolution model to allow for lower resolution videos to be produced by the transformer for efficiency.

Implementing the video portion of this model will require the following checkpoints:

- [ ] Implement MAGVIT-V2 tokenizer
    - [ ] Dialated Causal Convolution (in time dim)
    - [ ] Blur Pool
    - [ ] LFQ (can replace with fsq?)
    - [ ] VQVAE (FSQVAE)
    - [ ] Descriminator / GAN Loss
- [ ] Implement the Transformer Decoder
- [ ] Implement the Super Resolution model
- [ ] Incorporate audio (optional/if feasible)

# Setup

In [None]:
!pip install -q wandb pytorch_lightning av imageio
!pip install -q flash-attn --no-build-isolation

In [None]:
import os
import gc
import math
import wandb
import imageio
import torch
import torchvision
import numpy as np
import pytorch_lightning as pl
from google.colab import drive
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
from torch import nn, optim
from torch.nn import functional as F
from torchvision.datasets.video_utils import VideoClips
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Lambda, Resize, ToTensor, CenterCrop, Grayscale
import torchvision.transforms.functional as TF

from pytorch_lightning import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.tuner import Tuner
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

pl.seed_everything(42)

In [None]:
drive.mount('/content/drive')
steamboat_willie_gdrive_path = '/content/drive/My Drive/SteamboatWillie/SteamboatWillie.mp4'
!cp -r /content/drive/My\ Drive/SteamboatWillie/clips .

In [None]:
def display_clip(clip):
    '''
    util method for displaying tensors as video

    expects clip.shape = (C,T,H,W)
    '''
    assert len(clip.shape) == 4, 'clip shape must be PyTorch Tensor of shape (C,T,H,W)'

    def update(frame_idx):
        ax.clear()
        ax.imshow(video_clip_np[frame_idx], cmap='gray')
        ax.axis('off')

    video_clip_np = clip.permute(1, 2, 3, 0).numpy()
    fig, ax = plt.subplots()
    ani = FuncAnimation(fig, update, frames=range(video_clip_np.shape[0]), interval=60)
    plt.close()
    display(HTML(ani.to_html5_video()))

# Config

In [None]:
config = {
    "project": "SteamboatWillie VideoPoet",
    "data": {
        "paths": [steamboat_willie_gdrive_path],
        "clip_length": 16,
        "clip_dest_dir": "clips"
    },
    "model": {
        "magvit_v2": {
            "checkpoint_dir": "./checkpoints/magvit_v2"
        },
        "transformer": {
            "emb_dim": 512,
            "nheads": 8,
            "ctx_size": 2048,
            "window_size": 1024,
            "fan_out": 4,
            "nlayers": 6,
            "dropout": 0.0,
            "use_flash_attn": True,
            "checkpoint_dir": "./checkpoints/transformer"
        },
        "super_resolution": {
            "checkpoint_dir": "./checkpoints/super_resolution"
        }
    },
    "training": {
        "magvit_v2": {},
        "transformer": {},
        "super_resolution": {}
    },
    "logging": {}
}

# Data

## Clips Dataset

In [None]:
class RandomHorizontalFlipVideo(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, x):
        # x shape is expected to be (C, T, H, W)
        if torch.rand(1) < self.p:
            # flip all frames in the clip
            return x.flip(-1)
        return x


class SteamboatWillieDataset(Dataset):
    def __init__(self, config, mode='train', train_split=0.8, shuffle=True, augment=True):
        super().__init__()
        self.config = config
        self.train_split = train_split
        self.mode = mode

        self.preprocess_transforms = Compose([
                Lambda(lambda x: x.permute(0, 3, 1, 2)),   # (T, H, W, C) to (T, C, H, W) for Greyscale
                Grayscale(num_output_channels=1),          # Convert to grayscale
                Lambda(lambda x: x.permute(1, 0, 2, 3)),   # (T, C, H, W) to (C, T, H, W) for Conv3d
                CenterCrop((480, 575)),                    # Center crop to remove virtical bars
                Resize((config.img_size, config.img_size))
        ])

        self.postprocess_transforms = Compose([
            Lambda(lambda x: x / 255.),
            Lambda(lambda x: x.view(self.config.in_channels, self.config.clip_length, self.config.img_size, self.config.img_size))
        ])

        if self.mode == 'train' and augment:
            self.postprocess_transforms.transforms.append(RandomHorizontalFlipVideo(p=0.5))

        if os.path.exists(config.clip_dest_dir):
            clip_paths = self.build_existing_clip_paths(config.clip_dest_dir)
            self.clips = self.build_clip_refs(clip_paths)
        else:
            video_clips = VideoClips(
                config.paths,
                clip_length_in_frames=config.clip_length,
                frames_between_clips=config.clip_length
            )

            self.clips = self.build_clip_refs(self.build_clip_paths(video_clips, self.preprocess_transforms, config.clip_dest_dir))

        if mode in ['train', 'val']:
            total_clips = len(self.clips)
            if shuffle:
                indices = torch.randperm(total_clips).tolist()
            else:
                indices = list(range(total_clips))

            train_size = int(total_clips * train_split)

            if mode == 'train':
                self.clip_indices = indices[:train_size]
            else:
                self.clip_indices = indices[train_size:]
        elif mode == 'full':
            self.clip_indices = list(range(len(self.clips)))

    def build_clip_paths(self, video_clips, transforms, clip_dest_dir):
        """
        Build set of binary files to store processed video clips
        returns dict of clip_idx -> mmapped file path
        """
        clip_paths = {}

        if not os.path.exists(clip_dest_dir):
            os.makedirs(clip_dest_dir)

        for idx in tqdm(range(video_clips.num_clips()), desc='Creating clip .bin files'):
            # transform clips and write to mmap file
            clip, _, _, _ = video_clips.get_clip(idx)
            clip = self.preprocess_transforms(clip)
            clip_np = clip.numpy().astype(np.uint8)

            mmapped_file_path = os.path.join(clip_dest_dir, f'clip_{idx}.bin')
            fp = np.memmap(mmapped_file_path, dtype='uint8', mode='w+', shape=clip_np.shape)
            fp[:] = clip_np[:]
            fp.flush()
            del fp
            clip_paths[idx] = mmapped_file_path

        return clip_paths

    def build_existing_clip_paths(self, clip_dest_dir):
        """"
        returns dict of clip_idx -> mmapped file path
        from existing .bin files
        """
        clips_paths = {}
        for filename in os.listdir(clip_dest_dir):
            if filename.startswith('clip_') and filename.endswith('.bin'):
                idx = int(filename.split('_')[1].split('.')[0])
                file_path = os.path.join(clip_dest_dir, filename)
                clips_paths[idx] = file_path

        return clips_paths

    def build_clip_refs(self, clip_paths):
        """
        Build mmap reference to bin files
        returns dict of clip_idx -> np.array mmapped to respective bin file
        """
        clips = {}
        for idx, path in tqdm(clip_paths.items(), desc='Building clip refs'):
            clips[idx] = np.memmap(path, dtype='uint8', mode='r')

        return clips

    def __len__(self):
        return len(self.clip_indices)

    def __getitem__(self, idx):
        clip = self.clips[self.clip_indices[idx]]
        return self.postprocess_transforms(torch.tensor(clip, dtype=torch.float32))

## Lightning Datamodule

In [None]:
class SteamboatWillieDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.batch_size = config.batch_size
        self.config = config

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = SteamboatWillieDataset(self.config, mode='train')
            self.val_dataset = SteamboatWillieDataset(self.config, mode='val')

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.config.num_workers)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.config.num_workers)

# MAGVIT-V2

## Model

### Components

In [None]:
class CausalConv3d(nn.Module):
    '''
    enforces causality in the time dimension (https://paperswithcode.com/method/causal-convolution)
    https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889
    '''
    def __init__(self, in_channels, out_channels, kernel_size=(3,3,3), stride=1, dialation=1):
        super().__init__()
        k_t, k_h, k_w = kernel_size

        # pad: (left, right, top, bottom, front, back)
        pad_w, pad_h, pad_t = k_w//2, k_h//2, dilation * (k_t - 1) + (1 - stride)
        self.t_causal_pad = (pad_w, pad_w, pad_h, pad_h, pad_t, 0)

        stride = (stride, 1, 1)
        dilation = (dilation, 1, 1)
        
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, dialation)

    def forward(self, x):
        x = F.pad(x, self.t_causal_pad)
        x = self.conv(x)
        
        return x

In [None]:
class ResBlock3d(nn.Module):
    ''' handles both same and different in/out channels '''
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=(3,3,3)
        ):
        super().__init__()
        if in_channels != out_channels:
            self.identity = CausalConv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(1, 1, 1)
            )
        else:
            self.identity = nn.Identity()

        self.block = nn.Sequential(
            nn.GroupNorm(num_groups=2, num_channels=in_channels),
            nn.SiLU(),
            CausalConv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=1,
                dialation=1
            ),
            nn.GroupNorm(num_groups=2, num_channels=out_channels),
            nn.SiLU(),
            CausalConv3d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=1,
                dialation=1
            ),
        )

    def forward(self, x):
        out = self.block(x) + self.identity(x)
        return out

In [None]:
class ResBlockDown3d(nn.Module):
    ''' strided conv for down-sampling + blur pooling (https://arxiv.org/abs/1904.11486) ''' 
    def __init__(self, in_channels, out_hannels):
        super().__init__()

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

In [None]:
class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()

In [None]:
class LFQ(nn.Module):
    def __init__(self, config):
        super().__init__()

In [None]:
class FSQ(nn.Module):
    def __init__(self, levels, eps=1e-3):
        super().__init__()
        self.register_buffer('levels', torch.tensor(levels))
        # self.register_buffer(
        #     'basis',
        #     torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
        # )

        self.register_buffer('basis', 
            torch.cat([
                torch.tensor([1]),
                torch.cumprod(torch.tensor(levels[:-1]), dim=0)
            ], dim=0)
        )

        self.eps = eps
        self.codebook_size = torch.prod(self.levels)

    def round_ste(self, z):
        z_q = torch.round(z)
        return z + (z_q - z).detach()

    def quantize(self, z):
        # half_l is used to determine how to scale tanh; we
        # subtract 1 from the number of levels to account for 0
        # being a quantization bin and tanh being symmetric around 0
        half_l = (self.levels - 1) * (1 - self.eps) / 2

        # if a given level is even, it will result in a scale for tanh
        # which is halfway between integer values, so we offset
        # the tanh output down by 0.5 to line it with whole integers
        offset = torch.where(self.levels % 2 == 0, 0.5, 0.0)

        # if our level is even, we want to shift the tanh input to
        # ensure the 0 quantization bin is centered
        shift = torch.tan(offset / half_l)

        # once we have our shift and offset (in the case of an even level)
        # we can round to the nearest integer bin and allow for STE
        z_q = self.round_ste(torch.tanh(z + shift) * half_l - offset)

        # after quantization, we want to renormalize the quantized
        # values to be within the range expected by the model (ie. [-1, 1])
        half_width = self.levels // 2
        return z_q / half_width

    def scale_and_shift(self, z_q_normalized):
        half_width = self.levels // 2
        return (z_q_normalized * half_width) + half_width

    def scale_and_shift_inverse(self, z_q):
        half_width = self.levels // 2
        return (z_q - half_width) / half_width

    def codes_to_idxs(self, z_q):
        assert z_q.shape[-1] == len(self.levels)
        z_q = self.scale_and_shift(z_q)
        return (z_q * self.basis).sum(dim=-1).to(torch.int32)

    def idxs_to_codes(self, idxs):
        idxs = idxs.unsqueeze(-1)
        codes_not_centered = (idxs // self.basis) % self.levels
        return self.scale_and_shift_inverse(codes_not_centered)

    def forward(self, z):
        # TODO: make this work for generic tensor sizes
        # TODO: use einops to clean up

        if len(z.shape) == 5: # video
            B, C, T, H, W = z.shape
            # (B, C, T, H, W) -> (B, T, H, W, C)
            z_c_last = z.permute(0, 2, 3, 4, 1).contiguous()
            # (B, T, H, W, C) -> (BTHW, C)
            z_flatten = z_c_last.reshape(-1, C)
            z_flatten_q = self.quantize(z_flatten)
            # (BTHW, C) -> (B, T, H, W, C) -> (B, C, T, H, W)
            z_q = z_flatten_q.reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()

        elif len(z.shape) == 4: # image
            B, C, H, W = z.shape
            # (B, C, H, W) -> (B, H, W, C)
            z_c_last = z.permute(0, 2, 3, 1).contiguous()
            # (B, H, W, C) -> (BHW, C)
            z_flatten = z_c_last.reshape(-1, C)
            z_flatten_q = self.quantize(z_flatten)
            # (BHW, C) -> (B, H, W, C) -> (B, C, T, H, W)
            z_q = z_flatten_q.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()

        elif len(z.shape) == 2: # vector sequence
            # (B, C)
            z_q = self.quantize(z)

        return {'z_q': z_q}

### Models

In [None]:
class VQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()

In [None]:
class Descriminator(nn.Module):
    ''' bunch of downsampling resblocks + leaky relu '''
    def __init__(self, config):
        super().__init__()

In [None]:
class MAGVIT2Tokenizer(nn.Module):
    def __init__(self, config):
        super().__init__()

## Lightning Module

## Testing

## Training

## Sweeps

# Transformer Decoder

## Model

In [None]:
class Attention(nn.Module):
    def __init__(
            self,
            emb_dim,
            nheads,
            ctx_size,
            window_size,
            dropout=0.0,
            use_flash_attn=True
        ):
        super().__init__()
        assert emb_dim % nheads == 0

        self.W_Q = nn.Linear(emb_dim, emb_dim, bias=False)
        self.W_K = nn.Linear(emb_dim, emb_dim, bias=False)
        self.W_V = nn.Linear(emb_dim, emb_dim, bias=False)

        self.W_O = nn.Linear(emb_dim, emb_dim, bias=False)

        if not use_flash_attn:
            self.register_buffer(
                "mask",
                torch.tril(torch.ones((ctx_size, ctx_size))).reshape(
                    1, 1, ctx_size, ctx_size
                ),
            )

        self.head_dim = emb_dim // nheads
        self.nheads = nheads
        self.window_size = window_size

        self.dropout_p = dropout
        self.dropout = nn.Dropout(dropout)

        self.use_flash_attn = use_flash_attn

    def forward(self, x, slopes):
        B, T, D = x.size()

        # (B, T, D) -> (B, T, H, H_D)
        Q = self.W_Q(x).reshape(B, T, self.nheads, self.head_dim)
        K = self.W_K(x).reshape(B, T, self.nheads, self.head_dim)
        V = self.W_V(x).reshape(B, T, self.nheads, self.head_dim)

        if self.use_flash_attn:
            out = flash_attn_func(
                Q, K, V,
                dropout_p=self.dropout_p if self.training else 0.0,
                softmax_scale=None,
                causal=True,
                window_size=(self.window_size, 0),
                alibi_slopes=slopes.to(torch.float32),
                deterministic=False
            )
        else:
            # (B, T, H, H_D) -> (B, H, T, H_D)
            Q = Q.transpose(1, 2)
            K = K.transpose(1, 2)
            V = V.transpose(1, 2)

            # (B, H, T, H_D) @ (B, H, H_D, T) -> (B, H, T, T)
            attn = (Q @ K.transpose(-2, -1)) / (1.0 * math.sqrt(self.head_dim))
            # attn = attn + bias
            attn = attn.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
            attn = F.softmax(attn, dim=-1)

            attn = self.dropout(attn)

            # ((B, H, T, T)) @ (B, H, T, H_D) -> (B, H, T, H_D)
            out = attn @ V

        # (B, H, T, H_D) -> (B, T, H, H_D) -> (B, T, D)
        out = out.transpose(1, 2).reshape(B, T, self.emb_dim)

        # (B, T, D)
        out = self.W_O(out)

        return out

In [None]:
class MLP(nn.Module):
    def __init__(self, emb_dim, fan_out=4, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(emb_dim, fan_out * emb_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(fan_out * emb_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.dropout(x)

        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(
            self,
            emb_dim,
            nheads,
            ctx_size,
            window_size,
            fan_out=4,
            dropout=0.0,
            use_flash_attn=True
        ):
        super().__init__()
        self.ln_1 = nn.LayerNorm(emb_dim)
        self.attn = Attention(
            emb_dim,
            nheads,
            ctx_size,
            window_size,
            dropout,
            use_flash_attn
        )
        self.ln_2 = nn.LayerNorm(emb_dim)
        self.mlp = MLP(emb_dim, fan_out, dropout)

    def forward(self, x, slopes):
        x = x + self.attn(self.ln_1(x), slopes)
        x = x + self.mlp(self.ln_2(x))

        return x

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.blocks = nn.ModuleList([
            TransformerBlock(
                config.emb_dim,
                config.nheads,
                config.ctx_size,
                config.window_size,
                config.fan_out,
                config.dropout,
                config.use_flash_attn
            ) for _ in range(config.nlayers)
        ])
        self.pred_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
        self.ln_f = nn.LayerNorm(config.emb_dim)

        self.register_buffer("slopes", self.get_alibi_slope(config.nheads))
        self.window_size = config.window_size

        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        if not self.config.use_flash_attn:
            bias = (self.m * self.get_relative_positions(self.window_size).to(x.device)).unsqueeze(0)

        x = self.emb_up_proj(x)

        x = self.dropout(x)

        for block in self.blocks:
            x = block(x, self.slopes)

        logits = self.pred_head(self.ln_f(x))

        return logits

    def get_relative_positions(self, seq_len: int) -> torch.tensor:
        x = torch.arange(seq_len)[None, :]
        y = torch.arange(seq_len)[:, None]
        return (x - y).clamp_max_(0)


    def get_alibi_slope(self, num_heads):
        x = (2 ** 8) ** (1 / num_heads)
        if self.config.use_flash_attn:
            x = torch.tensor([1 / x ** (i + 1) for i in range(num_heads)])
        else:
            x = torch.tensor([1 / x ** (i + 1) for i in range(num_heads)]).unsqueeze(-1).unsqueeze(-1)

        return x

## Lightning Module

In [None]:
class StepwiseValidationScheduleCallback(Callback):
    def __init__(self, config):
        super().__init__()
        self.config = config

    def on_batch_end(self, trainer, pl_module):
        if (trainer.global_step + 1) % self.config.validation_step_interval == 0:
            trainer.validate(pl_module)

In [None]:
class StepwiseModelCheckpoint(ModelCheckpoint):
    def __init__(self, save_step_frequency, dirpath, filename="{step}", **kwargs):
        super().__init__(dirpath=dirpath, filename=filename, **kwargs)
        self.save_step_frequency = save_step_frequency

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=dataloader_idx)

        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            filepath = self.format_checkpoint_name(global_step, {}, verbatim=True).format(step=global_step)
            filepath = os.path.join(self.dirpath, f"{filepath}-step={global_step}.ckpt")
            self._save_model(filepath, trainer, pl_module)

            if self.save_top_k > 0:
                self._del_model(trainer, pl_module)

In [None]:
class LitTransformerDecoder(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.lr = config.lr

        if self.logger:
            self.logger.experiment.config.update(config)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))

        self.log('train/loss', loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)

        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))

        self.log('val/loss', loss, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            betas=(self.config.beta1, self.config.beta2),
            weight_decay=self.config.weight_decay
        )

        if self.config.use_lr_schedule:
            scheduler = {
                'scheduler': CosineAnnealingLR(optimizer, T_max=self.config.training_steps, eta_min=0),
                'interval': 'step',
                'frequency': 1,
            }
            return [optimizer], [scheduler]

        return optimizer

## Testing

## Training

## Sweeps

# Super Resolution

## Model

## Lightning Module

## Testing

## Training

## Sweeps