In [None]:
# Import the necessary libraries
import time, warnings
from tabulate import tabulate
import torch, torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR, ExponentialLR, CosineAnnealingLR

# Import the VAE model and functions
import network
from datasets import create_datasets
from evaluation import evaluate
from sampling import sample, plot_reconstruction
from plotting import plot_loss_lr, plot_loss_components

In [None]:
# Hyperparameters
batch_size = 128 # Number of images per update of the network
num_epochs = 50 # One epoch means seeing every image of the training dataset
embedding_dim = 512 
num_embeddings = 512
input_channels = 3  # CIFAR-10 images have 3 color channels
learning_rate = 5e-5 # Determines how drastically the parameters of the network change
output_frequency = 150 # Determines how often the training progress will be logged (in batches)

# Select the device that will be used for training: GPU, if available, otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print('=========================================')

# Put the neural network on the selected device
model = network.VQVAE(input_channels, num_embeddings, embedding_dim)
model.to(device)

# Optimizer selection
optimizer_option = 'adamw'

optimizer = None
if optimizer_option == "adam":
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
elif optimizer_option == "adamw":
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
elif optimizer_option == "rmsprop":
  optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
else:
  optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# Learning rate scheduler parameters
lr_schedule_option = 'cosine'

scheduler = None
if lr_schedule_option == 'step':
  scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
elif lr_schedule_option == 'exponential':
  scheduler = ExponentialLR(optimizer, gamma=0.9)
elif lr_schedule_option == 'cosine':
  scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-5)

# Scaler for AMP
scaler = torch.cuda.amp.GradScaler()

In [None]:
# Create and visualize the datasets for MNIST / CIFAR10
dataset_name = 'MNIST'
train_loader, test_loader = create_datasets(dataset_name, batch_size)

In [None]:
# Suppress user warnings
warnings.filterwarnings("ignore", category=UserWarning)

# We want to plot loss, its components and learning rate at the end of training
train_losses = []
test_losses = []
learning_rates = []
reconstruction_losses = []
codebook_losses = []

# Training loop
for epoch in range(num_epochs):
  print('-----------------------------------------------------------------------------------------------------------------------------')
  model.train()
  losses = []

  # Start of time measurement
  epoch_start_time = time.time()

  for batch_idx, (data, _) in enumerate(train_loader):
    data = data.to(device)
    optimizer.zero_grad()

    with torch.cuda.amp.autocast():
     # Forward pass through the VQ-VAE model
        decoded, codebook_loss = model(data)

        # Compute reconstruction loss
        recon_loss = F.mse_loss(decoded, data)
        total_loss = recon_loss + codebook_loss

    # Backward pass
    scaler.scale(codebook_loss).backward()

    # Optimization step
    scaler.step(optimizer)
    scaler.update()

    # Save the loss value for current batch
    curr_loss = codebook_loss.detach().clone()
    losses.append(curr_loss)

    # Log training loss and current learning rate
    if batch_idx % output_frequency == 0:
      reconstruction_losses.append(recon_loss.detach().clone())
      codebook_losses.append(codebook_loss.detach().clone())
      log = [['Epoch:', f'{epoch + 1:3d}/{num_epochs:3d}', 'Batch:', f'{batch_idx + 1:3d}',
              'Train Loss:', f'{curr_loss:.6f}', 'LR:', scheduler.get_last_lr()[0]]]
      print(tabulate(log, tablefmt="plain"))
    
  # Step the learning rate scheduler
  if scheduler is not None:
    lr = scheduler.get_last_lr()[0]
    scheduler.step()
  else:
    lr = learning_rate
  learning_rates.append(lr)

  # After the epoch, evaluate the accuracy on the test dataset
  mean_loss = evaluate(model, test_loader, total_loss, device)
  test_losses.append(mean_loss)

  # Save the average training loss
  average_loss = torch.stack(losses).mean().item()
  train_losses.append(average_loss)

  # End of time measurement
  elapsed_time = time.time() - epoch_start_time

  # Log as a horizontal table
  headers = ["Epoch", "Mean Test Loss", "Time"]
  data = [[epoch + 1, f"{mean_loss:.4f}", f"{elapsed_time:.2f}s"]]
  print(tabulate(data, headers=headers, tablefmt="fancy_grid"))

  # Generate reconstructed test images after each epoch
  plot_reconstruction(model, test_loader, device, num_samples=5)
  print('-----------------------------------------------------------------------------------------------------------------------------')

# Plot Train Loss & Test Loss & LR
plot_loss_lr(num_epochs, train_losses, test_losses, learning_rates)

# Plot each loss component
plot_loss_components(reconstruction_losses, codebook_losses)

# Sample some VAE-generated images
#sample(model, device, latent_dim, num_samples=50)