In [1]:
import os
import glob
import sys
import numpy as np
import pickle
import tensorflow as tf
import PIL
import ipywidgets
import io
import h5py
from keras.models import load_model
from keras.applications.mobilenet import preprocess_input

""" make sure this notebook is running from root directory """
while os.path.basename(os.getcwd()) in ('notebooks', 'src'):
    os.chdir('..')
assert ('README.md' in os.listdir('./')), 'Can not find project root, please cd to project root before running the following code'

import src.tl_gan.generate_image as generate_image
import src.tl_gan.feature_axis as feature_axis
import src.tl_gan.feature_celeba_organize as feature_celeba_organize

Using TensorFlow backend.


In [2]:
""" load feature directions """
path_feature_direction = './asset_results/stylegan_ffhq_feature_direction_retrained'

pathfile_feature_direction = glob.glob(os.path.join(path_feature_direction, 'feature_direction_*.pkl'))[-1]

with open(pathfile_feature_direction, 'rb') as f:
    feature_direction_name = pickle.load(f)

feature_direction = feature_direction_name['direction']
feature_name = feature_direction_name['name']
num_feature = feature_direction.shape[1]

import importlib
importlib.reload(feature_celeba_organize)
feature_name = feature_celeba_organize.feature_name_celeba_rename
feature_direction = feature_direction_name['direction']* feature_celeba_organize.feature_reverse[None, :]

In [3]:
""" start tf session and load GAN model """

# path to model code and weight
path_pg_gan_code = './src/model/pggan'
path_model = './network-snapshot-013246.pkl'
sys.path.append(path_pg_gan_code)


""" create tf session """
yn_CPU_only = False

if yn_CPU_only:
    config = tf.ConfigProto(device_count = {'GPU': 0}, allow_soft_placement=True)
else:
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

sess = tf.InteractiveSession(config=config)
tf.global_variables_initializer().run()
try:
    with open(path_model, 'rb') as file:
        G, D, Gs = pickle.load(file)
except FileNotFoundError:
    print('before running the code, download pre-trained model to project_root/asset_model/')
    raise

len_z = Gs.input_shapes[0][1]
z_sample = np.random.randn(len_z)
x_sample = generate_image.gen_single_img(z_sample, Gs=Gs)












Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [4]:
def img_to_bytes(x_sample):
    imgObj = PIL.Image.fromarray(x_sample)
    imgByteArr = io.BytesIO()
    imgObj.save(imgByteArr, format='PNG')
    imgBytes = imgByteArr.getvalue()
    return imgBytes


In [5]:
model = load_model('model_20180927_032934.h5') #feature extractor










In [6]:
ideal_weight = [(-2.5,2.5) for i in range(num_feature)] #magnitude of transformations along feature axis
ideal_weight[8]=(-1.5,3) #Black Hair
ideal_weight[9]=(-2,2.5) #Blond Hair
ideal_weight[15]=(-1.5,3.5) #Glasses
ideal_weight[16]=(-1,3.5) #Goatee
ideal_weight[20]=(-1.5,2) #Male
ideal_weight[21]=(-3,2.5) #Mouth Open
ideal_weight[24]=(-1,3) #Beard
ideal_weight[31]=(-3.5,2) #Smiling
ideal_weight[39]=(-1,3) #Age

In [7]:
#Generate Random Face
z_sample = np.random.randn(len_z)
x_sample = generate_image.gen_single_img(Gs=Gs)
w_img = ipywidgets.widgets.Image(value=img_to_bytes(x_sample), format='png', 
                                 width=256, height=256,
                                 layout=ipywidgets.Layout(height='256px', width='256px')
                                )

class GuiCallback(object):
    counter = 0
    def __init__(self):
        self.latents = z_sample
        self.feature_direction = feature_direction
        self.feature_lock_status = np.zeros(num_feature).astype('bool')
        self.feature_direction_disentangled = feature_axis.disentangle_feature_axis_by_idx(
            self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))
        
    def random_gen_feature(self, event, idx_feature, direction):
        #Generate face with/without desired feature
        self.latents = np.random.randn(len_z)
        orth = np.zeros(num_feature).astype('bool')
        orth[idx_feature] = np.logical_not(orth[idx_feature])
        self.latents = feature_axis.orthogonalize_one_vector(self.latents, self.feature_direction_disentangled[:, idx_feature])
        self.latents += self.feature_direction_disentangled[:, idx_feature]*ideal_weight[idx_feature][direction]
        self.update_img(idx_feature)       
        
    def random_gen(self, event):
        #Generate random face
        self.latents = np.random.randn(len_z)
        #self.latents = np.zeros(len_z)
        self.update_img()
        

    def modify_along_feature(self, event, idx_feature, step_size=0.01):
        self.latents += self.feature_direction_disentangled[:, idx_feature] * step_size
        self.update_img()

    def set_feature_lock(self, event, idx_feature, set_to=None):
        if set_to is None:
            self.feature_lock_status[idx_feature] = np.logical_not(self.feature_lock_status[idx_feature])
        else:
            self.feature_lock_status[idx_feature] = set_to
        self.feature_direction_disentangled = feature_axis.disentangle_feature_axis_by_idx(
            self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))
    
    def update_img(self, idx_feature=None):
        #print('Distance from mean: {:.2f}'.format(np.linalg.norm(self.latents)))
        x_sample = generate_image.gen_single_img(z=self.latents, Gs=Gs)       
        x_byte = img_to_bytes(x_sample)
        w_img.value = x_byte
        
        x = x_sample[None, :] # batch with 1 image  
        x = x[:, 1::2, 1::2, :] # downsample to 128x128
        x = preprocess_input(x)
        y = model.predict(x)
        if idx_feature is not None:
            if idx_feature in {24, 36, 39}:
                #Reverse direction of feature (e.g. No Beard --> Beard)
                print('{}: {:.2f}'.format(feature_name[idx_feature],-y[0][idx_feature]))
            else:
                print('{}: {:.2f}'.format(feature_name[idx_feature],y[0][idx_feature]))
   

guicallback = GuiCallback()

step_size = 100 #When translating latent vector along feature axes
def create_button(idx_feature, width=96, height=40):
    """ function to built button groups for one feature """
    w_name_toggle = ipywidgets.widgets.ToggleButton(
        value=False, description=feature_name[idx_feature],
        tooltip='{}, Press down to lock this feature'.format(feature_name[idx_feature]),
        layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                 width='{:.0f}px'.format(width),
                                 margin='2px 2px 2px 2px')
    )
    
    w_neg = ipywidgets.widgets.Button(description='-',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))
    w_pos = ipywidgets.widgets.Button(description='+',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))
    w_no = ipywidgets.widgets.Button(description='no',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))
    w_yes = ipywidgets.widgets.Button(description='yes',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))
    
    w_name_toggle.observe(lambda event: 
                      guicallback.set_feature_lock(event, idx_feature))
    w_neg.on_click(lambda event: 
                     guicallback.modify_along_feature(event, idx_feature, step_size=-1 * step_size))
    w_pos.on_click(lambda event: 
                     guicallback.modify_along_feature(event, idx_feature, step_size=+1 * step_size))
    w_no.on_click(lambda event: guicallback.random_gen_feature(event, idx_feature, 0))
    w_yes.on_click(lambda event: guicallback.random_gen_feature(event, idx_feature, +1))
    
    button_group = ipywidgets.VBox([w_name_toggle, ipywidgets.VBox([ipywidgets.HBox([w_neg, w_pos]), ipywidgets.HBox([w_no, w_yes])])],
                                  layout=ipywidgets.Layout(border='1px solid gray'))
    
    return button_group
  

list_buttons = []
for idx_feature in range(num_feature):
    list_buttons.append(create_button(idx_feature))


yn_button_select = True #Filtered Attributes only
def arrange_buttons(list_buttons, yn_button_select=True, ncol=4):
    num = len(list_buttons)
    if yn_button_select:
        feature_celeba_layout = feature_celeba_organize.feature_celeba_layout
        layout_all_buttons = ipywidgets.VBox([ipywidgets.HBox([list_buttons[item] for item in row]) for row in feature_celeba_layout])
    else:
        layout_all_buttons = ipywidgets.VBox([ipywidgets.HBox(list_buttons[i*ncol:(i+1)*ncol]) for i in range(num//ncol+int(num%ncol>0))])
    return layout_all_buttons
    
guicallback.update_img()
w_button_random = ipywidgets.widgets.Button(description='random face', button_style='success',
                                           layout=ipywidgets.Layout(height='40px', 
                                                               width='168px',
                                                               margin='1px 1px 5px 1px'))
w_button_random.on_click(guicallback.random_gen)

w_box = ipywidgets.HBox([w_img, 
                         ipywidgets.VBox([w_button_random, 
                                         arrange_buttons(list_buttons, yn_button_select=False)])
                        ], layout=ipywidgets.Layout(height='900px', width='700px')
                       )

print('press +/- to adjust feature, toggle feature name to lock the feature')
display(w_box)


press +/- to adjust feature, toggle feature name to lock the feature


HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\x08\x02\x00\x…