In [None]:
from __future__ import print_function
from collections import namedtuple
import os
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from src.utils import normalize, setup_dataset

%matplotlib inline

def convert(dictionary):
    return namedtuple('GenericDict', dictionary.keys())(**dictionary)
torch.cuda.set_device(0)

In [None]:
############################

DATASET = 'celeba'  # cifar10, imagenet

############################

In [None]:
dtype = torch.FloatTensor

opt=({      
         'nz':128, 
         'nout':128,
         'ngf':64, 
         'ndf': 64, 
         'noise': 'sphere',
         'ngpu': 1,
         'dataset': 'cifar10',
         'image_size':32,
         'workers':0,
         'batch_size': 64,
        })

if DATASET=='cifar10':
    # default params work for cifar10
    pass
elif DATASET== 'imagenet':
    opt['dataset'] = 'imagenet'
    opt['dataroot'] = '/sdh/data/imagenet' # put here your path
elif DATASET=='celeba':
    opt.update({      
             'nz':64, 
             'dataset': 'celeba',
             'dataroot': '/sdh/data/celebA/imgs1', # put here your path
             'image_size':64, 
            })

opt['nc'] =  1 if opt['dataset'] == 'mnist' else 3
opt= convert(opt)

In [None]:
# Load g,e
E_path = 'pretrained/%s_e.pth' % DATASET
G_path = 'pretrained/%s_g.pth' % DATASET

netG = torch.load(G_path).type(dtype)
netE = torch.load(E_path).type(dtype)
    
netG.eval()
netE.eval()

# Sample

In [None]:
z = Variable(dtype(opt.batch_size, opt.nz, 1, 1).normal_(0, 1), volatile=True)
z = normalize(z)

samples = netG(z).data.cpu()

grid = torchvision.utils.make_grid(samples/2 + 0.5, pad_value=1)
grid_PIL = transforms.ToPILImage()(grid)

grid_PIL

# Reconstruct

In [None]:
dataloader = setup_dataset(opt, train=False)
d = dataloader.next()

In [None]:
x = Variable(d[0], volatile=True)

ex = netE(x)
gex = netG(ex)


t = torch.FloatTensor(x.size(0) * 2, x.size(1), x.size(2), x.size(3))
t[0::2] = x.data.cpu()[:]
t[1::2] = gex.data.cpu()[:]


grid = torchvision.utils.make_grid(t/2 + 0.5, pad_value=1, nrow=16)
grid_PIL = transforms.ToPILImage()(grid)


grid_PIL