In [None]:
#set seed
import torch
torch.manual_seed(0)
import numpy as np
np.random.seed(0)

#import
from torch import optim
import torch.nn.functional as F
from torch import autograd
import os
import scipy.misc
from scipy.misc import imsave
from datetime import datetime
import torchvision
from torchvision import transforms

#import modules
from sampler import svhn_sampler
from model import Critic, Generator
from train import vf_wasserstein_distance, save_images


In [None]:
class NormalizeInverse(torchvision.transforms.Normalize):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

In [None]:
train_batch_size = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'
z_dim = 100

root = './'
model_dir_relpath = 'models'
model_run_relpath = 'output_2020_04_27__01_51_05'

model_path = os.path.join(root, model_dir_relpath, model_run_relpath)

In [None]:
critic = torch.load(model_path + 'critic.pt', map_location=torch.device('cpu'))
generator = torch.load(model_path + 'generator.pt', map_location=torch.device('cpu'))

# Generate images

In [None]:
model_output_relpath = os.path.join('output', 'samples')

model_output_path = os.path.join(model_output_path, model_output_relpath)

In [None]:
z = torch.randn(train_batch_size, z_dim, device = device)
z = autograd.Variable(z, requires_grad=False)
samples = generator(z)

In [None]:
unorm = NormalizeInverse(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
samples_norm = torch.stack([unorm(x) for x in samples])

In [None]:
#create directory if not exist
samples_norm = samples_norm.data.numpy()
os.makedirs(model_output_path, exist_ok=True)

save_images(
    samples_norm,
    os.path.join(model_output_path, 'samples.png')
)

# Analyze images from perturbations in z space

In [None]:
model_output_relpath = os.path.join('output', 'latent_space_variations')

model_output_path = os.path.join(model_output_path, model_output_relpath)

In [None]:
image_num = 0
eps_list = [-7.5, -5, -2.5, 0, 2.5, 5, 7.5]
for dim in range(z_dim):
    
    z_img_alleps = []
    for eps in eps_list:
        torch.manual_seed(0)
        z_all = torch.randn(train_batch_size, z_dim, device = device)
        z_img = z_all[image_num, :]
        z_img[dim] += eps
        z_img_alleps.append(z_img)
    if len(z_img_alleps) == 0:
        z_img_alleps = z_img_alleps.unsqueeze(0)
    else:
        z_img_alleps = torch.stack(z_img_alleps)
    z_img_alleps = autograd.Variable(z_img_alleps, requires_grad=False)
    
    samples = []
    for z in z_img_alleps:
        test = generator(z.unsqueeze(0))
        torch.manual_seed(0)
        samples.append(generator(z.unsqueeze(0)).squeeze())
    if len(samples) == 0:
        samples = samples.unsqueeze(0)
    else:
        samples = torch.stack(samples)
    
    samples_norm = []
    for s in samples:
        samples_norm.append(unorm(s))
    if len(samples_norm) == 0:
        samples_norm = samples_norm.unsqueeze(0)
    else:
        samples_norm = torch.stack(samples_norm)
        
    samples_norm = samples_norm.data.numpy()
    
    ### Output ###
    os.makedirs(model_output_path, exist_ok=True)
    save_images(
        samples_norm,
        os.path.join(model_output_path, 'samples_dim_{}.png'.format(dim))
    )