# StyleGAN Anime Sliders
This notebook demonstrate how to learn and extract controllable directions from [ThisAnimeDoesNotExist](https://thisanimedoesnotexist.ai/). This takes a pretrained StyleGAN and uses [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) to extract various labels from a number of samples. It then uses those labels to learn various attributes which are controllable with sliders.
Credits of the modeling goes to the Arfa and the Tensorfork community.

Topics covered include:
1. Generating images
1. Tagging images and learning directions from the latent space
1. Style Transfer
1. How to Project Images to the Latent Space
1. Steering with CLIP (coming soon) 

This is a work in progress. Stay tuned for more updates to the colab.

By Aaron Gokaslan ([Skyli0n](https://twitter.com/SkyLi0n)) 2021


In [None]:
%tensorflow_version 1.x
import tensorflow as tf
!pip install -U moviepy
!pip install cleanlab
import proglog
proglog.notebook()

# Download the code
%cd /content/
!pip install git+https://github.com/KichangKim/DeepDanbooru --no-deps

!mkdir -p /content/models
%cd /content/models
!wget -nc https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20200101-sgd-e30/deepdanbooru-v3-20200101-sgd-e30.zip -O deepdanbooru.zip
# Try the V4 version 
#!wget -nc https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip -O deepdanbooru.zip
!unzip -n deepdanbooru.zip
%cd /content

!git clone https://github.com/shawwn/stylegan2 -b estimator /content/stylegan2 --depth 1
%cd /content/stylegan2

!nvcc test_nvcc.cu -o test_nvcc -run

print('Tensorflow version: {}'.format(tf.__version__) )
!nvidia-smi -L
print('GPU Identified at: {}'.format(tf.test.gpu_device_name()))

## Load the StyleGAN model

In [None]:
%cd /content/stylegan2
import argparse
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import re
import sys
from io import BytesIO
import IPython.display
import numpy as np
from math import ceil
from PIL import Image, ImageDraw
import imageio

import pretrained_networks

import hashlib 

from google_drive_downloader import GoogleDriveDownloader as gdd


url = 'https://drive.google.com/open?id=1A-E_E32WAtTHRlOzjhhYhyyBDXLJN9_H'
model_id = url.replace('https://drive.google.com/open?id=', '')

network_pkl = '/content/models/model_%s.pkl' % model_id#(hashlib.md5(model_id.encode()).hexdigest())
gdd.download_file_from_google_drive(file_id=model_id, dest_path=network_pkl)


print('Loading networks from "%s"...' % network_pkl)
_G, _D, Gs = pretrained_networks.load_networks(network_pkl)
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]

## Display Utility


In [None]:
from IPython import display as ipythondisplay
import io
import os
import base64
import moviepy
from IPython.display import HTML

def show_video(vid):
  ext = os.path.splitext(vid)[-1][1:]
  height = 400
  tmp_video_file = os.path.join(os.path.dirname(vid), 'tmp_' + os.path.basename(vid))
  with moviepy.editor.VideoFileClip(vid, target_resolution=(height, None)) as clip:
    clip.write_videofile(tmp_video_file, preset='veryslow')
  video = io.open(vid, 'r+b').read()
  os.remove(tmp_video_file)
  ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
              loop controls style="height: {2}px;">
              <source src="data:video/{1}';base64,{0}" type="video/{1}" />
              </video>'''.format(base64.b64encode(video).decode('ascii'), ext, height)))

## Misc Utilities

In [None]:
# Useful utility functions...

# Generates a list of images, based on a list of latent vectors (Z), and a list (or a single constant) of truncation_psi's.
def generate_images_in_w_space(dlatents, truncation_psi, randomize_noise=False, show_progress=True):
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = randomize_noise
    Gs_kwargs.truncation_psi = truncation_psi
    dlatent_avg = Gs.get_var('dlatent_avg') # [component]

    if not show_progress:
      log_progress_ = lambda args, **kwargs: args
    else:
      log_progress_ = log_progress

    imgs = []
    for row, dlatent in log_progress_(enumerate(dlatents), name = "Generating images"):
        row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(truncation_psi, [-1, 1, 1]) + dlatent_avg
        row_images = Gs.components.synthesis.run(row_dlatents,  **Gs_kwargs)
        imgs.extend([PIL.Image.fromarray(r, 'RGB') for r in row_images])
    return imgs       

def generate_images(zs, truncation_psi, randomize_noise=False, show_progress=True):
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = randomize_noise
    if not isinstance(truncation_psi, list):
        truncation_psi = [truncation_psi] * len(zs)
        
    if not show_progress:
      log_progress_ = lambda args, **kwargs: args
    else:
      log_progress_ = log_progress
    imgs = []
    for z_idx, z in log_progress_(enumerate(zs), size = len(zs), name = "Generating images"):
        Gs_kwargs.truncation_psi = truncation_psi[z_idx]
        noise_rnd = np.random.RandomState(1) # fix noise
        tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
        images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
        for image in images:
          imgs.append(PIL.Image.fromarray(image, 'RGB'))
    return imgs

def generate_zs_from_seeds(seeds):
    zs = []
    for seed_idx, seed in enumerate(seeds):
        rnd = np.random.RandomState(seed)
        z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
        zs.append(z)
    return zs

# Generates a list of images, based on a list of seed for latent vectors (Z), and a list (or a single constant) of truncation_psi's.
def generate_images_from_seeds(seeds, truncation_psi):
    return generate_images(generate_zs_from_seeds(seeds), truncation_psi)

def saveImgs(imgs, location):
  for idx, img in log_progress(enumerate(imgs), size = len(imgs), name="Saving images"):
    file = location+ str(idx) + ".png"
    img.save(file)

def imshow(a, format='png', jpeg_fallback=True):
  a = np.asarray(a, dtype=np.uint8)
  str_file = BytesIO()
  PIL.Image.fromarray(a).save(str_file, format)
  im_data = str_file.getvalue()
  try:
    disp = IPython.display.display(IPython.display.Image(im_data))
  except IOError:
    if jpeg_fallback and format != 'jpeg':
      print ('Warning: image was too large to display in format "{}"; '
             'trying jpeg instead.').format(format)
      return imshow(a, format='jpeg')
    else:
      raise
  return disp

def showarray(a, fmt='png'):
    a = np.uint8(a)
    f = StringIO()
    PIL.Image.fromarray(a).save(f, fmt)
    IPython.display.display(IPython.display.Image(data=f.getvalue()))

        
def clamp(x, minimum, maximum):
    return max(minimum, min(x, maximum))
    
def drawLatent(image,latents,x,y,x2,y2, color=(255,0,0,100)):
  buffer = PIL.Image.new('RGBA', image.size, (0,0,0,0))
   
  draw = ImageDraw.Draw(buffer)
  cy = (y+y2)/2
  draw.rectangle([x,y,x2,y2],fill=(255,255,255,180), outline=(0,0,0,180))
  for i in range(len(latents)):
    mx = x + (x2-x)*(float(i)/len(latents))
    h = (y2-y)*latents[i]*0.1
    h = clamp(h,cy-y2,y2-cy)
    draw.line((mx,cy,mx,cy+h),fill=color)
  return PIL.Image.alpha_composite(image,buffer)
             
  
def createImageGrid(images, scale=0.25, rows=1):
   w,h = images[0].size
   w = int(w*scale)
   h = int(h*scale)
   height = rows*h
   cols = ceil(len(images) / rows)
   width = cols*w
   canvas = PIL.Image.new('RGBA', (width,height), 'white')
   for i,img in enumerate(images):
     img = img.resize((w,h), PIL.Image.ANTIALIAS)
     canvas.paste(img, (w*(i % cols), h*(i // cols))) 
   return canvas

def convertZtoW(latent, truncation_psi=0.7, truncation_cutoff=9):
  dlatent = Gs.components.mapping.run(latent, None) # [seed, layer, component]
  
  dlatent_avg = Gs.get_var('dlatent_avg') # [component]
  for j in range(len(dlatent)):
    for i in range(truncation_cutoff):
      dlatent[j][i] = (dlatent[j][i]-dlatent_avg)*truncation_psi + dlatent_avg
      
  return dlatent

def interpolate(zs, steps):
   out = []
   for i in range(len(zs)-1):
    for index in range(steps):
     fraction = index/float(steps) 
     out.append(zs[i+1]*fraction + zs[i]*(1-fraction))
   return out

# Taken from https://github.com/alexanderkuk/log-progress
def log_progress(sequence, every=1, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )





# Learn the Attribute Labels!

## Uses the DeepDanbooru Classifier ontop of the anime GAN

In [None]:
model = tf.keras.models.load_model('/content/models/model-resnet_custom_v3.h5', compile=True)
#model = tf.keras.models.load_model('/content/models/model-resnet_custom_v4.h5', compile=True)

def load_tags(tags_path):
    with open(tags_path, 'r') as tags_stream:
        tags = [tag for tag in (tag.strip() for tag in tags_stream) if tag]
        return tags

tags = np.asarray(load_tags('/content/models/tags.txt'))  
import cv2
import io
import deepdanbooru.data
import deepdanbooru as dd 
from deepdanbooru.commands.evaluate import evaluate_image


## Generate the labels
This next part generates a bunch of images and tags them with the DeepDanbooru classifier. This is done in two seperate ways. We can either map the Z or the W latents to the binary attribute labels. The W latents are the "style" latents of stylegan while Z is the more classical GAN latent variables. 

In [None]:
  
tags = np.asarray(load_tags('/content/models/tags.txt'))  
import cv2
import deepdanbooru as dd
from tqdm.auto import trange
DD_INPUT_SIZE = 512

def generate_and_tag(rnd):
    z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
    image_orig = generate_images([z], 0.5, randomize_noise=True, show_progress=False)[0]#(seeds, 0.5)[0]
    image = np.asarray(image_orig)
    image = cv2.resize(image, (DD_INPUT_SIZE, DD_INPUT_SIZE))
    image = np.squeeze(image)

    image = dd.image.transform_and_pad_image(image, DD_INPUT_SIZE, DD_INPUT_SIZE)

    image = np.expand_dims(image, 0)
    prediction = model.predict(image)[0]
    tag_ids = np.where(prediction > 0.5)

    return z, prediction


def generate_and_tagb(rnd, batch):
    z = rnd.randn(batch, *Gs.input_shape[1:]) # [minibatch, component]

    images = generate_images([z], 0.5, randomize_noise=True, show_progress=False)#(seeds, 0.5)[0]
    images_pred = []
    for image_orig in images:
      image = np.asarray(image_orig)

      image = cv2.resize(image, (DD_INPUT_SIZE, DD_INPUT_SIZE), interpolation=cv2.INTER_AREA)
      image = np.squeeze(image)

      image = dd.image.transform_and_pad_image(image, DD_INPUT_SIZE, DD_INPUT_SIZE)
      image = np.expand_dims(image, 0)
      images_pred.append(image)
  
    predictions = model.predict(np.concatenate(images_pred))

    return z, predictions


def gen_and_tag(n, seed):
    rnd = np.random.RandomState(seed)
    output_x_data = []
    output_y_data = []
    batch = 16
    for i in trange(n//batch):
      zs, predictions = generate_and_tagb(rnd, batch)
      for z, prediction in zip(zs, predictions):
        output_x_data.append(z)
        output_y_data.append(prediction)
    if n % batch != 0:
      zs, predictions = generate_and_tagb(rnd, n % batch)
      for z, prediction in zip(zs, predictions):
        output_x_data.append(z)
        output_y_data.append(prediction)
    
    return np.squeeze(np.array(output_x_data)), np.array(output_y_data)

def generate_and_tag_w(rnd, batch):
    z = rnd.randn(batch, *Gs.input_shape[1:]) # [minibatch, component]
    
    w = convertZtoW(z, truncation_psi=1.0, truncation_cutoff=Gs.components.mapping.output_shape[1])
    images = generate_images_in_w_space(w, 1.0, randomize_noise=True, show_progress=False)
    images_pred = []
    for image_orig in images:
      image = np.asarray(image_orig)

      image = cv2.resize(image, (DD_INPUT_SIZE, DD_INPUT_SIZE), interpolation=cv2.INTER_AREA)
      image = np.squeeze(image)
    
      image = dd.image.transform_and_pad_image(image, DD_INPUT_SIZE, DD_INPUT_SIZE)
      image = np.expand_dims(image, 0)
      images_pred.append(image)
  
    predictions = model.predict(np.concatenate(images_pred))
    return w, predictions


def gen_and_tag_w(n, seed):
    rnd = np.random.RandomState(seed)
    output_x_data = []
    output_y_data = []
    batch = 16
    for i in trange(n//batch):
      zs, predictions = generate_and_tag_w(rnd, batch)
      for z, prediction in zip(zs, predictions):
        output_x_data.append(z)
        output_y_data.append(prediction)
    if n % batch != 0:
      zs, predictions = generate_and_tag_w(rnd, n % batch)
      for z, prediction in zip(zs, predictions):
        output_x_data.append(z)
        output_y_data.append(prediction)
    
    return np.array(output_x_data), np.array(output_y_data)



## Logistic Regression on the Labels and the Colleted Latents

This function trains an L2 logistic regression to try to latent W or Z latents to binary **attributes** of that classified earlier.

In [None]:
from sklearn.linear_model import LogisticRegression, SGDClassifier, LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
from sklearn.metrics import accuracy_score
from cleanlab.classification import LearningWithNoisyLabels

def get_directions_dict(X_data_arr, y_data_arr):
  threshold = 0.5 # The threshold for number of positive classifications needed to learn an attribute
  valid_tags = (np.sum(y_data_arr > threshold, axis=0) > 25).nonzero()[0]
  print('Num of Valid Tags: %d' % len(valid_tags))
  directions_dict = dict()
  clf = SGDClassifier('log', class_weight='balanced', penalty='l2', n_iter_no_change=20, early_stopping=True)
  for j in trange(len(valid_tags)):
    i = valid_tags[j]
    # This decides if there are enough positive classifications to learn the attribute
    if np.min(np.bincount(y_data_arr[:, i] >= threshold)) >= 2: #and not np.all(y_data_arr[:, i] >= threshold) and not np.all(y_data_arr[:, i] < threshold): 
      clf = clf.fit(X_data_arr.reshape(X_data_arr.shape[0], -1), y_data_arr[:,i] >= threshold)
      new_direction = clf.coef_.reshape(X_data_arr[0].shape)
      score = clf.score(X_data_arr.reshape(X_data_arr.shape[0], -1), y_data_arr[:,i] >= threshold)
      if score >= 0.5:
        directions_dict[tags[i]] = new_direction
  return directions_dict

## Learn Z Embedding for Attributes!

In [None]:
# Generates 2500 samples. Reducing it will make this step faster
X_data_arr, y_data_arr = gen_and_tag(2500, 1234)
np.save('X_data', X_data_arr)
np.save('y_data', y_data_arr)

In [None]:
X_data_arr = np.load('X_data.npy')
y_data_arr = np.load('y_data.npy')

In [None]:
attribute_counter = dict()
for t, d in enumerate((y_data_arr > 0.5).sum(axis=0)):#tags[tag_ids]:
  attribute_counter[tags[t]]=d

print(sorted(attribute_counter.items(), key=lambda x:x[1], reverse=True))

continuous_attribute_counter = dict()
for t, d in enumerate(y_data_arr.sum(axis=0)):#tags[tag_ids]:
  continuous_attribute_counter[tags[t]]=d

print(sorted(continuous_attribute_counter.items(), key=lambda x:x[1], reverse=True))

In [None]:
# This does the actual machine learning
directions_dict = get_directions_dict(X_data_arr, y_data_arr)

In [None]:
import matplotlib.pylab as plt
%matplotlib inline

def move_and_show_row(latent_vector, direction, coeffs, mask, fig, ax, key):
    for i, coeff in enumerate(coeffs):
        new_latent_vector = latent_vector.copy()
        new_latent_vector = ((latent_vector + coeff*direction*mask))
        ax[i].imshow(generate_images([new_latent_vector], 1.0, randomize_noise=False, show_progress=False)[0])
        ax[i].set_title(f'{key} Coeff: %0.1f' % coeff)
    [x.axis('off') for x in ax]

def move_and_show_dir(latent_vector, direction, coeffs, mask):
    fig,ax = plt.subplots(1, len(coeffs), figsize=(15, 10), dpi=80)
    for i, coeff in enumerate(coeffs):
        new_latent_vector = latent_vector.copy()
        new_latent_vector = ((latent_vector + coeff*direction*mask))
        ax[i].imshow(generate_images([new_latent_vector], 1.0, randomize_noise=False, show_progress=False)[0])
        ax[i].set_title('Coeff: %0.1f' % coeff)
    [x.axis('off') for x in ax]
    plt.show()

In [None]:

def normalize(v):
    """Normalizes the vector v to be a unit vector"""
    norm=np.linalg.norm(v, ord=2)
    if norm==0:
        norm=np.finfo(v.dtype).eps
    return v/norm

latent_vector = np.random.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
mask = np.zeros(*Gs.input_shape[1:], dtype=float)
mask[:] = 1 # This is a mask of how much of the latents you want to actually change.
coeffs = [-5.0, -2.5, -1.0, 0, 1.0, 2.5, 5.0] # How strongly you want to apply the learned attribute. 
#Extreme values will distort image quality
print(len(directions_dict))
for key in directions_dict:
    print(key, np.argmax(directions_dict[key]))
    move_and_show_dir(latent_vector, normalize(directions_dict[key]), coeffs, mask)#, fig, ax, key)


In [None]:
import ipywidgets as widgets
categories = list(directions_dict.keys())
category_picker = widgets.Dropdown(options=categories)
slider = widgets.IntSlider(value=2, max=4)
coeffs = [-5, -1.5, 0, 1.5, 5]


display(category_picker)
display(slider)
out = widgets.Output()

latent_vector = np.random.randn(1, *Gs.input_shape[1:]) # [minibatch, component]


button = widgets.Button(description="Generate!")
button_clear = widgets.Button(description='Clear')
def run_button(*args):
  coeff = coeffs[slider.value]
  global latent_vector
  direction = directions_dict[category_picker.value]
  direction = normalize(direction)
  latent_vector = ((latent_vector + coeff*direction*mask))
  with out:
    image = generate_images([latent_vector], 0.5, randomize_noise=True, show_progress=False)[0]
    IPython.display.clear_output(True)
    imshow(image)
  
button.on_click(run_button)
display(button, button_clear)


## Learn W Embedding for Attributes!

In [None]:
X_data_arr, y_data_arr = gen_and_tag_w(10000, 1234)
print(X_data_arr.shape)

np.save('X_data_w', X_data_arr)
np.save('y_data_w', y_data_arr)


In [None]:
X_data_arr = np.load('X_data_w.npy')
y_data_arr = np.load('y_data_w.npy')

In [None]:
directions_dict = get_directions_dict(X_data_arr, y_data_arr)
print(len(directions_dict))

In [None]:
attribute_counter = dict()
for t, d in enumerate((y_data_arr > 0.5).sum(axis=0)):#tags[tag_ids]:
  attribute_counter[tags[t]]=d

print(sorted(attribute_counter.items(), key=lambda x:x[1], reverse=True))

continuous_attribute_counter = dict()
for t, d in enumerate(y_data_arr.sum(axis=0)):#tags[tag_ids]:
  continuous_attribute_counter[tags[t]]=d

print(sorted(continuous_attribute_counter.items(), key=lambda x:x[1], reverse=True))

When learning the W embedding, the layers used will affect how will a particular attribute transfers. For certain attributes like hair and eye color, only the highest level attributes will be useful. Therefore, you would want to consider mask the last 128 attributes ```mask[-128:]``` for instance. For other attributes like body shape, pose, and scenery, you will only get good results from modifying lower level attributes.


In [None]:
import matplotlib.pylab as plt
%matplotlib inline

def normalize(v):
    norm=np.linalg.norm(v, ord=2)
    if norm==0:
        norm=np.finfo(v.dtype).eps
    return v/norm * len(v)

latent_vector = np.random.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
latent_vector = convertZtoW(latent_vector, truncation_psi=1.0)[0]
mask = np.zeros(latent_vector.shape, dtype=float)

mask[:] = 1 # Which layers you want to actually use. 
print(latent_vector.shape)

def move_and_show_dir_w(latent_vector, direction, coeffs, mask):
    direction=normalize(direction)
    fig,ax = plt.subplots(1, len(coeffs), figsize=(15, 10), dpi=80)
    for i, coeff in enumerate(coeffs):
        new_latent_vector = latent_vector.copy()
        new_latent_vector = ((latent_vector + coeff*direction*mask))
        ax[i].imshow(generate_images_in_w_space(np.expand_dims(new_latent_vector,0), 1.0, randomize_noise=True, show_progress=False)[0])
        ax[i].set_title('Coeff: %0.1f' % coeff)
    [x.axis('off') for x in ax]
    plt.show()

keys = [k for k,v in sorted(continuous_attribute_counter.items(), key=lambda x:x[1], reverse=True) if k in directions_dict]
for key in keys:
    print(key)
    move_and_show_dir_w(latent_vector, directions_dict[key], [-5.0, -2.5, -1.0, 0, 1.0, 2.5, 5.0], mask)

In [None]:
import ipywidgets as widgets
categories = list(directions_dict.keys())
category_picker = widgets.Dropdown(options=categories)
slider = widgets.IntSlider(value=2, max=4)
coeffs = [-5, -1.5, 0, 1.5, 5]


category_picker.value
display
slider.value


latent_vector = np.random.randn(1, *Gs.input_shape[1:]) # [minibatch, component]


out = widgets.Output()

latent_vector =  convertZtoW(latent_vector, truncation_psi=1.0)[0]

mask = np.zeros(latent_vector.shape, dtype=float)
mask[:] = 1
button = widgets.Button(description="Generate Me!")
button_clear = widgets.Button(description='Clear')

def run_button(*args):
  coeff = coeffs[slider.value]
  global latent_vector
  direction = normalize(directions_dict[category_picker.value])
  latent_vector = ((latent_vector + coeff*direction*mask))
  with out:
    image = generate_images_in_w_space(np.expand_dims(latent_vector,0), 1.0, randomize_noise=True, show_progress=False)[0]
    IPython.display.clear_output(True)
    imshow(image)

def truncate(*args):
  global latent_vector
  truncation_psi = 0.75
  dlatent_avg = Gs.get_var('dlatent_avg') # [component]
  latent_vector = (latent_vector - dlatent_avg) * np.reshape(truncation_psi, [1, 1]) + dlatent_avg

  with out:
      image = generate_images_in_w_space(np.expand_dims(latent_vector,0), 1.0, randomize_noise=True, show_progress=False)[0]

      IPython.display.clear_output(True)
      imshow(image)

  
button.on_click(run_button)
button_trunc = widgets.Button(description='Truncate')
button_trunc.on_click(truncate)
display(category_picker, slider, button, button_clear, button_trunc)
out

# Figures Interpolation
This creates various interpolation plots of figures in the space. You can use it to go betweeen two seeds you like (interpolates in the Z space)

In [None]:
# generate some random seeds
size = 5
seeds = np.random.randint(10000000, size=size ** 2)
print(seeds)
truncation_psi = 1.0

# show the seeds
imshow(createImageGrid(generate_images_from_seeds(seeds, truncation_psi), 1 , size))

#Random Walk Video Generation

This generates a nice grid video of multiple values walking around the latent space. 


In [None]:
import scipy

grid_size = [3,3]
image_shrink = 1
image_zoom = 1
duration_sec = 10
smoothing_sec = 2.0 # 1.0
mp4_fps = 20
mp4_codec = 'libx264'
mp4_bitrate = '4M'
random_seed = np.random.randint(0, 999)#405
mp4_file = 'random_grid_%s.mp4' % random_seed
minibatch_size = 16
truncation_psi= 1.0

num_frames = int(np.rint(duration_sec * mp4_fps))
random_state = np.random.RandomState(random_seed)

# Generate latent vectors
shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:] # [frame, image, channel, component]
all_latents = random_state.randn(*shape).astype(np.float32)
all_latents = scipy.ndimage.gaussian_filter(all_latents, [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape), mode='wrap')
all_latents /= np.sqrt(np.mean(np.square(all_latents)))


def create_image_grid(images, grid_size=None):
    assert images.ndim == 3 or images.ndim == 4
    num, img_h, img_w, channels = images.shape

    if grid_size is not None:
        grid_w, grid_h = tuple(grid_size)
    else:
        grid_w = max(int(np.ceil(np.sqrt(num))), 1)
        grid_h = max((num - 1) // grid_w + 1, 1)

    grid = np.zeros([grid_h * img_h, grid_w * img_w, channels], dtype=images.dtype)
    for idx in range(num):
        x = (idx % grid_w) * img_w
        y = (idx // grid_w) * img_h
        grid[y : y + img_h, x : x + img_w] = images[idx]
    return grid

# Frame generation func for moviepy.
def make_frame(t):
    frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
    latents = all_latents[frame_idx]
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    images = Gs.run(latents, None, truncation_psi=truncation_psi,
                          randomize_noise=False, output_transform=fmt, 
                          minibatch_size=minibatch_size)

    grid = create_image_grid(images, grid_size)
    if image_zoom > 1:
        grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0)
    if grid.shape[2] == 1:
        grid = grid.repeat(3, 2) # grayscale => RGB
    return grid

# Generate video.
import moviepy.editor
video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)

In [None]:
# In order to download files, you can use the snippet below - this often fails for me, though, so I prefer the 'Files' browser in the sidepanel.

from google.colab import files
#files.download(mp4_file) 
show_video(mp4_file)


This makes a video where the latent codes are kept constant but the random noise is still applied to the images. This video lets you see the affect of random noise on the output. Like FFHQ, it mostly affects tiny attributes like individual hair strands

In [None]:
shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:] # [frame, image, channel, component]
all_latents = random_state.randn(*shape).astype(np.float32)

def make_frame(t):
    frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
    latents = all_latents[0]
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    images = Gs.run(latents, None, truncation_psi=truncation_psi,
                          randomize_noise=True, output_transform=fmt, 
                          minibatch_size=16)

    grid = create_image_grid(images, grid_size)
    if image_zoom > 1:
        grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0)
    if grid.shape[2] == 1:
        grid = grid.repeat(3, 2) # grayscale => RGB
    return grid

# Generate video.
import moviepy.editor
video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)

In [None]:
# In order to download files, you can use the snippet below - this often fails for me, though, so I prefer the 'Files' browser in the sidepanel.

from google.colab import files
#files.download(mp4_file) 
show_video(mp4_file)

In [None]:
# generating a MP4 movie
import moviepy.editor
zs = generate_zs_from_seeds([421645,6149575,3487643,3766864 ,3857159,5360657,3720613])

number_of_steps = 10
imgs = generate_images(interpolate(zs,number_of_steps), 1,0)

# Example of reading a generated set of images, and storing as MP4.
%mkdir out
movieName = 'out/mov.mp4'

with imageio.get_writer(movieName, mode='I') as writer:
    for image in log_progress(list(imgs), name = "Creating animation"):
        writer.append_data(np.array(image))
show_video(movieName)

In [None]:
# Simple (Z) interpolation
zs = generate_zs_from_seeds([401528 , 614808 ])

latent1 = zs[0]
latent2 = zs[1]

number_of_steps = 25
truncation_psi = 1.0

imgs = generate_images(interpolate([latent1,latent2],number_of_steps), truncation_psi)
number_of_images = len(imgs)
imshow(createImageGrid(imgs, 1 , 5))

#Style Transfer

##Coarse Style Transfer

This shows a style transfer of the coarse features (low numbered) feature between a src image and several target images. You can see how these low level features control e

In [None]:
import scipy

duration_sec = 10.0
smoothing_sec = 1.0
mp4_fps = 20
truncation_psi = 1.0

num_frames = int(np.rint(duration_sec * mp4_fps))
#random_seed = 500
random_seed = np.random.randint(0, 999)#405
random_state = np.random.RandomState(int(random_seed))


h = w = Gs.output_shape[-1]
#src_seeds = [601]
dst_seeds = [501, 702, 707]
num_styles = Gs.components.mapping.output_shape[1]
#style_ranges = ([0] * (num_styles // 2) + [range(num_styles // 2, num_styles)]) * len(dst_seeds)
style_ranges = (list(range(0, num_styles//2))) * len(dst_seeds)
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
synthesis_kwargs = dict(output_transform=fmt, truncation_psi=truncation_psi, minibatch_size=16)

shape = [num_frames] + Gs.input_shape[1:] # [frame, image, channel, component]
src_latents = random_state.randn(*shape).astype(np.float32)
src_latents = scipy.ndimage.gaussian_filter(src_latents,
                                            smoothing_sec * mp4_fps,
                                            mode='wrap')
src_latents /= np.sqrt(np.mean(np.square(src_latents)))

dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)


src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]
src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)


canvas = PIL.Image.new('RGB', (w * (len(dst_seeds) + 1), h * 2), 'white')
    
for col, dst_image in enumerate(list(dst_images)):
    canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), ((col + 1) * h, 0))

def make_frame(t):
    frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
    src_image = src_images[frame_idx]
    canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), (0, h))
    
    for col, dst_image in enumerate(list(dst_images)):
        col_dlatents = np.stack([dst_dlatents[col]])
        col_dlatents[:, style_ranges[col]] = src_dlatents[frame_idx, style_ranges[col]]
        col_images = Gs.components.synthesis.run(col_dlatents, randomize_noise=False, **synthesis_kwargs)
        for row, image in enumerate(list(col_images)):
            canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * h, (row + 1) * w))
    return np.array(canvas)
    
# Generate video.
import moviepy.editor
mp4_file = 'output.mp4'
mp4_codec = 'libx264'
mp4_bitrate = '2M'#8M

video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)

In [None]:
show_video(mp4_file)

##Fine Style Transfer

This code snippit shows fine attribute transfers (defined as the high half of the W latent spaces). These tends 

In [None]:
import scipy

duration_sec = 20.0
smoothing_sec = 1.0
mp4_fps = 20

num_frames = int(np.rint(duration_sec * mp4_fps))
random_seed = 404
random_state = np.random.RandomState(random_seed)


h = Gs.input_shape[-1]
w = h
style_num = Gs.components.mapping.output_shape[1]
style_ranges = [range(style_num//2, style_num)]

fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
synthesis_kwargs = dict(output_transform=fmt, truncation_psi=0.4, minibatch_size=16)

shape = [num_frames] + Gs.input_shape[1:] # [frame, image, channel, component]
src_latents = random_state.randn(*shape).astype(np.float32)
src_latents = scipy.ndimage.gaussian_filter(src_latents,
                                            smoothing_sec * mp4_fps,
                                            mode='wrap')
src_latents /= np.sqrt(np.mean(np.square(src_latents)))

dst_latents = np.stack([random_state.randn(Gs.input_shape[1])])


src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]


def make_frame(t):
    frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
    col_dlatents = np.stack([dst_dlatents[0]])
    col_dlatents[:, style_ranges[0]] = src_dlatents[frame_idx, style_ranges[0]]
    col_images = Gs.components.synthesis.run(col_dlatents, randomize_noise=False, **synthesis_kwargs)
    return col_images[0]
    
# Generate video.
import moviepy.editor
mp4_file = 'fine_%s.mp4' % (random_seed)
mp4_codec = 'libx264'
mp4_bitrate = '8M'

video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)

In [None]:
show_video(mp4_file)

## Corse and Fine Style Transfer

In [None]:
import scipy

duration_sec = 10.0
smoothing_sec = 1.0
mp4_fps = 20
truncation_psi = 0.4

num_frames = int(np.rint(duration_sec * mp4_fps))
random_seed = np.random.randint(0, 999)#405
random_state = np.random.RandomState(int(random_seed))


h = w = Gs.output_shape[-1]
#src_seeds = [601]
dst_seeds = [501, 702, 707]
style_num = Gs.components.mapping.output_shape[1]
style_ranges = [list(range(0, style_num))] * len(dst_seeds) # ([0] * (style_num // 2) + [range(style_num//2,style_num)]) * len(dst_seeds)

fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
synthesis_kwargs = dict(output_transform=fmt, truncation_psi=truncation_psi, minibatch_size=16)

shape = [num_frames] + Gs.input_shape[1:] # [frame, image, channel, component]
src_latents = random_state.randn(*shape).astype(np.float32)
src_latents = scipy.ndimage.gaussian_filter(src_latents,
                                            smoothing_sec * mp4_fps,
                                            mode='wrap')
src_latents /= np.sqrt(np.mean(np.square(src_latents)))

dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)


src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]
src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)


canvas = PIL.Image.new('RGB', (w * (len(dst_seeds) + 1), h * 2), 'white')
    
for col, dst_image in enumerate(list(dst_images)):
    canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), ((col + 1) * h, 0))

def make_frame(t):
    frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
    src_image = src_images[frame_idx]
    canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), (0, h))
    
    for col, dst_image in enumerate(list(dst_images)):
        col_dlatents = np.stack([dst_dlatents[col]])
        col_dlatents[:, style_ranges[col]] = src_dlatents[frame_idx, style_ranges[col]]
        col_images = Gs.components.synthesis.run(col_dlatents, randomize_noise=False, **synthesis_kwargs)
        for row, image in enumerate(list(col_images)):
            canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * h, (row + 1) * w))
    return np.array(canvas)
    
# Generate video.
import moviepy.editor
mp4_file = 'output_all.mp4'
mp4_codec = 'libx264'
mp4_bitrate = '2M'#8M

video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)

In [None]:
show_video(mp4_file)

#Stylemixing

In [None]:
import PIL
from IPython.display import Image, display
truncation_psi=0.7
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
synthesis_kwargs = dict(output_transform=fmt, truncation_psi=truncation_psi, minibatch_size=32)
h = w = Gs.output_shape[-1]

def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges):
    src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
    dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)
    src_dlatents = Gs.components.mapping.run(src_latents, None, truncation_psi=truncation_psi) # [seed, layer, component]
    dst_dlatents = Gs.components.mapping.run(dst_latents, None, truncation_psi=truncation_psi) # [seed, layer, component]
    src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
    dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)

    canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white')
    for col, src_image in enumerate(list(src_images)):
        canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))
    for row, dst_image in enumerate(list(dst_images)):
        canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))
        row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))
        
        row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]
        row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
        for col, image in enumerate(list(row_images)):
            canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))
    canvas.save(png)
    display(Image(png, width=1024))
    

style_num = Gs.components.mapping.output_shape[1]
draw_style_mixing_figure('fig3.png', Gs, w=w, h=h, src_seeds=[634,504,687,606,220], dst_seeds=[406,204,1898,1733,614], style_ranges=[range(0,style_num //3 )]*2+[range(style_num//3,2*style_num//3)]*2+[range(2*style_num//3,style_num)])


#MultiRes Image Visualzation

In [None]:
import PIL
from IPython.display import Image, display

fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
truncation_psi=0.4
synthesis_kwargs = dict(output_transform=fmt, truncation_psi=truncation_psi, minibatch_size=16)
h = w = Gs.output_shape[-1]

def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed):
    print(png)
    latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1])
    images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb]

    canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white')
    image_iter = iter(list(images))
    for col, lod in enumerate(lods):
        for row in range(rows * 2**lod):
            image = PIL.Image.fromarray(next(image_iter), 'RGB')
            image = image.crop((cx, cy, cx + cw, cy + ch))
            image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS)
            canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod))
    canvas.save(png)
    display(Image(png, width=1024))

draw_uncurated_result_figure('fig2.png', Gs, cx=0, cy=0, cw=w, ch=h, rows=3, lods=[0,1,2,2,3,3], seed=23)

In [None]:
# more complex example, interpolating in W instead of Z space.
zs = generate_zs_from_seeds([421645,6149575,3487643,3766864 ,3857159,5360657,3720613 ])

# It seems my truncation_psi is slightly less efficient in W space - I probably introduced an error somewhere...

dls = []
for z in zs:
  dls.append(convertZtoW(z ,truncation_psi=1.0))

number_of_steps = 100

imgs = generate_images_in_w_space(interpolate(dls,number_of_steps), 1.0)

%mkdir out
movieName = 'out/mov.mp4'

with imageio.get_writer(movieName, mode='I') as writer:
    for image in log_progress(list(imgs), name = "Creating animation"):
        writer.append_data(np.array(image))

In [None]:
show_video(movieName)

In [None]:
#from IPython.display import Image, display

def draw_truncation_trick_figure(png, Gs, w=512, h=512, seeds=[91, 81, 388], psis=[1, 0.7, 0.5, 0, -0.5, -1]):
    #print(png)
    latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)
    dlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component]
    dlatent_avg = Gs.get_var('dlatent_avg') # [component]

    canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white')
    for row, dlatent in enumerate(list(dlatents)):
        row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg
        row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
        for col, image in enumerate(list(row_images)):
            canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h))
            
    canvas.save(png)
    IPython.display.display(IPython.display.Image(png, width=1024))
        #PIL.Image(png, width=1024))
h = w = Gs.output_shape[-1]
draw_truncation_trick_figure('output.png', Gs, w=w, h=h, seeds=[901, 888,777], psis=[2.0, 1, 0.7, 0.5, 0, -0.5, -0.7 -1])

# (Bonus) Projecting images onto the generatable manifold

StyleGAN2 comes with a projector that finds the closest generatable image based on any input image. This allows you to get a feeling for the diversity of the portrait manifold.

In [None]:
!mkdir projection
!mkdir projection/imgs
!mkdir projection/out

# Now upload a single image to 'stylegan2/projection/imgs' (use the Files side panel). Image should be color PNG, with a size of 1024x1024.

In [None]:
# Convert uploaded images to TFRecords
import dataset_tool
# You have to upload images for this to work
dataset_tool.create_from_images("./projection/records/", "./projection/imgs/", True)
!rm 'projection/records/-r10.tfrecords'

# Run the projector
import run_projector
import projector
import training.dataset
import training.misc
import os 
import cv2

def project_real_images(dataset_name, data_dir, num_images, num_snapshots):
    proj = projector.Projector()
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = training.dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size=0, verbose=True, repeat=False, shuffle_mb=0)
    print(dataset_obj.shape)
    print(Gs.output_shape)
    assert dataset_obj.shape == Gs.output_shape[1:]

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)
        images = training.misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        run_projector.project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('projection/out/image%04d-' % image_idx), num_snapshots=num_snapshots)

project_real_images("records","./projection",1,100)

In [None]:
# Create video 

import glob

imgs = sorted(glob.glob("projection/out/*step*.png"))

target_imgs = sorted(glob.glob("projection/out/*target*.png"))
assert len(target_imgs) == 1, "More than one target found?"
target_img = imageio.imread(target_imgs[0])

movieName = "projection/movie.mp4"
with imageio.get_writer(movieName, mode='I') as writer:
    for filename in log_progress(imgs, name = "Creating animation"):
        image = imageio.imread(filename)

        # Concatenate images with original target image
        w,h = image.shape[0:2]
        canvas = PIL.Image.new('RGBA', (w*2,h), 'white')
        canvas.paste(Image.fromarray(target_img), (0, 0))
        canvas.paste(Image.fromarray(image), (w, 0))

        writer.append_data(np.array(canvas))  

In [None]:
# Now you can download the video (find it in the Files side panel under 'stylegan2/projection')

# To cleanup
!rm projection/out/*.*
!rm projection/records/*.*
!rm projection/imgs/*.*