In [None]:
from model.generator import G
from model.id_encoder import IDEncoder
from model.attr_encoder import AttrEncoder
from model.stylegan import StyleGAN_G
from model.model import Network
from data_loader.data_loader import DataLoader
import utils
import tensorflow as tf
import os
from pathlib import Path 
import matplotlib.pyplot as plt
import numpy as np

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

id_model_path = "./pretrained/vggface2.h5"
stylegan_G_synthesis_path = "./pretrained/stylegan_G_256x256_synthesis/stylegan_G_256x256.h5"
landmarks_model_path = "./pretrained/face_utils/keypoints"
face_detection_model_path = "./pretrained/face_utils/detector"
arcface_model_path = "./pretrained/arcface_weights/weights-b"

class Args(object):
    def __init__(self):
        self.resolution = 256
        self.load_checkpoint = False
        self.train = True
        self.dataset_path = Path("./dataset")
        self.train_data_size = 50000
        self.batch_size = 6
        self.reals = False
        self.test_real_attr = True
        self.train_real_attr = False


args = Args()
g_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)

stylegan_G = StyleGAN_G(resolution=256, truncation_psi=0.7)
stylegan_G.built = True
stylegan_G.load_weights(stylegan_G_synthesis_path, by_name=True)

# generator = G(args, id_model_path, stylegan_G, landmarks_model_path, face_detection_model_path, arcface_model_path)

# z = tf.random.normal((6, 512))
# sp = tf.zeros((6, 9984))
# w = stylegan_G.model_mapping(z)
# images, style_list = stylegan_G.model_synthesis([w, sp])
# images = (images + 1) / 2

# pixel_loss_func = tf.keras.losses.MeanAbsoluteError(tf.keras.losses.Reduction.SUM)

embedding_network = Network(args=args, id_net_path=id_model_path, base_generator=stylegan_G, phase="embedding", 
                            landmarks_net_path=landmarks_model_path,
                            face_detection_model_path=face_detection_model_path, 
                            test_id_net_path=arcface_model_path)

In [None]:
from mpl_toolkits import mplot3d
import collections

import face_alignment
import matplotlib.pyplot as plt
from skimage import io


faN = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False, device="cpu")

input = io.imread("./dataset/dataset_256/image/00000/00000.png")
preds = faN.get_landmarks(input)[-1]

In [None]:
preds[1]

In [None]:
max(preds[:,0]), max(preds[:,1]), max(preds[:,2])

In [None]:
min(preds[:,0]), min(preds[:,1]), min(preds[:,2])

In [None]:
preds[30]

In [None]:
plot_style = dict(marker='o', markersize=2, linestyle='-', lw=2)

pred_type = collections.namedtuple('prediction_type', ['slice', 'color'])
pred_types = {'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)),
              'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)),
              'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)),
              'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)),
              'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)),
              'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)),
              'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)),
              'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)),
              'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4))
              }

fig = plt.figure(figsize=plt.figaspect(.5))
ax = fig.add_subplot(1, 2, 1)
ax.imshow(input)

for pred_type in pred_types.values():
    ax.plot(preds[pred_type.slice, 0],
            preds[pred_type.slice, 1],
            color=pred_type.color, **plot_style)

ax.axis('off')

ax = fig.add_subplot(1, 2, 2, projection='3d')
surf = ax.scatter3D(preds[:, 2],
                  preds[:, 0],
                  -preds[:, 1] * 1.2,
                  c='cyan',
                  alpha=1.0,
                  edgecolor='b')

for pred_type in pred_types.values():
    ax.plot3D(preds[pred_type.slice, 2],
              preds[pred_type.slice, 0],
              -preds[pred_type.slice, 1] * 1.2, color='blue')

#ax.view_init(elev=45., azim=45.)
#ax.set_zlim(ax.get_zlim()[::-1])
ax.set_ylim(ax.get_ylim()[::-1])


ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')

plt.show()

In [None]:
import tensorflow as tf
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

initializer = tf.keras.initializers.Orthogonal()
values = initializer(shape=(3, 3))


In [None]:
trainable = tf.Variable(values)
trainable

In [None]:
with tf.GradientTape(persistent=True) as g_tape:
    trainable = tf.Variable(values)
    x1 = trainable * 2.

g_tape.gradient(x1, trainable)

In [None]:
tf.Variable(values)[0], tf.Variable(values)[1]

In [None]:
tf.tensordot(values[0], tf.transpose(values[1]), 1)

In [None]:
id_img_list = []
for i in range(6):
    id_img_list.append(plt.imread(f"./output/exp_11/weights/step_id_img{i}.png")[None, ...])

id_img = tf.concat(id_img_list, 0)

attr_img_list = []
for i in range(6):
    attr_img_list.append(plt.imread(f"./output/exp_11/weights/step_attr_img{i}.png")[None, ...])

attr_img = tf.concat(attr_img_list, 0)

id_z_matching = np.load("./output/exp_11/weights/id_z_matching.npy")
id_z_matching = tf.concat(id_z_matching, 0)

In [None]:
plt.imshow(id_img[1])

In [None]:
import pickle

with open("./output/exp_11/weights/step_noise", "rb") as f:
    step_noise = pickle.load(f)

with open("./output/exp_11/weights/step_grads", "rb") as f:
    step_grads = pickle.load(f)

In [None]:
stylegan_G.model_synthesis.get_layer(f'G_synthesis/noise{0}').get_weights()[0]

In [None]:
for i in range(len(step_noise)):
    print(i, step_noise[i].numpy().shape)
    stylegan_G.model_synthesis.get_layer(f'G_synthesis/noise{i}').set_weights([step_noise[i].numpy()])

In [None]:
dataloader = DataLoader(args)
id_img, id_z_matching, attr_img, attr_img_indices = dataloader.get_batch(is_train=True, is_cross=True)

In [None]:
from utils import general_utils as utils

sigma = int(80 * (256 / 256))
pixel_mask = utils.inverse_gaussian_image(256, sigma)
pixel_mask = tf.broadcast_to(pixel_mask, [6, *pixel_mask.shape])

In [None]:
id_embedding = generator.id_encoder(id_img)
src_landmarks = generator.landmarks(attr_img)

attr_embedding = generator.attr_encoder(attr_img)

In [None]:
id_embedding.shape

In [None]:
attr_embedding.shape

In [None]:
id_embedding = generator.id_encoder(id_img)
src_landmarks = generator.landmarks(attr_img)

attr_embedding = generator.attr_encoder(attr_img)

z_tag = tf.concat([id_embedding, attr_embedding], -1)
clatents = generator.reference_mapping(z_tag)

gen_img, style_list, _ = generator.stylegan_s(id_z_matching, clatents[:,0,:])

# Move to roughly [0,1]
gen_img = (gen_img + 1) / 2
gen_img = tf.clip_by_value(gen_img, 0, 1)

# Identity loss
gen_img_id_embedding = generator.id_encoder(gen_img)
id_loss = tf.reduce_mean(tf.keras.losses.MAE(gen_img_id_embedding, tf.stop_gradient(id_embedding)))

# Landmark loss
try:
    dst_landmarks = generator.landmarks(gen_img)
except Exception as e:
    dst_landmarks = None

if dst_landmarks is None or src_landmarks is None:
    landmarks_loss = 0

else:
    landmarks_loss = 0.01 * tf.reduce_mean(tf.keras.losses.MSE(src_landmarks, dst_landmarks))

# Pixel loss
l1_loss = pixel_loss_func(attr_img, gen_img, sample_weight=pixel_mask)
mssim = tf.reduce_mean(1 - tf.image.ssim_multiscale(attr_img, gen_img, 1.0))
pixel_loss = 0.02 * (0.84 * mssim + 0.16 * l1_loss)

# Total loss
total_loss = id_loss + landmarks_loss + pixel_loss

In [None]:
np.any(np.isnan(dst_landmarks))

In [None]:
plt.imshow(id_img[1])

In [None]:
plt.imshow(attr_img[1])

In [None]:
plt.imshow(gen_img[1])

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def show_landmarks(image, landmarks, bbox=None, retuire_bbox=False):
    fig, ax = plt.subplots()
    ax.imshow(image)
    ax.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    if retuire_bbox:
        bbox = bbox[0]
        rect = patches.Rectangle((bbox[0], bbox[3] - (bbox[3] - bbox[1])), bbox[2]- bbox[0], bbox[3] - bbox[1], linewidth=1, edgecolor='g', facecolor='none')
        ax.add_patch(rect)
    plt.pause(0.001)