<a href="https://colab.research.google.com/github/RubeRad/StyleGAN2-TensorFlow-2.x/blob/master/StyleGAN2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# import the normal stuff
import tensorflow        as tf
import numpy             as np
import matplotlib.pyplot as plt
import os

In [None]:
# DO THIS IF YOU ARE IN COLAB
# Clone the repository with the TensorFlow2 update to StyleGAN2
# Thanks Alberto Rosas!!
%cd /content
!git clone https://github.com/ruberad/StyleGAN2-TensorFlow-2.x.git stylegan2
%cd /content/stylegan2

In [None]:
# import StyleGAN2-specific stuff
from utils.utils_stylegan2 import convert_images_to_uint8
from stylegan2_generator   import StyleGan2Generator

In [None]:
# Check if you have access to a GPU in this (virtual) machine
!nvidia-smi -L
gpuname = tf.test.gpu_device_name()
print('GPU Identified at: "{}"'.format(gpuname))

if gpuname:
  impl = 'cuda'
  gpu = True
  print('yay fast!')
else:
  impl = 'ref'
  gpu = False
  print('aww, slow.')

In [None]:
import gdown
# This is the public URL of the online file with weights that StyleGan2 needs for generating faces
url = 'https://drive.google.com/uc?export=download&confirm=pbef&id=1afMN3e_6UuTTPDL63WHaA0Fb9EQrZceE'
# This is where we want the file to go (path on the virtual machine)
out = 'weights/ffhq.npy'

if os.path.exists(out):
  print('ffhq.npy weights file is present')
else:
  gdown.download(url, out, quiet=False)

In [None]:
weights_name='ffhq' # face generation, high quality
sg2 = StyleGan2Generator(weights=weights_name, impl=impl, gpu=gpu)

In [None]:
seed = 1
rng = np.random.RandomState(seed)
z = rng.randn(1, 512).astype('float32')
z

In [None]:
w = sg2.mapping_network(z)
w

In [None]:
trunc = 0.5
w_avg = sg2.dlatent_avg
w_trunc = w_avg + (w - w_avg) * trunc

out = sg2.synthesis_network(w_trunc)

img = convert_images_to_uint8(out)

In [None]:
#plotting images
fig = plt.figure()
ax = plt.gca()
ax.axis('off')
img_plot = ax.imshow(img)

In [None]:
def generate_and_plot_from_seeds(generator, seeds):
    fig,axs = plt.subplots(3,3, figsize=(15,15))
    for i in range(min(len(seeds),9)):
        print('seed:', seeds[i])
        rng = np.random.RandomState(seeds[i])
        z = rng.randn(1, 512).astype('float32')
        
        w = generator.mapping_network(z)
        trunc = 0.5
        w_avg = generator.dlatent_avg
        w_trunc = w_avg + (w - w_avg) * trunc
        
        out = generator.synthesis_network(w_trunc)
        
        img = convert_images_to_uint8(out)
        axs[i//3,i%3].axis('off')
        img_plot = axs[i//3,i%3].imshow(img)

In [None]:
seeds = []
for i in range(9): 
    seeds.append(rng.randint(1,100000))
generate_and_plot_from_seeds(sg2, seeds)