In [4]:
# Copyright (c) 2021, InterDigital R&D France. All rights reserved.

# This source code is made available under the license found in the
# LICENSE.txt in the root directory of this source tree.

from __future__ import print_function

import matplotlib.pyplot as plt
%matplotlib inline

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets 

import argparse
import copy
import glob
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import yaml

from PIL import Image
from torchvision import transforms, utils, models

import sys

if os.getcwd().split('/')[-1] == 'notebooks':
    sys.path.append('..')
    os.chdir('..')
    
from datasets import *
from trainer import *
from utils.functions import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)
Image.MAX_IMAGE_PIXELS = None
device = torch.device('cuda')

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='001', help='Path to the config file.')
parser.add_argument('--attr', type=str, default='Eyeglasses', help='attribute for manipulation.')
parser.add_argument('--latent_path', type=str, default='./data/celebahq_dlatents_psp.npy', help='dataset path')
parser.add_argument('--label_file', type=str, default='./data/celebahq_anno.npy', help='label file path')
parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='stylegan model path')
parser.add_argument('--classifier_model_path', type=str, default='./models/latent_classifier_epoch_20.pth', help='pretrained attribute classifier')
parser.add_argument('--log_path', type=str, default='./logs/', help='log file path')
opts = parser.parse_args([])

# Celeba attribute list
attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \
            'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \
            'Blurry': 10, 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \
            'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \
            'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \
            'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \
            'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \
            'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}

ImportError: No module named 'fused'

In [None]:
# Initialize trainer model.
log_dir = os.path.join(opts.log_path, opts.config) + '/'
config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'))

trainer = Trainer(config, None, None, opts.label_file)
trainer.initialize(opts.stylegan_model_path, opts.classifier_model_path)   
trainer.to(device)
print('Load model.')


### Visulization of attribute manipulation

In [None]:
# Set desired attributes for manipulation in attr_list
attr_list = ['Male','Eyeglasses','Young','Smiling']
testdata_dir = './data/test/'

# Load latent transformer models
T_net_dict = {}
for attr in attr_list:
    trainer.attr_num = attr_dict[attr]
    trainer.load_model(log_dir)
    T_net_dict[attr] = copy.deepcopy(trainer.T_net)
    
# Visualization function
def visu_manipulation(seed, **attr_scale):
    with torch.no_grad():
        w_0 = np.load(testdata_dir + 'latent_code_%05d.npy'%int(seed))
        w_0 = torch.tensor(w_0).to(device)
        w_1 = w_0
        for key in attr_scale.keys():
            if attr_scale[key] != 0:
                w_1 = T_net_dict[key](w_1.view(w_0.size(0),-1), torch.tensor(attr_scale[key]).unsqueeze(0).to(device))
        w_1 = w_1.view(w_0.size())
        w_1 = torch.cat((w_1[:,:11,:], w_0[:,11:,:]), 1)
        x_1, _ = trainer.StyleGAN([w_1], input_is_latent=True, randomize_noise=False)
        img = np.clip(clip_img(x_1)[0].cpu().numpy()*255.,0,255).astype(np.uint8)
        img = Image.fromarray(img.transpose(1,2,0))
        plt.figure(figsize=(10,10))
        plt.imshow(img)
        plt.axis('off')
        plt.show()

In [None]:
# User interface
%matplotlib inline
attr_scale = {key: (-1.5,1.5,0.3) for key in attr_list}
interact(visu_manipulation, seed=[0,1,2,3,4,5,6,7,8], **attr_scale)