<a href="https://colab.research.google.com/github/dvschultz/ai/blob/master/SWA_playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Port of @arfafax’s notebook [here](https://github.com/arfafax/StyleGAN2_experiments/blob/master/StyleGAN2%20Network%20Interpolation.ipynb)

In [1]:
%tensorflow_version 1.x

!git clone https://github.com/dvschultz/stylegan2
%cd stylegan2

Cloning into 'stylegan2'...
remote: Enumerating objects: 257, done.[K
remote: Total 257 (delta 0), reused 0 (delta 0), pack-reused 257[K
Receiving objects: 100% (257/257), 15.25 MiB | 2.44 MiB/s, done.
Resolving deltas: 100% (137/137), done.
/content/stylegan2


In [2]:
import ipywidgets as widgets
import pretrained_networks
import PIL.Image
import numpy as np
src_model = '/content/network-snapshot-002111.pkl' #floralmag
dst_model = '/content/network-snapshot-000024.pkl' #ladiescrop
_G, _D, Gs = pretrained_networks.load_networks(src_model)
_Gd, _Dd, Gsd = pretrained_networks.load_networks(dst_model)
bGs = Gs.clone()
import dnnlib
import dnnlib.tflib as tflib
Gs_syn_kwargs = dnnlib.EasyDict()
batch_size = 1
Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_syn_kwargs.randomize_noise = True
Gs_syn_kwargs.minibatch_size = batch_size

noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]

Setting up TensorFlow plugin "fused_bias_act.cu": Preprocessing... Compiling... Loading... Done.
Setting up TensorFlow plugin "upfirdn_2d.cu": Preprocessing... Compiling... Loading... Done.


In [0]:
from dnnlib.tflib import tfutil
def weighted_average(src_net, dst_net, t):
    names = []
    for name in src_net.trainables.keys():
        if name not in src_net.trainables:
            print("Not restoring (not present):     {}".format(name))
        elif dst_net.trainables[name].shape != src_net.trainables[name].shape:
            print("Not restoring (different shape): {}".format(name))

        if name in src_net.trainables and dst_net.trainables[name].shape == src_net.trainables[name].shape:
            names.append(name)

    tfutil.set_vars(tfutil.run({bGs.vars[name]: (t*dst_net.vars[name] + (1-t)*src_net.vars[name]) for name in names}))

In [9]:
seed = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed: ', continuous_update=False)
scale = widgets.FloatSlider(min=0, max=5, step=0.01, value=1, description='Scale: ', continuous_update=False)
truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ', continuous_update=False)
blending = widgets.FloatSlider(min=0, max=1, step=0.01, value=0, description='Blending: ', continuous_update=False)

bot_box = widgets.HBox([seed, scale, truncation, blending])
ui = widgets.VBox([bot_box])

def display_sample(seed, scale, truncation, blending):
    weighted_average(Gs, Gsd, blending)
    
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation is not None:
        Gs_kwargs.truncation_psi = truncation
    rnd = np.random.RandomState(seed)
    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
    
    batch_size = 1
    all_seeds = [seed] * batch_size
    all_z = np.stack([np.random.RandomState(seed).randn(*bGs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]
    all_w = bGs.components.mapping.run(scale*all_z, None) # [minibatch, layer, component]
    if truncation != 1:
        w_avg = bGs.get_var('dlatent_avg')
        all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]
    all_images = bGs.components.synthesis.run(all_w, **Gs_syn_kwargs)
    #save image and display
    display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))

out = widgets.interactive_output(display_sample, {'seed': seed, 'scale': scale, 'truncation': truncation, 'blending': blending})

display(ui, out)

VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='Seed: ', max=100000), F…

Output()

In [11]:
%cd ../
%mkdir 17258
%cd stylegan2

/content
/content/stylegan2


In [0]:

def save_sample(seed, scale, truncation, blending):
    weighted_average(Gs, Gsd, blending)
    
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation is not None:
        Gs_kwargs.truncation_psi = truncation
    rnd = np.random.RandomState(seed)
    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
    
    batch_size = 1
    all_seeds = [seed] * batch_size
    all_z = np.stack([np.random.RandomState(seed).randn(*bGs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]
    all_w = bGs.components.mapping.run(scale*all_z, None) # [minibatch, layer, component]
    if truncation != 1:
        w_avg = bGs.get_var('dlatent_avg')
        all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]
    all_images = bGs.components.synthesis.run(all_w, **Gs_syn_kwargs)
    #save image and display
    #display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))
    PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)).save("/content/17258/%.2f.png" % blending)

In [21]:
%cd ../17258
for i in np.arange(0, 1.0, 0.01):
  save_sample(17528,1.0,0.1,i)

/content/17258


In [22]:
!zip -r test.zip /content/17258

  adding: content/17258/ (stored 0%)
  adding: content/17258/0.20.png (deflated 0%)
  adding: content/17258/0.95.png (deflated 0%)
  adding: content/17258/0.26.png (deflated 0%)
  adding: content/17258/0.51.png (deflated 0%)
  adding: content/17258/0.68.png (deflated 0%)
  adding: content/17258/0.63.png (deflated 0%)
  adding: content/17258/0.86.png (deflated 0%)
  adding: content/17258/0.22.png (deflated 0%)
  adding: content/17258/0.32.png (deflated 0%)
  adding: content/17258/0.03.png (deflated 0%)
  adding: content/17258/0.18.png (deflated 0%)
  adding: content/17258/0.49.png (deflated 0%)
  adding: content/17258/0.23.png (deflated 0%)
  adding: content/17258/0.41.png (deflated 0%)
  adding: content/17258/0.58.png (deflated 0%)
  adding: content/17258/0.94.png (deflated 0%)
  adding: content/17258/0.89.png (deflated 0%)
  adding: content/17258/0.46.png (deflated 0%)
  adding: content/17258/0.82.png (deflated 0%)
  adding: content/17258/0.47.png (deflated 0%)
  adding: content/17258