In [1]:
import os
import glob
import sys
import numpy as np
import pickle
from sklearn.cluster import KMeans
import tensorflow as tf
import PIL
import ipywidgets
import io
import h5py
from keras.models import load_model
from keras.applications.mobilenet import preprocess_input

""" make sure this notebook is running from root directory """
while os.path.basename(os.getcwd()) in ('notebooks', 'src'):
    os.chdir('..')
assert ('README.md' in os.listdir('./')), 'Can not find project root, please cd to project root before running the following code'

import src.tl_gan.generate_image as generate_image
import src.tl_gan.feature_axis as feature_axis
import src.tl_gan.feature_celeba_organize as feature_celeba_organize

""" load feature directions """
path_feature_direction = './asset_results/stylegan_ffhq_feature_direction_retrained'

pathfile_feature_direction = glob.glob(os.path.join(path_feature_direction, 'feature_direction_*.pkl'))[-1]

with open(pathfile_feature_direction, 'rb') as f:
    feature_direction_name = pickle.load(f)

feature_direction = feature_direction_name['direction']
feature_name = feature_direction_name['name']
num_feature = feature_direction.shape[1]

import importlib
importlib.reload(feature_celeba_organize)
feature_name = feature_celeba_organize.feature_name_celeba_rename
feature_direction = feature_direction_name['direction']* feature_celeba_organize.feature_reverse[None, :]
""" start tf session and load GAN model """

# path to model code and weight
path_pg_gan_code = './src/model/pggan'
path_model = './network-snapshot-013246.pkl'
sys.path.append(path_pg_gan_code)


""" create tf session """
yn_CPU_only = False

if yn_CPU_only:
    config = tf.ConfigProto(device_count = {'GPU': 0}, allow_soft_placement=True)
else:
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

sess = tf.InteractiveSession(config=config)
tf.global_variables_initializer().run()
try:
    with open(path_model, 'rb') as file:
        G, D, Gs = pickle.load(file)
except FileNotFoundError:
    print('before running the code, download pre-trained model to project_root/asset_model/')
    raise

def img_to_bytes(x_sample):
    imgObj = PIL.Image.fromarray(x_sample)
    imgByteArr = io.BytesIO()
    imgObj.save(imgByteArr, format='PNG')
    imgBytes = imgByteArr.getvalue()
    return imgBytes

model = load_model('model_20180927_032934.h5') #feature extractor

Using TensorFlow backend.













Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where









In [2]:
#features = [8,9,20,21,31,39]
features = [20]
#Beard, Glasses not possible
pm = [-1,1]
centers = np.empty((40,30,2,512)) #no. of features, max no. of clusters, with/without attribute (0,1), dim=512
len_z = Gs.input_shapes[0][1]
for a in features:
    for sign in pm:
        print(feature_name[a])
        print("direction:" + str(sign))
        temp = np.empty((512,0))
        for i in range(5000):
            if i%100==99:
                print(i+1)
            z_sample = np.random.randn(len_z)
            z_arr = np.ndarray.flatten(z_sample)
            x_sample = generate_image.gen_single_img(z=z_arr, Gs=Gs)
            x_byte = img_to_bytes(x_sample)
            x = x_sample[None, :] # batch with 1 image  
            x = x[:, 1::2, 1::2, :] # downsample to 128x128
            x = preprocess_input(x)
            y = model.predict(x)
            if (y[0][a]*sign>=0.8): #scale is from -1 to 1
                temp = np.append(temp, z_arr)
        temp = np.reshape(temp,(-1,512))
        if a == 20:
            kmeans = KMeans(n_clusters=30).fit(temp)
            for i in range(30):
                centers[a][i][(sign+1)//2]=kmeans.cluster_centers_[i]
        else:
            kmeans = KMeans(n_clusters=10).fit(temp)
            for i in range(10):
                centers[a][i][(sign+1)//2]=kmeans.cluster_centers_[i]
        
'''
8: Black Hair
9: Blond Hair
15: Glasses
16: Goatee
20: Male
21: Mouth Open
24: Beard
31: Smiling
39: Age
'''

Male
direction:-1
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
Male
direction:1
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000


'\n8: Black Hair\n9: Blond Hair\n15: Glasses\n16: Goatee\n20: Male\n21: Mouth Open\n24: Beard\n31: Smiling\n39: Age\n'

In [5]:
## Generate Random Face
z_sample = np.random.randn(len_z)
x_sample = generate_image.gen_single_img(Gs=Gs)
w_img = ipywidgets.widgets.Image(value=img_to_bytes(x_sample), format='png', 
                                 width=256, height=256,
                                 layout=ipywidgets.Layout(height='256px', width='256px')
                                )

class GuiCallback(object):
    counter = 0
    def __init__(self):
        self.latents = z_sample
        self.feature_direction = feature_direction
        
    def random_gen_feature(self, event, idx_feature, direction):
        #Generate face with/without desired feature
        if idx_feature == 20: 
            rnd = random.randint(0,29)
        else:
            rnd = random.randint(0,9)
        if idx_feature == 39:
            direction = 1 - direction
        self.latents = centers[idx_feature][rnd][direction]+0.2*np.random.randn(len_z)
        self.update_img(idx_feature)       
        
    def random_gen(self, event):
        self.latents = np.random.randn(len_z)
        self.update_img()
    
    def update_img(self, idx_feature=None):
        #print('Distance from mean: {:.2f}'.format(np.linalg.norm(self.latents)))
        x_sample = generate_image.gen_single_img(z=self.latents, Gs=Gs)       
        x_byte = img_to_bytes(x_sample)
        w_img.value = x_byte
        
        x = x_sample[None, :] # batch with 1 image  
        x = x[:, 1::2, 1::2, :] # downsample to 128x128
        x = preprocess_input(x)
        y = model.predict(x)
        if idx_feature is not None:
            if idx_feature in {24, 36, 39}:
                #Reverse direction of feature (e.g. No Beard --> Beard)
                print('{}: {:.2f}'.format(feature_name[idx_feature],-y[0][idx_feature]))
            else:
                print('{}: {:.2f}'.format(feature_name[idx_feature],y[0][idx_feature]))
   

guicallback = GuiCallback()

def create_button(idx_feature, width=96, height=40):
    """ function to built button groups for one feature """
    w_name_toggle = ipywidgets.widgets.ToggleButton(
        value=False, description=feature_name[idx_feature],
        tooltip='{}, Press down to lock this feature'.format(feature_name[idx_feature]),
        layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                 width='{:.0f}px'.format(width),
                                 margin='2px 2px 2px 2px')
    )
    
    w_no = ipywidgets.widgets.Button(description='no',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))
    w_yes = ipywidgets.widgets.Button(description='yes',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))

    w_no.on_click(lambda event: guicallback.random_gen_feature(event, idx_feature, 0))
    w_yes.on_click(lambda event: guicallback.random_gen_feature(event, idx_feature, +1))
    
    button_group = ipywidgets.VBox([w_name_toggle, ipywidgets.VBox([ipywidgets.HBox([w_yes, w_no])])],
                                  layout=ipywidgets.Layout(border='1px solid gray'))
    
    return button_group
  

list_buttons = []
for idx_feature in range(num_feature):
    list_buttons.append(create_button(idx_feature))

yn_button_select = True #Filtered Attributes only
def arrange_buttons(list_buttons, ncol=4):
    feature_celeba_layout = feature_celeba_organize.feature_celeba_layout
    return ipywidgets.VBox([ipywidgets.HBox([list_buttons[item] for item in row]) for row in feature_celeba_layout])
    
    
guicallback.update_img()
w_button_random = ipywidgets.widgets.Button(description='random face', button_style='success',
                                           layout=ipywidgets.Layout(height='40px', 
                                                               width='168px',
                                                               margin='1px 1px 5px 1px'))
w_button_random.on_click(guicallback.random_gen)

w_box = ipywidgets.HBox([w_img, 
                         ipywidgets.VBox([w_button_random, 
                                         arrange_buttons(list_buttons)])
                        ], layout=ipywidgets.Layout(height='300px', width='700px')
                       )

print('Select yes/no to generate faces with/without the feature')
print('Note: model not yet trained for "eyeglasses" and "beard"')
display(w_box)


Select yes/no to generate faces with/without the feature
Note: model not yet trained for "eyeglasses" and "beard"


HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\x08\x02\x00\x…