<a href="https://colab.research.google.com/github/adamdavidcole/stylegan2-ada-pytorch-adam/blob/main/network_blending_gui.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Network Blending
This demo will show how to combine two separate StyleGAN2-ADA-PyTorch models into one by splitting their weights at a specified layer.

This example was created by Derrick Schultz for his Advanced StyleGAN2 class. It’s a simpler version of [Justin Pinkney’s Tensorflow version](https://github.com/justinpinkney/stylegan2/blob/master/blend_models.py).

---

If you find this notebook useful, consider signing up for my [Patreon](https://www.patreon.com/bustbright) or [YouTube channel](https://www.youtube.com/channel/UCaZuPdmZ380SFUMKHVsv_AA/join). You can also send me a one-time payment on [Venmo](https://venmo.com/Derrick-Schultz).


In [None]:
!nvidia-smi -L

In [None]:
!pip install einops ninja gdown

In [None]:
# Connect Google Drive 
# (NOTE: only run this if you want to save the results in GDrive after the runtime ends)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
!pip install gdown --upgrade

if os.path.isdir("/content/drive/MyDrive/stylegan2-ada-pytorch-adam"):
    %cd "/content/drive/MyDrive/stylegan2-ada-pytorch-adam"
elif os.path.isdir("/content/drive/"):
    #install script
    %cd "/content/drive/MyDrive/"
    !git clone https://github.com/adamdavidcole/stylegan2-ada-pytorch-adam.git
    %cd stylegan2-ada-pytorch-adam
    
    # !gdown --id 1-5xZkD8ajXw1DdopTkH_rAoCsD72LhKU -O /content/drive/MyDrive/colab-sg2-ada-pytorch/stylegan2-ada-pytorch/pretrained/wikiart.pkl
else:
    !git clone https://github.com/adamdavidcole/stylegan2-ada-pytorch-adam.git
    %cd stylegan2-ada-pytorch-adam
    # !mkdir downloads
    # !mkdir datasets
    # !mkdir pretrained

!mkdir input_images
!mkdir input_images/raw
!mkdir input_images/aligned

In [None]:
!git config --global user.name "test"
!git config --global user.email "test@test.com"
!git fetch origin
# !git pull
# !git stash
!git checkout origin/main -- "*.py" 
# !git checkout origin/main -- "*.ipynb"
# !git checkout origin/main -- "ffhq_dataset/*" 


In [None]:
!pip install ninja opensimplex

In [None]:
!python legacy.py \
        --source=/content/drive/MyDrive/stylegan2-ada-pytorch-adam/pretrained/stylegan2-ffhq-slim.pkl.txt \
        --dest=/content/drive/MyDrive/stylegan2-ada-pytorch-adam/pretrained/stylegan2-ffhq-slim.pkl

## Download Pretrained Models

In [None]:
# https://drive.google.com/file/d//view?usp=sharing
if not os.path.isdir('pretrained'):
  !mkdir pretrained

# butterflys
!gdown --id 105VsQSTdthX4lSvHUW6YM0_MiaHcOruJ -O pretrained/butterflys_000016.pkl
!gdown --id 107MmrDtr0GX8rDXDJzi59-7tnlz0QhjC -O pretrained/butterflys_000032.pkl
!gdown --id 10QXjGYAo9sn-UYKKpJqyDX1ON14OscQ2 -O pretrained/butterflys_000048.pkl
!gdown --id 15NC-plFvfs59NLT0-t3SIucpJcEmvOmq -O pretrained/butterflys_000677.pkl

# ukiyoe
!gdown --id 1BkRsnE0YygA2ufbfDOV4-fOgTMjSr94K -O pretrained/stylegan2-ffhq-slim.pkl
!gdown --id 1BjYGiOUKk8SC35a2e5QrJ1QtvaxJ0QD7 -O pretrained/ukiyoe-256-slim.pkl

!wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl -O pretrained/ffhq_256.pkl

In [None]:
#@title Helper Functions & Setup
#common functions 
import pickle, torch, PIL, copy, cv2, math
import numpy as np
import ipywidgets as widgets
from IPython.display import display
from google.colab import files
from io import BytesIO
from PIL import Image, ImageEnhance

from IPython.display import Image as DisplayImage, clear_output

# define device to use
device = torch.device('cuda')

def get_model(path):
  # with open(path, 'rb') as f:
  #   _G = pickle.load(f)['G_ema'].cuda()
  device = torch.device('cuda')
  with dnnlib.util.open_url(path) as fp:
      _G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device)
  
  return _G

#tensor to PIL image 
def t2i(t):
  return PIL.Image.fromarray((t*127.5+127).clamp(0,255)[0].permute(1,2,0).cpu().numpy().astype('uint8'))

#stack an array of PIL images horizontally
def add_imgs(images):
  widths, heights = zip(*(i.size for i in images))

  total_width = sum(widths)
  max_height = max(heights)

  new_im = PIL.Image.new('RGB', (total_width, max_height))

  x_offset = 0
  for im in images:
    new_im.paste(im, (x_offset,0))
    x_offset += im.size[0]
  return new_im


def apply_mask(matrix, mask, fill_value):
    masked = np.ma.array(matrix, mask=mask, fill_value=fill_value)
    return masked.filled()
 
def apply_threshold(matrix, low_value, high_value):
    low_mask = matrix < low_value
    matrix = apply_mask(matrix, low_mask, low_value)
 
    high_mask = matrix > high_value
    matrix = apply_mask(matrix, high_mask, high_value)
 
    return matrix

# A simple color correction script to brighten overly dark images
def simplest_cb(img, percent):
    assert img.shape[2] == 3
    assert percent > 0 and percent < 100
 
    half_percent = percent / 200.0
 
    channels = cv2.split(img)
 
    out_channels = []
    for channel in channels:
        assert len(channel.shape) == 2
        # find the low and high precentile values (based on the input percentile)
        height, width = channel.shape
        vec_size = width * height
        flat = channel.reshape(vec_size)
 
        assert len(flat.shape) == 1
 
        flat = np.sort(flat)
 
        n_cols = flat.shape[0]
 
        low_val  = flat[math.floor(n_cols * half_percent)-1]
        high_val = flat[math.ceil( n_cols * (1.0 - half_percent))-1]
 
 
        # saturate below the low percentile and above the high percentile
        thresholded = apply_threshold(channel, low_val, high_val)
        # scale the channel
        normalized = cv2.normalize(thresholded, thresholded.copy(), 0, 255, cv2.NORM_MINMAX)
        out_channels.append(normalized)
 
    return cv2.merge(out_channels)
 
def normalize(inf, thresh):
    img = np.array(inf)
    out_img = simplest_cb(img, thresh)
    return PIL.Image.fromarray(out_img)

def get_w_from_path(w_path):
  projected_w_np = np.load(projected_w_path)[0]
  w = torch.tensor(projected_w_np).to(device).unsqueeze(0)
  return w

def synthesize_tensor_from_w(G, w):
  # print(w.shape)
  # print(w)
  return G.synthesis(w, noise_mode='const', force_fp32=True)

def synthesize_img_from_w(G, w):
  tensor = synthesize_tensor_from_w(G, w)
  return t2i(tensor)

def synthesize_tensor_from_w_path(G, w_path):
  w = get_w_from_path(w_path)
  return synthesize_tensor_from_w(G, w)

def synthesize_img_from_w_path(G, w_path):
  tensor = synthesize_tensor_from_w_path(G, w_path)
  return t2i(tensor)

def synthesize_img_from_w_path(G, w_path):
  tensor = synthesize_tensor_from_w_path(G, w_path)
  return t2i(tensor)

def synthesize_img_from_w_np(G, w_np):
  w = torch.tensor(w_np).to(device).unsqueeze(0)
  tensor = synthesize_img_from_w(G, w)
  return t2i(tensor)



class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

In [None]:
#@title Blend Functions

import os
import copy
import numpy as np
import torch
import pickle
import dnnlib
import legacy

def extract_conv_names(model, model_res):
    model_names = list(name for name,weight in model.named_parameters())

    return model_names

def blend_models(low, high, model_res, resolution, level, blend_width=None, blend_mask=None):

    resolutions =  [4*2**x for x in range(int(np.log2(resolution)-1))]
    
    low_names = extract_conv_names(low, model_res)
    high_names = extract_conv_names(high, model_res)

    assert all((x == y for x, y in zip(low_names, high_names)))

    #start with lower model and add weights above
    model_out = copy.deepcopy(low)
    params_src = high.named_parameters()
    dict_dest = model_out.state_dict()

    if blend_mask is None:
      for name, param in params_src:
          if not any(f'synthesis.b{res}' in name for res in resolutions) and not ('mapping' in name):
              # print(name)
              # print(param.data)
              dict_dest[name].data.copy_(param.data)
    else:
      for name, param in params_src:
        if not ('mapping' in name):
          # print(f"name: {name}")


          for idx, res in enumerate(resolutions):
            if f'synthesis.b{res}' in name:
              mask_val = blend_mask[idx]
              next_data = dict_dest[name].data * (1 - mask_val) + param.data * (mask_val)

              # print(mask_val)

              dict_dest[name].data.copy_(next_data)


    model_out_dict = model_out.state_dict()
    model_out_dict.update(dict_dest) 
    model_out.load_state_dict(dict_dest)
    
    return model_out

In [None]:
#@title Select Models {run: "auto"}
#@markdown Select a pretrained model for the source and destination or paste links to your own
#@markdown <br/>(Note: destination must be fine-tuned from source and both must be StyleGAN2 pkl format)
source_model = "FFHQ_256" #@param ["FFHQ_256", "FFHQ_256_slim"] {allow-input: true}
destination_model = "Butteflys_0048" #@param ["Butteflys_0016", "Butteflys_0032", "Butteflys_0048", "Butteflys_0677", "Ukiyoe_256_slim"] {allow-input: true}


model_keys = {
    "Butteflys_0016": "pretrained/butterflys_000016.pkl",
    "Butteflys_0032": "pretrained/butterflys_000032.pkl",
    "Butteflys_0048": "pretrained/butterflys_000048.pkl",
    "Butteflys_0677": "pretrained/butterflys_000677.pkl",
    "FFHQ_256": "pretrained/ffhq_256.pkl", 

    "FFHQ_256_slim": "/content/drive/MyDrive/stylegan2-ada-pytorch-adam/pretrained/stylegan2-ffhq-slim.pkl",
    "Ukiyoe_256_slim": "/content/drive/MyDrive/stylegan2-ada-pytorch-adam/pretrained/ukiyoe-256-slim.pkl"
}

lo_res_pkl = model_keys[source_model] if source_model in model_keys  else source_model
hi_res_pkl = model_keys[destination_model] if destination_model in model_keys else destination_model
model_res = 256
level = 0
blend_width=None

G_kwargs = dnnlib.EasyDict()

with dnnlib.util.open_url(lo_res_pkl) as f:
    # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
    lo = legacy.load_network_pkl(f, custom=False, **G_kwargs) # type: ignore
    lo_G, lo_D, lo_G_ema = lo['G'], lo['D'], lo['G_ema']

with dnnlib.util.open_url(hi_res_pkl) as f:
    # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
    hi = legacy.load_network_pkl(f, custom=False, **G_kwargs)['G_ema'] # type: ignore
    #hi_G, hi_D, hi_G_ema = hi['G'], lo['D'], lo['G_ema']


### Project Face

In [None]:
!python align_images.py 

In [None]:
resume_network=

!python train.py --outdir=results --data=/content/drive/MyDrive/stylegan2-ada-pytorch-adam/input_images/raw/pokemon_256.zip \
  --gpus=1 --cfg=paper256 --mirror=1 --snap=1 --aug=noaug --metrics=none --resume=$lo_res_pkl


In [None]:
import time 
# ts stores the time in seconds
ts = int(time.time())

network_name = os.path.basename(lo_res_pkl)
projection_outdir = f"projections/{ts}_{network_name}/"

num_steps = 1001 #@param {type: "slider", min: 1, max: 10000, step: 1}
uploaded_file_path = "/content/drive/MyDrive/stylegan2-ada-pytorch-adam/input_images/aligned/mona_real_01.png" #@param {type: "string"} 

!python projector.py --outdir=$projection_outdir --target=$uploaded_file_path --num-steps=$num_steps --save-video=false \
  --network=$lo_res_pkl

### Network Blend Basic

In [None]:
#@title Select Blend Layer {run: "auto"}
device = "cuda"

#@markdown **Select source vector**
projected_w_path = "/content/drive/MyDrive/stylegan2-ada-pytorch-adam/projections/p1655515028/projected_w.npz" #@param {type: "string"}
use_projected_w = False #@param {type:"boolean"}
seed=5601 #@param {type: "slider", min: 0, max: 10000, step: 1}

#@markdown ---

switch_layer = 8 #@param [4, 8, 16, 32, 64, 128]  {type:"raw"}
blend_width = 0 #@param {type: "slider", min: 0, max: 5, step: 0.01}
model_out = blend_models(lo_G_ema, hi, model_res, switch_layer, level, blend_width=blend_width)

G1 = lo_G_ema.to(device)
G2 = hi.to(device)
G_blend = model_out.to(device)


if use_projected_w:
  w_np = np.load(projected_w_path)['w']
  w = torch.tensor(w_np).to(device)
else:
  label = torch.zeros([1, G1.c_dim], device=device)
  z = torch.from_numpy(np.random.RandomState(seed).randn(1, G1.z_dim)).to(device)

  w = G1.mapping(z, None, truncation_psi=0.8, truncation_cutoff=8)


g1_img = G1.synthesis(w, noise_mode='const', force_fp32=True)
g1_img = (g1_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
g1_imgfile = PIL.Image.fromarray(g1_img[0].cpu().numpy(), 'RGB')

# g1_imgfile.save(f'G1seed{seed:04d}.png')
g2_img = G2.synthesis(w, noise_mode='const', force_fp32=True)
g2_img = (g2_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
g2_imgfile = PIL.Image.fromarray(g2_img[0].cpu().numpy(), 'RGB')

g3_img = G_blend.synthesis(w, noise_mode='const', force_fp32=True)
g3_img = (g3_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
g3_imgfile = PIL.Image.fromarray(g3_img[0].cpu().numpy(), 'RGB')
display(add_imgs([g1_imgfile, g3_imgfile, g2_imgfile]))

### Fine Tune Blend

In [None]:
#@title Select Blend Layer {run: "auto"}
device = "cuda"

#@markdown **Select source vector**
projected_w_path = "/content/drive/MyDrive/stylegan2-ada-pytorch-adam/projections/1655519505_ffhq_256.pkl/projected_w.npz" #@param {type: "string"}
use_projected_w = True #@param {type:"boolean"}
seed=2294 #@param {type: "slider", min: 0, max: 10000, step: 1}

#@markdown ---

# switch_layer = 128 #@param [4, 8, 16, 32, 64, 128]  {type:"raw"}

blend_4 = 0 #@param {type: "slider", min: 0, max: 1, step: 0.01}
blend_8 = 0.2 #@param {type: "slider", min: 0, max: 1,  step: 0.01}
blend_16 = 0.47 #@param {type: "slider", min: 0, max: 1, step: 0.01}
blend_32 = 1 #@param {type: "slider", min: 0, max: 1, step: 0.01}
blend_64 = 1 #@param {type: "slider", min: 0, max: 1, step: 0.01}
blend_128 = 1 #@param {type: "slider", min: 0, max: 1, step: 0.01}
blend_256 = 1 #@param {type: "slider", min: 0, max: 1, step: 0.01}

blend_mask = [blend_4, blend_8, blend_16, blend_32, blend_64, blend_128, blend_256]
print(blend_mask)

model_out = blend_models(lo_G_ema, hi, model_res, model_res, level, blend_width=blend_width, blend_mask=blend_mask)

G1 = lo_G_ema.to(device)
G2 = hi.to(device)
G_blend = model_out.to(device)


if use_projected_w:
  w_np = np.load(projected_w_path)['w']
  w = torch.tensor(w_np).to(device)
else:
  label = torch.zeros([1, G1.c_dim], device=device)
  z = torch.from_numpy(np.random.RandomState(seed).randn(1, G1.z_dim)).to(device)

  w = G1.mapping(z, None, truncation_psi=0.8, truncation_cutoff=8)


g1_img = G1.synthesis(w, noise_mode='const', force_fp32=True)
g1_img = (g1_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
g1_imgfile = PIL.Image.fromarray(g1_img[0].cpu().numpy(), 'RGB')

# g1_imgfile.save(f'G1seed{seed:04d}.png')
g2_img = G2.synthesis(w, noise_mode='const', force_fp32=True)
g2_img = (g2_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
g2_imgfile = PIL.Image.fromarray(g2_img[0].cpu().numpy(), 'RGB')

g3_img = G_blend.synthesis(w, noise_mode='const', force_fp32=True)
g3_img = (g3_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
g3_imgfile = PIL.Image.fromarray(g3_img[0].cpu().numpy(), 'RGB')
display(add_imgs([g1_imgfile, g3_imgfile, g2_imgfile]))

### Experimental Overblending
Choose values outside the 0-1 scale

In [None]:
#@title Select Blend Layer {run: "auto"}
device = "cuda"

#@markdown **Select source vector**
projected_w_path = "/content/drive/MyDrive/stylegan2-ada-pytorch-adam/out/256_styglegan2_ada_3/projected_w.npz" #@param {type: "string"}
use_projected_w = False #@param {type:"boolean"}
seed=2973 #@param {type: "slider", min: 0, max: 10000, step: 1}

#@markdown ---

# switch_layer = 128 #@param [4, 8, 16, 32, 64, 128]  {type:"raw"}

blend_4 = -0.78 #@param {type: "slider", min: -10, max: 10, step: 0.01}
blend_8 = 0.86 #@param {type: "slider", min: -10, max: 10, step: 0.01}
blend_16 = 0.79 #@param {type: "slider", min: -10, max: 10, step: 0.01}
blend_32 = -0.19 #@param {type: "slider", min: -10, max: 10, step: 0.01}
blend_64 = 0.16 #@param {type: "slider", min: -10, max: 10, step: 0.01}
blend_128 = 0.01 #@param {type: "slider", min: -10, max: 10, step: 0.01}
blend_256 = 0.19 #@param {type: "slider", min: -10, max: 10, step: 0.01}

blend_mask = [blend_4, blend_8, blend_16, blend_32, blend_64, blend_128, blend_256]
print(blend_mask)

model_out = blend_models(lo_G_ema, hi, model_res, model_res, level, blend_width=blend_width, blend_mask=blend_mask)

G1 = lo_G_ema.to(device)
G2 = hi.to(device)
G_blend = model_out.to(device)


if use_projected_w:
  w_np = np.load(projected_w_path)['w']
  w = torch.tensor(w_np).to(device)
else:
  label = torch.zeros([1, G1.c_dim], device=device)
  z = torch.from_numpy(np.random.RandomState(seed).randn(1, G1.z_dim)).to(device)

  w = G1.mapping(z, None, truncation_psi=0.8, truncation_cutoff=8)


g1_img = G1.synthesis(w, noise_mode='const', force_fp32=True)
g1_img = (g1_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
g1_imgfile = PIL.Image.fromarray(g1_img[0].cpu().numpy(), 'RGB')

# g1_imgfile.save(f'G1seed{seed:04d}.png')
g2_img = G2.synthesis(w, noise_mode='const', force_fp32=True)
g2_img = (g2_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
g2_imgfile = PIL.Image.fromarray(g2_img[0].cpu().numpy(), 'RGB')

g3_img = G_blend.synthesis(w, noise_mode='const', force_fp32=True)
g3_img = (g3_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
g3_imgfile = PIL.Image.fromarray(g3_img[0].cpu().numpy(), 'RGB')
display(add_imgs([g1_imgfile, g3_imgfile, g2_imgfile]))