In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch

# from train_orig_multiple_codebooks import VQVAE, get_batch
from train_orig_multiple_codebooks import VQVAE, get_batch
from config import generate_params
# from dataset_generator import generate_random_image
from utils import seed_everything

from torchvision.transforms import ToTensor, ToPILImage
from torchvision.utils import make_grid

seed = 111
seed_everything(seed)

to_tensor = ToTensor()
to_pil = ToPILImage()

In [3]:
snapshot_dir = 'snapshots'
train_name = 'overfit_1codebook'
snapshot_path = os.path.join(snapshot_dir, train_name, 'snapshot.tar')

snapshot = torch.load(snapshot_path, map_location='cpu')

model_params = dict(
    input_channels=3,
    n_hid=64,
    n_init=32,
    num_codebooks=1,
    codebook_size=2 ** 14,
    embedding_dim=32
)
model = VQVAE(**model_params)
model.load_state_dict(snapshot['model'])
model.eval()

for p in model.parameters():
    p.requires_grad = False

In [4]:
images = get_batch(batch_size=4, transform=to_tensor, **generate_params)

model.train()
x_hat, _, _ = model(images)

grid = torch.cat([images, x_hat], dim=0)
grid = make_grid(grid, nrow=4)
grid = to_pil(grid)
grid.save('tmp_train_mode.png')

model.eval()
x_hat, _, _ = model(images)

grid = torch.cat([images, x_hat], dim=0)
grid = make_grid(grid, nrow=4)
grid = to_pil(grid)
grid.save('tmp_eval_mode.png')

In [5]:
model.train()

images_in = model.recon_loss.inmap(images)
x_hat_in, _, _ = model(images_in)

grid = torch.cat([images_in, x_hat_in], dim=0)
grid = make_grid(grid, nrow=4)
grid = to_pil(grid)
grid.save('tmp_train_mode_in.png')

images_out = model.recon_loss.unmap(images_in)
x_hat_out = model.recon_loss.unmap(x_hat_in)

grid = torch.cat([images_out, x_hat_out], dim=0)
grid = make_grid(grid, nrow=4)
grid = to_pil(grid)
grid.save('tmp_train_mode_out.png')