<a href="https://colab.research.google.com/github/dvschultz/ai/blob/master/Aydao_CopyCropWeights.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Make new PKLs from other (larger) PKLs

Lightly documented process for generating pkl files from other pkls. This new version can crop files as long as they use the same layer count (using the same res_log2 count)

Big big shout out to [@AydaoGMan on twitter](https://twitter.com/AydaoGMan) for supplying the code for this that I’m cutting up to put in this notebook!

##Install libraries and dependencies

In [1]:
%tensorflow_version 1.x

!git clone https://github.com/aydao/stylegan2-surgery
%cd stylegan2-surgery/

TensorFlow 1.x selected.
Cloning into 'stylegan2-surgery'...
remote: Enumerating objects: 29, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 1399 (delta 11), reused 14 (delta 6), pack-reused 1370[K
Receiving objects: 100% (1399/1399), 13.71 MiB | 15.72 MiB/s, done.
Resolving deltas: 100% (951/951), done.
/content/stylegan2-surgery


In [0]:
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import sys, getopt, os

import numpy as np
import dnnlib
import dnnlib.tflib as tflib
from dnnlib.tflib import tfutil
from dnnlib.tflib.autosummary import autosummary

from training import dataset
from training import misc
import pickle
import argparse

# Note well that the argument order is target then source  
def copy_and_crop_trainables_from(target_net, source_net) -> None:
    for name in target_net.trainables.keys():
        if name not in source_net.trainables:
            print("Not restoring (not present):     {}".format(name))
        elif target_net.trainables[name].shape != source_net.trainables[name].shape:
            print("Not restoring (different shape): {}".format(name))
        elif name in source_net.trainables and target_net.trainables[name].shape == source_net.trainables[name].shape:
            print("Restoring: {}".format(name))
    source_trainables = source_net.trainables.keys()
    target_trainables = target_net.trainables.keys()
    names = [pair for pair in zip(source_trainables, target_trainables)]
            
    skip = []
    for pair in names:
        source_name, target_name = pair
        x = source_net.get_var(source_name)
        y = target_net.get_var(target_name)
        source_shape = x.shape
        target_shape = y.shape
        if source_shape != target_shape:
            update = x
            index = None 
            if 'Dense' in source_name:
                index = 0
                gap = source_shape[index] - target_shape[index]
                start = abs(gap) // 2
                end = start + target_shape[index]
                update = update[start:end,:]
            else:
                if source_shape[2] != target_shape[2]:
                    index = 2
                    gap = source_shape[index] - target_shape[index]
                    start = abs(gap) // 2
                    end = start + target_shape[index]
                    update = update[:,:,start:end,:]
                if source_shape[3] != target_shape[3]:
                    index = 3
                    gap = source_shape[index] - target_shape[index]
                    start = abs(gap) // 2
                    end = start + target_shape[index]
                    update = update[:,:,:,start:end]

            target_net.set_var(target_name, update)
            skip.append(source_name)

    weights_to_copy = {target_net.vars[pair[1]]: source_net.vars[pair[0]] for pair in names if pair[0] not in skip}
    tfutil.set_vars(tfutil.run(weights_to_copy))

def copy_weights(source_pkl,target_pkl,output_pkl):

    tflib.init_tf()

    with tf.Session() as sess:
        with tf.device('/gpu:0'):

            sourceG, sourceD, sourceGs = pickle.load(open(source_pkl, 'rb'))
            targetG, targetD, targetGs = pickle.load(open(target_pkl, 'rb'))
            
            print('Source:')
            sourceG.print_layers()
            sourceD.print_layers() 
            sourceGs.print_layers()
            
            print('Target:')
            targetG.print_layers()
            targetD.print_layers() 
            targetGs.print_layers()
            
            copy_and_crop_trainables_from(targetG, sourceG)
            copy_and_crop_trainables_from(targetD, sourceD)
            copy_and_crop_trainables_from(targetGs, sourceGs)
            
            misc.save_pkl((targetG, targetD, targetGs), os.path.join('./', output_pkl))

##Conversion

Steps to produce:
1. Create a pkl that is 1/2 the size of your current model (You can do this by starting a new model in the skyflynil repo). Grab the network-snapshot-00000.pkl and bring it into Colab.
2. Import your already trained pkl
3. Run the command below and swap out the paths. (First path is the pretrained model, Second path is the empty, half-scale model.)


In [4]:
!gdown --id 1W9DJy51SHWyYbdGi72NH37_ahlkOzmXi -O /content/w.pkl
!gdown --id 1ZqXxZxwDU0LO0NPCUxKbz0BmtJqguaA8 -O /content/glitch.pkl

Downloading...
From: https://drive.google.com/uc?id=1W9DJy51SHWyYbdGi72NH37_ahlkOzmXi
To: /content/w.pkl
379MB [00:04, 85.6MB/s]
Downloading...
From: https://drive.google.com/uc?id=1ZqXxZxwDU0LO0NPCUxKbz0BmtJqguaA8
To: /content/glitch.pkl
379MB [00:05, 71.4MB/s]


In [7]:
!gdown --id 1bXUQyWn5IAAz8k6pE9sKghIHSn1Ee0uq -O /content/smaller.pkl

Downloading...
From: https://drive.google.com/uc?id=1bXUQyWn5IAAz8k6pE9sKghIHSn1Ee0uq
To: /content/smaller.pkl
372MB [00:03, 98.0MB/s]


In [8]:
copy_weights('/content/w.pkl','/content/smaller.pkl','/content/copied.pkl')

Source:

G                              Params    OutputShape         WeightShape     
---                            ---       ---                 ---             
latents_in                     -         (?, 512)            -               
labels_in                      -         (?, 0)              -               
lod                            -         ()                  -               
dlatent_avg                    -         (512,)              -               
G_mapping/latents_in           -         (?, 512)            -               
G_mapping/labels_in            -         (?, 0)              -               
G_mapping/Normalize            -         (?, 512)            -               
G_mapping/Dense0               262656    (?, 512)            (512, 512)      
G_mapping/Dense1               262656    (?, 512)            (512, 512)      
G_mapping/Dense2               262656    (?, 512)            (512, 512)      
G_mapping/Dense3               262656    (?, 512)      

In [0]:
cd ../

/content/stylegan2-surgery


## Testing it out

In [9]:
pwd

'/content/stylegan2-surgery'

In [10]:
%cd ../
!git clone https://github.com/dvschultz/stylegan2
!pip install opensimplex
%cd stylegan2

/content
Cloning into 'stylegan2'...
remote: Enumerating objects: 282, done.[K
remote: Total 282 (delta 0), reused 0 (delta 0), pack-reused 282[K
Receiving objects: 100% (282/282), 15.27 MiB | 3.79 MiB/s, done.
Resolving deltas: 100% (152/152), done.
Collecting opensimplex
  Downloading https://files.pythonhosted.org/packages/b9/28/25649c7258dd530aaf9a5c8a4bffeb08e6293ed3d4b671d44b985e8c28fa/opensimplex-0.2.tar.gz
Building wheels for collected packages: opensimplex
  Building wheel for opensimplex (setup.py) ... [?25l[?25hdone
  Created wheel for opensimplex: filename=opensimplex-0.2-cp36-none-any.whl size=14040 sha256=5c91bf1b5e852979197886a8eb0871205a8a78cc87275ad0d2e0963213114cd0
  Stored in directory: /root/.cache/pip/wheels/d3/25/5a/f35e0ac92237c60db97e2f49536fca6d817b0f83340253abee
Successfully built opensimplex
Installing collected packages: opensimplex
Successfully installed opensimplex-0.2
/content/stylegan2


In [11]:
!python run_generator.py generate-images --network=/content/copied.pkl --seeds=10-75 --truncation-psi=0.5

Local submit - run_dir: results/00000-generate-images
dnnlib: Running run_generator.generate_images() on localhost...
Loading networks from "/content/copied.pkl"...
Setting up TensorFlow plugin "fused_bias_act.cu": Preprocessing... Compiling... Loading... Done.
Setting up TensorFlow plugin "upfirdn_2d.cu": Preprocessing... Compiling... Loading... Done.
Generating image for seed 10 (0/66) ...
Generating image for seed 11 (1/66) ...
Generating image for seed 12 (2/66) ...
Generating image for seed 13 (3/66) ...
Generating image for seed 14 (4/66) ...
Generating image for seed 15 (5/66) ...
Generating image for seed 16 (6/66) ...
Generating image for seed 17 (7/66) ...
Generating image for seed 18 (8/66) ...
Generating image for seed 19 (9/66) ...
Generating image for seed 20 (10/66) ...
Generating image for seed 21 (11/66) ...
Generating image for seed 22 (12/66) ...
Generating image for seed 23 (13/66) ...
Generating image for seed 24 (14/66) ...
Generating image for seed 25 (15/66) ...