In [None]:
import torch
import numpy as np
from PIL import Image

import sys
sys.path.append('../')

from input_pipeline import PairsDataset
from networks.generator import Generator
from networks.encoder import ResNetEncoder

# Load models

In [None]:
G = Generator(2, 3).eval().cuda()
G.load_state_dict(torch.load('../models/run00_generator.pth'))

E = ResNetEncoder(3, 8).eval().cuda()
E.load_state_dict(torch.load('../models/run00_encoder.pth'))

# Load images

In [None]:
DATA = '/home/dan/datasets/edges2shoes/train/'
dataset = PairsDataset(folder=DATA, size=256, is_training=True)

A, B = dataset[990]

# Show images

In [None]:
edges = A.permute(1, 2, 0).numpy()[:, :, 0]
mask = A.permute(1, 2, 0).numpy()[:, :, 1]
image = B.permute(1, 2, 0).numpy()

edges = (edges * 255).astype('uint8')
mask = (mask * 255).astype('uint8')
image = (image * 255).astype('uint8')

Image.fromarray(255 - image)

In [None]:
Image.fromarray(edges)

In [None]:
Image.fromarray(mask)

# Encode

In [None]:
A2 = A.unsqueeze(0).cuda()
B2 = B.unsqueeze(0).cuda()

with torch.no_grad():
    mean, logvar = E(B2)
    std = logvar.mul(0.5).exp()
    z = torch.randn(1, 8).cuda()
    B_restored = G(A2, mean + z * std)

In [None]:
image = 1.0 - B_restored.squeeze(0).permute(1, 2, 0).cpu().numpy()
image = (image * 255).astype('uint8')
Image.fromarray(image)

# Generate

In [None]:
with torch.no_grad():
    z = torch.randn(1, 8).cuda()
    B_generated = G(A2, z)

In [None]:
image = 1.0 - B_generated.squeeze(0).permute(1, 2, 0).cpu().numpy()
image = (image * 255).astype('uint8')
Image.fromarray(image)