<a href="https://colab.research.google.com/github/dvschultz/stylegan2-ada-pytorch/blob/main/Network_Blending_ADA_PT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Network Blending
This demo will show how to combine two separate StyleGAN2-ADA-PyTorch models into one by splitting their weights at a specified layer.

This example was created by Derrick Schultz for his Advanced StyleGAN2 class. It’s a simpler version of [Justin Pinkney’s Tensorflow version](https://github.com/justinpinkney/stylegan2/blob/master/blend_models.py).

---

If you find this notebook useful, consider signing up for my [Patreon](https://www.patreon.com/bustbright) or [YouTube channel](https://www.youtube.com/channel/UCaZuPdmZ380SFUMKHVsv_AA/join). You can also send me a one-time payment on [Venmo](https://venmo.com/Derrick-Schultz).


In [None]:
!nvidia-smi -L

In [None]:
!git clone https://github.com/dvschultz/stylegan2-ada-pytorch
%cd stylegan2-ada-pytorch
!pip install ninja opensimplex

## Download two models

In [None]:
!wget http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl
!wget http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl

In [None]:
!gdown --id 15GpzB-wTwGIZC_Wu0ruaEJi7-giRWOOo -O /content/bone-bone.pkl
!wget http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl

In [None]:
!python legacy.py --source=/content/stylegan2-ada-pytorch/stylegan2-ffhq-config-f.pkl --dest=/content/ffhq-pt.pkl
!python legacy.py --source=/content/bone-bone.pkl --dest=/content/bone-bone-pt.pkl

## Script Example
If you want to simply run the command as a script, you run the cell below.

In [None]:
!python blend_models.py --help

In [None]:
!python blend_models.py --lower_res_pkl /content/ffhq-pt.pkl --split_res 64 --higher_res_pkl /content/bone-bone-pt.pkl --output_path /content/ffhq-bonebone-split64.pkl

## Code example

If you want to see under the hood here’s how this works.

In [None]:
import os
import copy
import numpy as np
import torch
import pickle
import dnnlib
import legacy

def extract_conv_names(model, model_res):
    model_names = list(name for name,weight in model.named_parameters())

    return model_names

def blend_models(low, high, model_res, resolution, level, blend_width=None):

    resolutions =  [4*2**x for x in range(int(np.log2(resolution)-1))]
    print(resolutions)
    
    low_names = extract_conv_names(low, model_res)
    high_names = extract_conv_names(high, model_res)

    assert all((x == y for x, y in zip(low_names, high_names)))

    #start with lower model and add weights above
    model_out = copy.deepcopy(low)
    params_src = high.named_parameters()
    dict_dest = model_out.state_dict()

    for name, param in params_src:
        if not any(f'synthesis.b{res}' in name for res in resolutions) and not ('mapping' in name):
            # print(name)
            dict_dest[name].data.copy_(param.data)

    model_out_dict = model_out.state_dict()
    model_out_dict.update(dict_dest) 
    model_out.load_state_dict(dict_dest)
    
    return model_out

In [None]:
lo_res_pkl = '/content/freagan-pt.pkl'
hi_res_pkl = '/content/ladiescrop.pkl'
model_res = 1024
level = 0
blend_width=None
out = '/content/blend-frea-ladiestransfer-128.pkl'

G_kwargs = dnnlib.EasyDict()

with dnnlib.util.open_url(lo_res_pkl) as f:
    # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
    lo = legacy.load_network_pkl(f, custom=False, **G_kwargs) # type: ignore
    lo_G, lo_D, lo_G_ema = lo['G'], lo['D'], lo['G_ema']

with dnnlib.util.open_url(hi_res_pkl) as f:
    # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
    hi = legacy.load_network_pkl(f, custom=False, **G_kwargs)['G_ema'] # type: ignore
    #hi_G, hi_D, hi_G_ema = hi['G'], lo['D'], lo['G_ema']

rezes = [8,16,32,64,128]
for r in rezes: 
    model_out = blend_models(lo_G_ema, hi, model_res, r, level, blend_width=blend_width)

    # for n in model_out.named_parameters():
    #     print(n[0])

    #save new pkl file
    out = f'/content/blend-frea-ladiestransfer-{r}.pkl'
    data = dict([('G', None), ('D', None), ('G_ema', None)])
    with open(out, 'wb') as f:
        #misc.save_pkl((low_res_G, low_res_D, out), output_pkl)
        data['G'] = lo_G
        data['D'] = lo_D
        data['G_ema'] = model_out
        pickle.dump(data, f)



In [None]:
## Test Generating Images With Your New Model

In [None]:
for r in rezes:
    !python generate.py --outdir=/content/out/blended-frea-ladiestransfer2-{r}/ --trunc=0.6 --seeds=0-24 --network=/content/blend-frea-ladiestransfer-{r}.pkl

In [None]:
!zip -r transferred-blends_r2.zip /content/out