In [None]:
import sys
sys.path.append('../../')

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from vqvae_cifar10 import VQVAE
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

In [None]:
training_data = datasets.CIFAR10(root='../../data',
                                 train=True,
                                 download=True,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0))
                                 ]))

validation_data = datasets.CIFAR10(root='../../data',
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0))
                                   ]))

training_loader = DataLoader(training_data,
                             batch_size=256,
                             shuffle=True,
                             num_workers=8,
                             pin_memory=True)
validation_loader = DataLoader(validation_data,
                               batch_size=32,
                               shuffle=True,
                               num_workers=8,
                               pin_memory=True)


In [None]:
model = VQVAE.load_from_checkpoint(
    '../../lightning_logs/vqvae_cifar10/version_1/checkpoints/epoch=76-step=14999.ckpt',
    num_hiddens=128,
    num_residual_hiddens=32,
    num_residual_layers=2,
    num_embeddings=512,
    embedding_dim=64,
    commitment_cost=0.25)

## Reconstruction

In [None]:
valid_originals, _ = next(iter(validation_loader))
_, valid_recons, _ = model(valid_originals)

In [None]:
valid_recons.shape

In [None]:
def show(img):
    npimg = img.numpy()
    fig = plt.figure(figsize=(10, 20))
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
    plt.axis('off')

In [None]:
show(make_grid(valid_recons.cpu().data) + 0.5)

In [None]:
show(make_grid(valid_originals.cpu().data) + 0.5)