<a href="https://colab.research.google.com/github/dvschultz/ml-art-colabs/blob/master/SWA_playground_SG2ADA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Modified port of @arfafax’s notebook [here](https://github.com/arfafax/StyleGAN2_experiments/blob/master/StyleGAN2%20Network%20Interpolation.ipynb) to work with stylegan2-ada

In [None]:
%tensorflow_version 1.x

!git clone https://github.com/dvschultz/stylegan2-ada
%cd stylegan2-ada

In [None]:
!nvidia-smi

In [None]:
!gdown --id 12JuWv5OAqInDajEtyk8C9Xy7H5kMOw89
!gdown --id 1FbltaJujVl5V9LZoX9-wcZFTlqCbWNcz

In [None]:
import ipywidgets as widgets
#import pretrained_networks
import PIL.Image
import numpy as np
import pickle

import dnnlib
import dnnlib.tflib as tflib

tflib.init_tf()

src_model = './network-snapshot-000448.pkl' #floralmag
dst_model = './network-snapshot-000042.pkl' #ladiescrop

print('Loading source network from "%s"...' % src_model)
with dnnlib.util.open_url(src_model) as fp:
    _G, _D, Gs = pickle.load(fp)
print('Loading destination network from "%s"...' % dst_model)
with dnnlib.util.open_url(dst_model) as fp:
    _Gd, _Dd, Gsd = pickle.load(fp)

bGs = Gs.clone()

Gs_syn_kwargs = dnnlib.EasyDict()
batch_size = 1
Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_syn_kwargs.randomize_noise = True
Gs_syn_kwargs.minibatch_size = batch_size

noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]

In [None]:
from dnnlib.tflib import tfutil
def weighted_average(src_net, dst_net, t):
    names = []
    for name in src_net.trainables.keys():
        if name not in src_net.trainables:
            print("Not restoring (not present):     {}".format(name))
        elif dst_net.trainables[name].shape != src_net.trainables[name].shape:
            print("Not restoring (different shape): {}".format(name))

        if name in src_net.trainables and dst_net.trainables[name].shape == src_net.trainables[name].shape:
            names.append(name)

    tfutil.set_vars(tfutil.run({bGs.vars[name]: (t*dst_net.vars[name] + (1-t)*src_net.vars[name]) for name in names}))

In [None]:
seed = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed: ', continuous_update=False)
scale = widgets.FloatSlider(min=0, max=5, step=0.01, value=1, description='Scale: ', continuous_update=False)
truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ', continuous_update=False)
blending = widgets.FloatSlider(min=0, max=1, step=0.01, value=0, description='Blending: ', continuous_update=False)

bot_box = widgets.HBox([seed, scale, truncation, blending])
ui = widgets.VBox([bot_box])

def display_sample(seed, scale, truncation, blending):
    weighted_average(Gs, Gsd, blending)
    
    Gs_kwargs = {
        'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
        'randomize_noise': False
    }
    if truncation is not None:
        Gs_kwargs['truncation_psi'] = truncation
    
    rnd = np.random.RandomState(seed)
    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
    
    batch_size = 1
    all_seeds = [seed] * batch_size
    all_z = np.stack([np.random.RandomState(seed).randn(*bGs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]
    all_w = bGs.components.mapping.run(scale*all_z, None) # [minibatch, layer, component]
    if truncation != 1:
        w_avg = bGs.get_var('dlatent_avg')
        all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]
    all_images = bGs.components.synthesis.run(all_w, **Gs_syn_kwargs)
    #save image and display
    display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))

out = widgets.interactive_output(display_sample, {'seed': seed, 'scale': scale, 'truncation': truncation, 'blending': blending})

display(ui, out)

In [None]:
%cd ../
%mkdir 17258
%cd stylegan2

In [None]:

def save_sample(seed, scale, truncation, blending):
    weighted_average(Gs, Gsd, blending)
    
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation is not None:
        Gs_kwargs.truncation_psi = truncation
    rnd = np.random.RandomState(seed)
    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
    
    batch_size = 1
    all_seeds = [seed] * batch_size
    all_z = np.stack([np.random.RandomState(seed).randn(*bGs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]
    all_w = bGs.components.mapping.run(scale*all_z, None) # [minibatch, layer, component]
    if truncation != 1:
        w_avg = bGs.get_var('dlatent_avg')
        all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]
    all_images = bGs.components.synthesis.run(all_w, **Gs_syn_kwargs)
    #save image and display
    #display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))
    PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)).save("/content/17258/%.2f.png" % blending)

In [None]:
%cd ../17258
for i in np.arange(0, 1.0, 0.01):
  save_sample(17528,1.0,0.1,i)