<a href="https://colab.research.google.com/github/LuthandoMaqondo/magvit2-pytorch/blob/luthando-contribution/notebooks/training.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount the Drive

In [1]:
import os
import sys
import platform
import requests
import torch
import wandb
from getpass import getpass


try:
    from google.colab import drive
    IN_COLAB = True
except:
    WORKING_DIR = '.'
    IN_COLAB = False
if IN_COLAB:
    WORKING_DIR = '/content/drive/MyDrive/Colab Notebooks'
    drive.mount('/content/drive',  force_remount=True)

# START The MAGVIT-v2

In [2]:
# !pip install magvit2-pytorch

In [6]:
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

tokenizer = VideoTokenizer(
    image_size = 128,
    init_dim = 64,
    max_dim = 512,
    codebook_size = 1024,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'linear_attend_space',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
        'compress_time',
        ('consecutive_residual', 2),
        'compress_time',
        ('consecutive_residual', 2),
        'attend_time',
    )
)

dataset_folder = os.path.expanduser(f"{WORKING_DIR}/datasets/Appimate/train") if IN_COLAB else os.path.expanduser(f"~/.cache/datasets/Appimate/train")
trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder = dataset_folder,     # folder of either videos or images, depending on setting below
    dataset_type = 'videos',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 4,
    grad_accum_every = 8,
    learning_rate = 2e-5,
    num_train_steps = 1_000,
    use_wandb_tracking = True,
)
with trainer.trackers(project_name = 'magvit2', run_name = 'baseline'):
    if wandb.api.api_key is None:
        key = getpass("Access key: ")
        wandb.login(key=key, relogin=True)
    trainer.train()

10 training samples found at /Users/luthandomaqondo/.cache/datasets/Appimate/train
training with dataset of 9 samples and validating with randomly splitted 1 samples


python(16413) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(16414) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
python(16419) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(16420) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
[34m[1mwandb[0m: Currently logged in as: [33mluthando957[0m. Use [1m`wandb login --relogin`[0m to force relogin
python(16430) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(16431) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011145238888841252, max=1.0…

step 0


RuntimeError: Conv3D is not supported on MPS

In [None]:
# # after a lot of training ...
# # can use the EMA of the tokenizer
# ema_tokenizer = trainer.ema_tokenizer

# # mock video
# video = torch.randn(1, 3, 17, 128, 128)

# # tokenizing video to discrete codes
# codes = ema_tokenizer.tokenize(video) # (1, 9, 16, 16) <- in this example, time downsampled by 4x and space downsampled by 8x. flatten token ids for (non)-autoregressive training

# # sanity check
# decoded_video = ema_tokenizer.decode_from_code_indices(codes)
# assert torch.allclose(
#     decoded_video,
#     ema_tokenizer(video, return_recon = True)
# )