In [3]:
import os
import vision
from docopt import docopt
from torchvision import transforms
from glow.builder import build
from glow.trainer import Trainer
from glow.config import JsonConfig
import cv2
import random
import torch
import numpy as np

# functions

In [12]:
def run_z(graph, z):
    graph.eval()
    x = graph(z=torch.tensor([z]).cuda(), eps_std=0.3, reverse=True)
    img = x[0].permute(1, 2, 0).detach().cpu().numpy()
    img = img[:, :, ::-1]
    img = cv2.resize(img, (256, 256))
    return img

# load data

In [6]:
hparams = JsonConfig("hparams/celeba1.json")
dataset = vision.Datasets["celeba"]
dataset_root = './dataset/CelebA'
# set transform of dataset
transform = transforms.Compose([
transforms.CenterCrop(hparams.Data.center_crop),
        transforms.Resize(hparams.Data.resize),
        transforms.ToTensor()])
# build
graph = build(hparams, False)["graph"]
dataset = dataset(dataset_root, transform=transform)

[Builder]: Found 1 gpu
[Builder]: cuda:1 is not found, ignore.
[Builder]: cuda:2 is not found, ignore.
[Builder]: cuda:3 is not found, ignore.
[Checkpoint]: Load ./results/celeba/trained.pkg successfully
[Builder]: Use cuda [0] to train, use 0 to load data and get loss.
Begin to parse all image attrs
Find 202599 images, with 40 attrs


In [7]:
# get Z
z_dir = './celeba_z'
if not os.path.exists(z_dir):
    print("Generate Z to {}".format(z_dir))
    os.makedirs(z_dir)
    generate_z = True
    delta_Z = graph.generate_attr_deltaz(dataset)
    for i, z in enumerate(delta_Z):
        np.save(os.path.join(z_dir, "detla_z_{}.npy".format(i)), z)
    print("Finish generating")
else:
    print("Load Z from {}".format(z_dir))
    generate_z = False
    # try to load
    try:
        delta_Z = []
        for i in range(hparams.Glow.y_classes):
            z = np.load(os.path.join(z_dir, "detla_z_{}.npy".format(i)))
            delta_Z.append(z)
    except FileNotFoundError:
        # need to generate
        generate_z = True
        print("Failed to load {} Z".format(hparams.Glow.y_classes))
        quit()   

Load Z from ./celeba_z


In [10]:
print("No. of images to choose from:", len(dataset))
print("No. of attributes to choose from:", len(delta_Z))
print('Available attibutes:',dataset.attrs)

No. of images to choose from: 202599
No. of attributes to choose from: 40
Available attibutes: ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']


In [15]:
base_index = 2
attr_index = 4
attr_name = dataset.attrs[attr_index]
z_delta = delta_Z[attr_index]

graph.eval()
z_base = graph.generate_z(dataset[base_index]["x"])
# begin to generate new image
images = []
names = []
images.append(run_z(graph, z_base))
names.append("reconstruct_origin")
interpolate_n = 5
for i in range(0, interpolate_n+1):
    d = z_delta * float(i) / float(interpolate_n)
    images.append(run_z(graph, z_base + d))
    names.append("attr_{}_{}".format(attr_name, interpolate_n + i))
    if i > 0:
        new_image = run_z(graph, z_base - d)
        images.append(new_image)
        names.append("attr_{}_{}".format(attr_name, interpolate_n - i))
        cv2.namedWindow(str(i))
        cv2.imshow(str(i),new_image)
#save_images(images, names)