# Train a Multiscale VQVAE on the cifar10 dataset

## Imports

In [None]:
import os, torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt

# Import dataset
from data.PD12M import PD12MDataModule

# Import the MultiscaleVQVAE
from model.model import ModelArgs, LitMultiscaleVQVAE


## Dataset

In [None]:
dm = PD12MDataModule(batch_size=4, num_workers=16)
train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

## Create our VQVAE model

In [None]:
# 1. Define the model configuration.
model_args = ModelArgs(
    codebook_size=16384,
    codebook_embed_dim=8,
    codebook_l2_norm=True,
    codebook_show_usage=True,
    commit_loss_beta=0.25,
    entropy_loss_ratio=0.0,
    encoder_ch_mult=[1, 2, 2, 4],
    decoder_ch_mult=[1, 2, 2, 4],
    z_channels=256,
    dropout_p=0.0,
)

# 2. Initialize the Lightning module.
model = LitMultiscaleVQVAE(
    model_args=model_args,
    patch_nums=(1, 2, 4, 8, 16, 32),
    learning_rate=1e-4,
)


## Train our model

In [None]:
# 4. Configure a logger (here, TensorBoardLogger).
logger = TensorBoardLogger(".tensorboard", name="multiscale_vqvae")

checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    filename="model-{epoch:02d}-{step:08d}",  # makes each checkpoint unique
    every_n_train_steps=1000,  # adjust this number as needed
    save_top_k=-1  # save all checkpoints
)

# 5. Initialize the Trainer.
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",
    devices="auto",  # or specify an integer like devices=1 for one GPU
    logger=logger,
    log_every_n_steps=20,  # use this instead of progress_bar_refresh_rate
    callbacks=[checkpoint_callback]
)

# 6. Train the model.
trainer.fit(model, train_loader)

## Visualize reconstructions

In [None]:
# Move model to GPU if not already on it
model.to("cuda")
model.eval()

with torch.no_grad():
    # Get one batch from the evaluation DataLoader and move it to GPU
    eval_batch = next(iter(val_loader))[0].to("cuda")
    scales = (1, 2, 4, 8, 16)
    # Forward pass with multi-scale encoding
    reconstructions, _ = model.vqvae(eval_batch, v_patch_nums=scales)

# Create image grids for the input images
input_grid = make_grid(eval_batch.cpu(), nrow=8, padding=2)

# Set up subplots: one column for the input, one for each scale reconstruction
n_scales = len(scales)
fig, axes = plt.subplots(1, n_scales + 1, figsize=(4 * (n_scales + 1), 6))

axes[0].imshow(input_grid.permute(1, 2, 0))
axes[0].set_title("Input Images")
axes[0].axis("off")

for i, recon in enumerate(reconstructions):
    recon_grid = make_grid(recon.cpu(), nrow=8, padding=2)
    axes[i + 1].imshow(recon_grid.permute(1, 2, 0))
    axes[i + 1].set_title(f"Scale {scales[i]}")
    axes[i + 1].axis("off")

plt.tight_layout()
plt.show()

## Save trained model

In [None]:
# Define the artifacts directory using os.path.join
artifacts_dir = os.path.join(".", ".artifacts")
os.makedirs(artifacts_dir, exist_ok=True)

# Define the file path for saving the model state
model_save_path = os.path.join(artifacts_dir, "vqvae.pt")

# Save the model's state dictionary
torch.save(model.vqvae.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")
