<a href="https://colab.research.google.com/github/TA-aiacademy/course_3.0/blob/v2-5_gan/08_v2-5_GAN/Part4/01_Stylegan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# StyleGAN

### 本章節內容大綱
* [StyleGAN](#StyleGAN)
* [StyleGAN in Anime dataset (by Gwern)](#StyleGAN-in-Anime-dataset-(by-Gwern))

這個章節要來 demo 2019年初才釋出 weight 的 StyleGAN，本次的教學是由 https://github.com/NVlabs/stylegan clone下來的 weight 以及 code 再作一些修改，由於該模型是用 tf1.X 版本訓練的，故助教在這邊有修改一些版本的細節，學員如果要在本機端上clone github，記得使用tf 1.X版本來使用，或是將這份教材複製到本機端用tf2.0跑也是可以的。

StyleGAN Generator的架構，可以理解為前面幾層是在勾勒輪廓，後面是在畫精細的細節。

<img src="https://hackmd.io/_uploads/HJp5gETga.jpg" width=500  />



StyleGAN 承襲了 ProgressiveGAN 的 Discriminator，基本上也是用PatchGAN的概念。

<img src="https://hackmd.io/_uploads/SkMpg4Tx6.png" width=500  />

In [None]:
# 上傳資料
!wget -q https://github.com/TA-aiacademy/course_3.0/releases/download/v2.5_gan/GAN_part4.zip
!unzip -q GAN_part4.zip

In [None]:
import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import matplotlib.pyplot as plt

import imageio
import glob
from IPython.display import display, Image
import cv2

### 讀入產生高畫質人臉圖片的權重

In [None]:
url = 'cache/2019stylegan-ffhq-1024x1024_mod.pkl'

tflib.init_tf()
with open(url, 'rb') as f:
    _G, _D, Gs = pickle.load(f)

### 產生隨機的圖片

In [None]:
# 隨機sample一組潛在向量(latent vector)來產生圖片
rnd = np.random.RandomState(420)
latents = rnd.randn(1, Gs.input_shape[1])

In [None]:
# Generate image
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

plt.figure(figsize=(10,10))
plt.imshow(images[-1])
plt.show()

### 試著一次改變vector中的一個element來看看有什麼變化吧

In [None]:
# 先固定一個值都是 1 的向量
latents = np.ones((1, Gs.input_shape[1]))

In [None]:
# 產生圖片
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

plt.figure(figsize=(10, 10))
plt.imshow(images[-1])
plt.show()

In [None]:
save_path = './exp_img/lat_18' # 向量總共有 512 維，如果想要改變第18維就修改成 lat_18

if not os.path.exists(save_path):
    os.makedirs(save_path)

ind = int(save_path.split('_')[-1])

for i in np.arange(-15, 16, 0.5): # 每次改變 0.5 的值看看
    latents[0][ind] = i
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

    plt.figure(figsize=(5, 5))
    plt.imshow(images[-1])

    plt.savefig(os.path.join(save_path, 'image_{:03f}.png'.format(i)))

    plt.show()

下面的gif可以方便我們觀察改動不同的 element ，會對於 output 有什麼影響

In [None]:
# 使用imageio製作gif圖
anim_file = save_path + '/anim.gif'

with imageio.get_writer(anim_file, mode='I') as writer:

    filenames = glob.glob(save_path + '/image*.png')
#     filenames = sorted(filenames)
    filenames.sort(key=lambda x: os.path.getmtime(x))

    last = -1
    for i, filename in enumerate(filenames):
        frame = 2*(i**0.5)
        if round(frame) > round(last):
            last = frame
        else:
            continue
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

display(Image(filename=anim_file))

In [None]:
# change the first element in latent vector
display(Image(filename='./exp_img/lat_0/anim.gif'))

In [None]:
# change the 256th element in latent vector
display(Image(filename='./exp_img/lat_255/anim.gif'))

In [None]:
# change the last element in latent vector
display(Image(filename='./exp_img/lat_511/anim.gif'))

# 風格混合(style mixing)

StyleGAN 不只像是一般的 GAN 能隨機生成一張逼真的圖片，因為它一層層疊加的結構，讓它有辦法可以做各種不同細緻度的風格轉換。

In [None]:
def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges):
    print(png)
    src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
    dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)
    src_dlatents = Gs.components.mapping.run(src_latents, None)  # [seed, layer, component]
    dst_dlatents = Gs.components.mapping.run(dst_latents, None)  # [seed, layer, component]
    src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
    dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)

    canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white')
    for col, src_image in enumerate(list(src_images)):
        canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))
    for row, dst_image in enumerate(list(dst_images)):
        canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))
        row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))
        row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]
        row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
        for col, image in enumerate(list(row_images)):
            canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))
    canvas.save(png)

synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
                        minibatch_size=8)

result_dir = 'results'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

draw_style_mixing_figure(os.path.join(result_dir, 'style-mixing-human8.png'), Gs, w=1024, h=1024,
                         src_seeds=[639, 701, 687, 615, 2268],
                         dst_seeds=[888, 829, 1898, 1733, 1614, 845, 1450, 2266],
                         style_ranges=[range(0, 4)]*2+[range(4, 8)]*2+[range(8, 12)]*2+[range(12, 16)]*2)

### 轉換結果
下圖的第一個 row 是隨機產生的來源圖片(source image)，第一個 column 是隨機產生的目標圖片(destination image)，透過將目標圖片的部分層的「中層潛在向量」(intermediate latent vector)替換成來源圖片的向量層，就可以達到變換風格的效果。下面的例子是以兩個 row 為一單位，每單位分別是變換第 0-3 層的向量、4-7 層...到第 15 層，每四個層去取代的結果，可以發現前幾層的改變幅度很大，會把整個臉型跟面向都改成另一個風格，然而後面幾層可能開始只改變五官、到最後幾層只改變整個色調細節而已。

In [None]:
mix_img = cv2.imread('results/style-mixing-human8.png')
plt.figure(figsize=(15, 25))
plt.imshow(cv2.cvtColor(mix_img, cv2.COLOR_BGR2RGB))
plt.show()

# StyleGAN in Anime dataset (by Gwern)

看到 Nvidia 釋出如此強大的模型，各路大神也紛紛來試玩看看，而這位 Gwern 用爬蟲抓了一堆動漫的角色圖，前處理後丟進模型訓練，下面是他釋出的pre-train weight，有興趣的學員也可以玩玩看，連結如下:https://www.gwern.net/Faces#

In [None]:
url = 'cache/2019-04-30-stylegan-danbooru2018-portraits-02095-066083_mod.pkl'

tflib.init_tf()
with open(url, 'rb') as f:
    _G, _D, Gs = pickle.load(f)

In [None]:
# 產生 0 向量
latents = np.zeros((1, Gs.input_shape[1]))

In [None]:
# Generate image.
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.5, randomize_noise=True, output_transform=fmt)

plt.figure(figsize=(10,10))
plt.imshow(images[-1])
plt.show()

In [None]:
save_path = './exp_img_anim/lat_255' # 向量總共有 512 維，如果想要改變第255維就修改成 lat_255

plt.ioff() # 用這個 method 就能不要把圖 plot 出來

if not os.path.exists(save_path):
    os.makedirs(save_path)


ind = int(save_path.split('_')[-1])

for i in np.arange(-0.004,0.0041,0.0001):
    latents[0][ind] = i
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=False, output_transform=fmt)

    plt.figure(figsize=(5, 5))
    plt.imshow(images[-1])

    plt.savefig(os.path.join(save_path, 'image_{:03f}.png'.format(i)))

#     plt.show()
    plt.close()

In [None]:
# 使用imageio製作gif圖
anim_file = save_path + '/anim.gif'

with imageio.get_writer(anim_file, mode='I') as writer:

    filenames = glob.glob(save_path + '/image*.png')
    filenames.sort(key=lambda x: os.path.getmtime(x))

    last = -1
    for i, filename in enumerate(filenames):
        frame = 4*(i**0.5)
        if round(frame) > round(last):
            last = frame
        else:
            continue
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

display(Image(filename=anim_file))

In [None]:
display(Image(filename='./exp_img_anim/lat_0/anim.gif'))

# Style mixing

這部分與上面的人臉類似，就不在贅述

In [None]:
# sample a vector from Normal distribution
rnd = np.random.RandomState(0)
latents = rnd.randn(1, Gs.input_shape[1])

In [None]:
def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges):
    print(png)
    src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
    dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)
    src_dlatents = Gs.components.mapping.run(src_latents, None)  # [seed, layer, component]
    dst_dlatents = Gs.components.mapping.run(dst_latents, None)  # [seed, layer, component]
    src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
    dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)

    canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white')
    for col, src_image in enumerate(list(src_images)):
        canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))
    for row, dst_image in enumerate(list(dst_images)):
        canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))
        row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))
        row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]
        row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
        for col, image in enumerate(list(row_images)):
            canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))
    canvas.save(png)

synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
                        minibatch_size=8)

draw_style_mixing_figure(os.path.join(result_dir, 'style-mixing-anim_8.png'), Gs, w=512, h=512,
                         src_seeds=[639, 701, 687, 615, 2268],
                         dst_seeds=[888, 829, 1898, 1733, 1614, 845, 1450, 2266],
                         style_ranges=[range(0, 4)]*2+[range(4, 8)]*2+[range(8, 12)]*2+[range(12, 16)]*2)

In [None]:
mix_img = cv2.imread('results/style-mixing-anim_8.png')
plt.figure(figsize=(16, 24))
plt.imshow(cv2.cvtColor(mix_img, cv2.COLOR_BGR2RGB))
plt.show()