# VQ-VAE JAX Training on Kaggle TPU

This notebook trains the VQ-VAE model using JAX on Kaggle TPUs.

## Prerequisites
1. Enable TPU accelerator in Kaggle settings
2. Add your `momask-codes` repo as a dataset
3. Add HumanML3D dataset (or use dummy data for testing)

In [32]:
# Cell 1: Install dependencies
# JAX is pre-installed on Kaggle TPU
# Pin flax to 0.11.x to avoid breaking changes in 0.12.0


# Verify installation
import jax
import flax
print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"TPU devices: {len(jax.devices())}")

JAX version: 0.8.0
Flax version: 0.12.0
JAX backend: tpu
TPU devices: 8


In [None]:
# Cell 2: Setup paths
import sys
import os

# Add uploaded repo to Python path
# Adjust 'momask-codes' to match your dataset name on Kaggle
REPO_PATH = '/kaggle/input/momask3'
sys.path.insert(0, REPO_PATH)

# Verify imports work
from models.vq.model_jax import RVQVAE
print("Successfully imported RVQVAE model!")

In [None]:
# Cell 3: Configure and run training
import sys

# Set command line arguments for training
# Adjust these paths based on your dataset locations
sys.argv = [
    'train_vq_jax.py',
    '--dataset_name', 't2m',
    '--batch_size', '256',
    '--max_epoch', '50',
    '--name', 'rvq_kaggle_tpu',
    '--checkpoints_dir', '/kaggle/working/checkpoints',
    # Uncomment and adjust if you have HumanML3D uploaded:
    # '--data_root', '/kaggle/input/humanml3d/HumanML3D',
]

# Import and run main
from train_vq_jax import main
main()

In [None]:
# Cell 4 (Optional): Save outputs to Kaggle
import shutil
import os

# Copy checkpoints to output directory for download
output_dir = '/kaggle/working'
checkpoint_dir = '/kaggle/working/checkpoints'

if os.path.exists(checkpoint_dir):
    print(f"Checkpoints saved in: {checkpoint_dir}")
    for f in os.listdir(checkpoint_dir):
        print(f"  - {f}")