In [1]:
import projector
import numpy as np
import dnnlib
from dnnlib import tflib
import pickle
import tensorflow as tf
import PIL
import os
import tqdm

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
network_pkl = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl"

In [3]:
def project(network_pkl: str, target_fname: str, outdir: str, save_video: bool, seed: int):
    # Load networks.
    tflib.init_tf({'rnd.np_random_seed': seed})
    print('Loading networks from "%s"...' % network_pkl)
    with dnnlib.util.open_url(network_pkl) as fp:
        _G, _D, Gs = pickle.load(fp)

    # Load target image.
    # files = [f for f in listdir(target_fname) if isfile(join(target_fname, f))]
    # for i in range(len(files)):
    # path = target_fname + files[i]
    target_pil = PIL.Image.open(target_fname)
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil= target_pil.convert('RGB')
    target_pil = target_pil.resize((Gs.output_shape[3], Gs.output_shape[2]), PIL.Image.ANTIALIAS)
    target_uint8 = np.array(target_pil, dtype=np.uint8)
    target_float = target_uint8.astype(np.float32).transpose([2, 0, 1]) * (2 / 255) - 1

    # Initialize projector.
    proj = projector.Projector()
    proj.set_network(Gs)
    proj.start([target_float])

    # Setup output directory.
    os.makedirs(outdir, exist_ok=True)
    # target_pil.save(f'{outdir}/target.png')
    writer = None
    if save_video:
        writer = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=60, codec='libx264', bitrate='16M')

    # Run projector.
    with tqdm.trange(proj.num_steps) as t:
        for step in t:
            assert step == proj.cur_step
            if writer is not None:
                writer.append_data(np.concatenate([target_uint8, proj.images_uint8[0]], axis=1))
            dist, loss = proj.step()
            t.set_postfix(dist=f'{dist[0]:.4f}', loss=f'{loss:.2f}')

    # Save results.
    # PIL.Image.fromarray(proj.images_uint8[0], 'RGB').save(f'{outdir}/proj.png')
    np.savez('out/dlatents.npz', dlatents=proj.dlatents)
    return proj.dlatents
    if writer is not None:
        writer.close()
        


In [4]:
dlatent = project(network_pkl, "images/002_02.png", "out/", False, 303)

Loading networks from "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl"...
Setting up TensorFlow plugin "fused_bias_act.cu": Loading... Done.
Setting up TensorFlow plugin "upfirdn_2d.cu": Loading... Done.
Projector: Computing W midpoint and stddev using 10000 samples...
Projector: std = 10.0058
Projector: Setting up noise inputs...
Projector: Building image output graph...
Projector: Building loss graph...
Projector: Building noise regularization graph...
Projector: Setting up optimizer...
Projector: Preparing target images...
Projector: Initializing optimization state...


100%|███████████████████████████████████████████████████████| 10/10 [00:45<00:00,  4.54s/it, dist=0.5358, loss=2360.51]


In [5]:
dlatent.shape

(1, 18, 512)

In [5]:
tflib.init_tf({'rnd.np_random_seed': 303})
with dnnlib.util.open_url(network_pkl) as fp:
        _G, _D, Gs = pickle.load(fp)

In [13]:
data = np.load('out/dlatents.npz')
dlat = data[data.files[0]]

In [26]:
image_float_expr = tf.cast(Gs.components.synthesis.get_output_for(dlat), tf.float32)
images_uint8_expr = tflib.convert_images_to_uint8(image_float_expr, nchw_to_nhwc=True)[0]
img = PIL.Image.fromarray(tflib.run(images_uint8_expr), 'RGB')

In [27]:
img.show()