In [1]:
import torch
from datasets import load_mnist, get_observation_pixels
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim

BATCH_SIZE = 100
train_loader, test_loader, val_loader = load_mnist(BATCH_SIZE)

from config import load_config
config = load_config("vq_vae")

In [2]:
from models import VQVAE
import pytorch_lightning as pl
from trainers import BaseTrainer

class VAETrainer(BaseTrainer):
    def __init__(self, num_embeddings, embedding_dim):
        model = VQVAE(num_embeddings, embedding_dim)
        super(VAETrainer, self).__init__(model)
        
    def forward(self, x, x_cond, y):
        return self.model(x)
    
    def step(self, batch, batch_idx, mode = 'train'):
        x, x_cond, y = batch
        x_hat, quantized, latent, embedding_indices = self(x, x_cond, y)
        loss = self.model.loss(latent, quantized, x_hat, x)
        self.log_dict({f"{mode}_{key}": val.item() for key, val in loss.items()}, sync_dist=True, prog_bar=True)
        return loss['loss']
    
    def decode(self, z):
        return self.model.decode(z)

In [4]:
model = VAETrainer(**config['model_params'])

from trainers import SuperTrainer

trainer = SuperTrainer(**config['trainer_params'])
trainer.fit(model, train_loader, val_loader)
#save
trainer.save_model_checkpoint()

