# StarGAN

### Image-to-Image translation samples

| Hair Color | Age | Hair Color + Gender |
|:----:|:----:|:----:|
|  ![black haired female](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/sample_black_haired_female.png "black haired female")  ![blond haired female](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/sample_blond_haired_female.png "blond haired female")<br>&ensp;&ensp;**Black Hair** to **Blond Hair**&ensp;&ensp; | ![black haired young male](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/sample_young_man.png "young male")  ![aged male](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/sample_aged_man.png "aged male") <br>&ensp;&ensp;**Young** to **Aged**&ensp;&ensp; | ![black haired male](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/sample_black_haired_male.png "black haired male")  ![blond haired female](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/sample_blond_haired_female2.png "blond haired female") <br>**Black Hair/Male** to **Blond Hair/Female**|

### From one domain to multi-domains

|  Source<br>Black-Haired<br>Young Female |  Generated 1<br>Blond-Haired<br>Young Female  |  Generated 2<br>Brown-Haired<br>Young Female  |  Generated 3<br>Black-Haired<br>Young Male  |  Generated 4<br>Black-Haired<br>Aged Female |
| ---- | ---- | ---- | ---- | ---- |
|  ![black haired female](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/source_Black_Hair_Young.png "black haired young female")  |  ![blond haired female](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/generated_Blond_Hair_Young.png "blond haired young female")  |  ![brown haired female](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/generated_Brown_Hair_Young.png "brown haired young female")  |  ![black haired male](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/generated_Black_Hair_Male_Young.png "black haired young male")  |  ![black haired aged female](https://github.com/sony/nnabla-examples/raw/master/GANs/stargan/imgs/generated_Black_Hair.png "black haired aged female")  |

Here, we run [StarGAN](https://arxiv.org/abs/1711.09020), an Image-to-Image translation model based on [CycleGAN](https://arxiv.org/abs/1703.10593). Unlike CycleGAN, which translates images from one domain to another, StarGAN can translate from one domain to multiple domains with *one single model*. In this demo, we take a look at how it works.

# Preparation
Let's start by installing nnabla and accessing [nnabla-examples repository](https://github.com/sony/nnabla-examples). If you're running on Colab, make sure that your Runtime setting is set as GPU, which can be set up from the top menu (Runtime → change runtime type), and make sure to click **Connect** on the top right-hand side of the screen before you start.

In [None]:
!pip install nnabla-ext-cuda100
!git clone https://github.com/sony/nnabla-examples.git
%run nnabla-examples/interactive-demos/colab_utils.py
%cd nnabla-examples/GANs/stargan

We need to download the pretrained weight and required config file.

In [None]:
# get StarGAN pretrained weights.
!wget https://nnabla.org/pretrained-models/nnabla-examples/GANs/stargan/pretrained_params_on_celebA.h5

# get StarGAN config file.
!wget https://nnabla.org/pretrained-models/nnabla-examples/GANs/stargan/pretrained_conf_on_celebA.json

# Upload an image

Run the below cell to upload an image to use. Make sure to select just 1 image (if you upload multiple images, all the images but the last one will be ignored) and that image must contain **one** face.

In [None]:
from google.colab import files

img = files.upload()

For convenience, rename the image file. Also, let's check the input image here.

In [None]:
import os
ext = os.path.splitext(list(img.keys())[-1])[-1]
os.rename(list(img.keys())[-1], "input_image{}".format(ext)) 

input_img = "input_image" + ext

from IPython.display import Image,display
display(Image(input_img))

Since the model expects the input images to contain a facial region only, We need to crop the image. This time we use dlib for face detection and cropping. First, we'll download the required dlib weights.

In [None]:
# get dlib's face detection model.
!wget http://dlib.net/files/mmod_human_face_detector.dat.bz2
!bzip2 -d mmod_human_face_detector.dat.bz2

And then import some dependency.

In [None]:
import cv2
import dlib
import numpy as np
from skimage import io, color
import matplotlib.pyplot as plt

Now using dlib we'll detect the face in the image.

In [None]:
image = io.imread(input_img)
if image.ndim == 2:
    image = color.gray2rgb(image)
elif image.shape[-1] == 4:
    image = image[..., :3]

face_detector = dlib.cnn_face_detection_model_v1("mmod_human_face_detector.dat")
detected_faces = face_detector(cv2.cvtColor(image[..., ::-1].copy(), cv2.COLOR_BGR2GRAY))
detected_faces = [[d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom()] for d in detected_faces]

assert len(detected_faces) == 1, "Warning: only one face should be contained."
detected_faces = detected_faces[0]

Here, with some scripts, we extract the facial region only. These scripts are from [FAN example](https://github.com/sony/nnabla-examples/tree/master/facial-keypoint-detection/face-alignment), but partially modified.

In [None]:
def transform(point, center, scale, resolution, invert=False):
    """Generate and affine transformation matrix.
    Given a set of points, a center, a scale and a target resolution, the
    function generates and affine transformation matrix. If invert is ``True``
    it will produce the inverse transformation.
    Arguments:
        point {numpy.array} -- the input 2D point
        center {numpy.array} -- the center around which to perform the transformations
        scale {float} -- the scale of the face/object
        resolution {float} -- the output resolution
    Keyword Arguments:
        invert {bool} -- define wherever the function should produce the direct or the
        inverse transformation matrix (default: {False})
    """
    point.append(1)

    h = 200.0 * scale
    t = np.eye(3)
    t[0, 0] = resolution / h
    t[1, 1] = resolution / h
    t[0, 2] = resolution * (-center[0] / h + 0.5)
    t[1, 2] = resolution * (-center[1] / h + 0.5)

    if invert:
        t = np.reshape(np.linalg.inv(np.reshape(t, [1, 3, 3])), [3, 3])

    new_point = np.reshape(np.matmul(
        np.reshape(t, [1, 3, 3]), np.reshape(point, [1, 3, 1])), [3, ])[0:2]

    return new_point.astype(int)


def crop(image, center, scale, resolution=256):
    """Center crops an image or set of heatmaps
    Arguments:
        image {numpy.array} -- an rgb image
        center {numpy.array} -- the center of the object, usually the same as of the bounding box
        scale {float} -- scale of the face
    Keyword Arguments:
        resolution {float} -- the size of the output cropped image (default: {256.0})
    Returns:
        [type] -- [description]
    """  # Crop around the center point
    """ Crops the image around the center. Input is expected to be an np.ndarray """
    ul = transform([1, 1], center, scale, resolution, True)
    br = transform([resolution, resolution], center, scale, resolution, True)

    if image.ndim > 2:
        newDim = np.array([br[1] - ul[1], br[0] - ul[0],
                           image.shape[2]], dtype=np.int32)
        newImg = np.zeros(newDim, dtype=np.uint8)
    else:
        newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
        newImg = np.zeros(newDim, dtype=np.uint8)
    ht = image.shape[0]
    wd = image.shape[1]
    newX = np.array(
        [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
    newY = np.array(
        [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
    oldX = np.array(
        [int(max(1, ul[0] + 1)), int(min(br[0], wd))], dtype=np.int32)
    oldY = np.array(
        [int(max(1, ul[1] + 1)), int(min(br[1], ht))], dtype=np.int32)

    newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
           ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]

    newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
                        interpolation=cv2.INTER_LINEAR)
    return newImg

With the scripts above, crop the uploaded image.

In [None]:
center = [detected_faces[2] - (detected_faces[2] - detected_faces[0]) / 2.0,
          detected_faces[3] - (detected_faces[3] - detected_faces[1]) / 2.0]
#center[1] = center[1] - (detected_faces[3] - detected_faces[1]) * 0.12
scale = (detected_faces[2] - detected_faces[0] + detected_faces[3] - detected_faces[1]) / 195
inp = crop(image, center, scale, resolution=128)
plt.imshow(inp)

Now, save this cropped image and move it to the new directory named "source_img". This image will be used as an input to StarGAN.

In [None]:
import os
import shutil
io.imsave("cropped_image.png", inp)
source_dir = "source_img"
os.makedirs(source_dir, exist_ok=True)
shutil.move("cropped_image.png", f"source_img/input_image.png")

# Now run StarGAN!
Now that we prepared all the required files, let's run StarGAN. In this demo, 5 attributes, `Black Hair`, `Blond Hair`, `Brown Hair`, `Male` and `Young` can be chosen as *attributes*. 

For each attribute, you will be asked whether you want to *add* that attribute to the input image. You need to type `yes` or `no`.


For instance, first you will be asked whether or not to `use 'Black_Hair'`. Then if you type `yes`, the model will try to change the input's hair color to black (note that the model doesn't consider the original image's hair color, so it can be possible that the black haired person's hair can be modified even though it is already black). 


If you type `no`, that attribute will not be added. 


As for 4th and 5th attributes, `Male` and `Young`, if you don't want to use these attributes, the model will add the *opposite* attribute, in other words, `female` and `aged`.


Seeing is believing, just give it a shot!

In [None]:
!python generate.py --pretrained-params pretrained_params_on_celebA.h5 --config pretrained_conf_on_celebA.json --test-image-path source_img

In [None]:
import glob
generated_img = sorted(glob.glob(os.path.join("tmp.results/*.png")), key=os.path.getmtime)[-1]
#print(generated_img)
display(Image(generated_img))

# Try with your face
Now, if your machine has a webcam, Colab will access to that webcam and you can capture your own image. Similar to the previous demo, with some preprocessing such as face detection and cropping, you can try StarGAN on your face! Let's see what happens.
Executing the following cell will enable the camera. Just press `Capture` and captured image will be saved.

In [None]:
from IPython.display import Image
try:
    filename = take_photo(cam_width=256, cam_height=256)
    print('Saved to {}'.format(filename))
    # Show the image which was just taken.
    display(Image(filename))
except Exception as err:
    # Errors will be thrown if the user does not have a webcam or if they do not
    # grant the page permission to access it.
    print(str(err))

In the following cell, the captured image will be cropped.

In [None]:
image = io.imread("photo.png")
if image.ndim == 2:
    image = color.gray2rgb(image)
elif image.shape[-1] == 4:
    image = image[..., :3]

face_detector = dlib.cnn_face_detection_model_v1("mmod_human_face_detector.dat")
detected_faces = face_detector(cv2.cvtColor(image[..., ::-1].copy(), cv2.COLOR_BGR2GRAY))
detected_faces = [[d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom()] for d in detected_faces]

assert len(detected_faces) == 1, "Warning: only one face should be contained."
detected_faces = detected_faces[0]

center = [detected_faces[2] - (detected_faces[2] - detected_faces[0]) / 2.0,
          detected_faces[3] - (detected_faces[3] - detected_faces[1]) / 2.0]
#center[1] = center[1] - (detected_faces[3] - detected_faces[1]) * 0.12
scale = (detected_faces[2] - detected_faces[0] + detected_faces[3] - detected_faces[1]) / 195
inp = crop(image, center, scale, resolution=128)
plt.imshow(inp)

import os
import shutil
io.imsave("cropped_image.png", inp)
source_dir = "source_img"
os.makedirs(source_dir, exist_ok=True)
shutil.move("cropped_image.png", f"source_img/input_image.png")

If the images shown above looks OK, just execute the following cell! It will run `StarGAN`.

In [None]:
!python generate.py --pretrained-params pretrained_params_on_celebA.h5 --config pretrained_conf_on_celebA.json --test-image-path source_img

In [None]:
generated_img = sorted(glob.glob(os.path.join("tmp.results/*.png")), key=os.path.getmtime)[-1]
#print(generated_img)
display(Image(generated_img))