In [1]:
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
import torchvision
from torch.backends import cudnn
from data_loader import get_loader
from data_utils import FaceData
from model import G12, G21, D1, D2
import numpy as np

%load_ext autoreload
%autoreload 2

cudnn.benchmark = True 

def merge_images(sources, targets, batch_size=9, k=10):
    _, h, w = sources.shape
    row = int(np.sqrt(batch_size))
    merged = np.zeros([3, row*h, row*w*2])
    for idx, (s, t) in enumerate(zip(sources, targets)):
        i = idx // row
        j = idx % row
        merged[:, i*h:(i+1)*h, (j*2)*h:(j*2+1)*h] = s
        merged[:, i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h] = t
    return merged.transpose(1, 2, 0)

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
config1 = {
'image_size' : 200, 
'g_conv_dim' : 64, 
'd_conv_dim' : 64, 
'num_classes' : 10, 
'train_iters' : 40000,
'batch_size' : 9,
'num_workers' : 2,
'lr' : 0.0002,
'beta1' : 0.5,
'beta2' : 0.999,

'mode'  :'train',
'model_path' :'./models',
'sample_path' :'./samples',
'mnist_path' :'./mnist',
'svhn_path' :'./svhn',
'log_step' : 10,
'sample_step' : 500,

'use_labels': False, 
'use_reconst_loss' : True, 
}

config = dotdict(config1)

#loading data
old = FaceData(image_paths_file='LAG/train/train.txt', young=False)
young = FaceData(image_paths_file='LAG/train/train.txt')

#loading models
g12 = G12(conv_dim=config.g_conv_dim)
state_G12 = torch.load('models/g12-36000.pkl')
g12.load_state_dict(state_G12)

d1 = D1(conv_dim=config.g_conv_dim)
state_D1 = torch.load('models/d1-36000.pkl')
d1.load_state_dict(state_D1)

g21 = G21(conv_dim=config.g_conv_dim)
state_G21 = torch.load('models/g21-36000.pkl')
g21.load_state_dict(state_G21)


plt.switch_backend('agg')
plt.ion()

#plt.figure(figsize=(15,5))
plt.figure()

num_example_imgs = 3
#for i, (img, img_label) in enumerate(young[-(num_example_imgs):]):
for i, (img, img_label) in enumerate(young[-(num_example_imgs+100):]):

    inputs = img.unsqueeze(0)
    inputs = Variable(inputs)
    
    outputs = g12(inputs)
    d1_out = d1(outputs)
    #print(d1_out)
    reverse = g21(outputs)    
    
    pred = outputs[0].data.cpu()
    rev = reverse[0].data.cpu()
    
    img, pred, rev = img.numpy(), pred.numpy(), rev.numpy() 

    #pred = np.around(pred,decimals=2)
    #pred = (pred - np.min(pred))*255/np.max(pred)
    #rev = np.around(rev,decimals=2)
    #rev = rev - np.min(rev)

    #merged = merge_images(pred, rev)

    # img
    plt.subplot(num_example_imgs, 3, i * 3 + 1)
    plt.axis('off')
    plt.imshow(img.transpose(1,2,0))
    
    # pred
    plt.subplot(num_example_imgs, 3, i * 3 + 2)
    plt.axis('off')
    #plt.imshow(merged)
    plt.imshow(pred.reshape(200,200,3))
    
    # rev
    plt.subplot(num_example_imgs, 3, i * 3 + 3)
    plt.axis('off')
    plt.imshow(rev.reshape(200,200,3))
    
    plt.savefig('randomfig.png')
    plt.show()


FileNotFoundError: [Errno 2] No such file or directory: 'models/g12-36000.pkl'