[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JuhongPark/HairFastGAN/blob/app/app/MirrorMirrorGAN_Application.ipynb)

In [None]:
# @title Model Load: Colab T4 GPU 기준 12분 가량 소요
# Env Setting
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
!rm ninja-linux.zip
!rm -rf sample_data

!pip install diffusers==0.11.1
!pip install transformers scipy ftfy accelerate
!pip install jax==0.4.23 jaxlib==0.4.23
!pip install googletrans==4.0.0-rc1
!pip install pillow==10.0.0 face_alignment dill==0.2.7.1 addict fpie git+https://github.com/openai/CLIP.git -q


# Import
from pathlib import Path
from io import BytesIO
from PIL import Image
from functools import cache
from diffusers import StableDiffusionPipeline
from googletrans import Translator
from concurrent.futures import ProcessPoolExecutor
import os
import sys
import argparse
import requests
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torchvision.transforms as T


# Model load
!git clone https://github.com/AIRI-Institute/HairFastGAN
%cd HairFastGAN

def install_packages():
    !pip install pillow==10.0.0 face_alignment dill==0.2.7.1 addict fpie \
      git+https://github.com/openai/CLIP.git -q

def download_models():
    !git clone https://huggingface.co/AIRI-Institute/HairFastGAN
    !cd HairFastGAN && git lfs pull && cd ..
    !mv HairFastGAN/pretrained_models pretrained_models
    !mv HairFastGAN/input input
    !rm -rf HairFastGAN

with ProcessPoolExecutor() as executor:
    executor.submit(install_packages)
    executor.submit(download_models)

from hair_swap import HairFast, get_parser
from models.Blending import Blending

model_args = get_parser()
hair_fast = HairFast(model_args.parse_args([]))

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

translator = Translator()


################# Function #################
%matplotlib inline

def to_tuple(func):
    def wrapper(arg):
        if isinstance(arg, list):
            arg = tuple(arg)
        return func(arg)
    return wrapper

@to_tuple
@cache
def download_and_convert_to_pil(urls):
    pil_images = []
    for url in urls:
        response = requests.get(url, allow_redirects=True, headers={"User-Agent": "Mozilla/5.0"})
        img = Image.open(BytesIO(response.content))
        pil_images.append(img)
        print(f"Downloaded an image of size {img.size}")
    return pil_images

def display_images(images=None, **kwargs):
    is_titles = images is None
    images = images or kwargs

    grid = gridspec.GridSpec(1, len(images))
    fig = plt.figure(figsize=(20, 10))

    for i, item in enumerate(images.items() if is_titles else images):
        title, img = item if is_titles else (None, item)

        img = T.functional.to_pil_image(img) if isinstance(img, torch.Tensor) else img
        img = Image.open(img) if isinstance(img, str | Path) else img

        ax = fig.add_subplot(1, len(images), i+1)
        ax.imshow(img)
        if title:
            ax.set_title(title, fontsize=20)
        ax.axis('off')

    plt.show()

is_any_url = False

def try_download_image(url):
    is_any_url = True
    try:
        return download_and_convert_to_pil([url])[0]
    except Exception as e:
        print(f"Can't download the image from the link {url}")
        print(e)
        return False

def convert_input(inp):
    if not inp.startswith('http'):
        path = os.path.join(input_path, inp)
        try:
            if os.path.isfile(path):
                path_to_imgs[path] = Image.open(path)
                return path_to_imgs[path]
        except Exception as e:
            print(f"Can't open the image {inp}")
            print(e)
            return False
    else:
        return try_download_image(inp)

if 'hair_fast_instans' not in globals():
    if 'hair_fast' in globals():
        hair_fast_instans = {'Default': hair_fast}
    else:
        model_args = get_parser()
        hair_fast = HairFast(model_args.parse_args([]))
        hair_fast_instans = {'Default': hair_fast}

if 'path_to_imgs' not in globals():
    path_to_imgs = {}

Blending_checkpoint = "Default"
Alignment_images = "Auto"
Poisson_Blending = "Off"
Poissons_iters = 115
Poisson_erossion = 15

if Blending_checkpoint not in hair_fast_instans:
    if Blending_checkpoint == 'Alternative_v1':
        new_args = model_args.parse_args(['--blending_checkpoint', 'pretrained_models/Blending/checkpoint_old.pth'])
    elif Blending_checkpoint == 'Alternative_v2':
        new_args = model_args.parse_args(['--blending_checkpoint', 'pretrained_models/Blending/checkpoint_old2.pth'])
    else:
        raise ValueError(f'{Blending_checkpoint} not exist')

    hair_fast_ = copy(hair_fast)
    hair_fast_.blend = Blending(new_args, net=hair_fast_.net)
    hair_fast_instans[Blending_checkpoint] = hair_fast_

def get_image_from_en(prompt_en, path='img'):
    input_path = "/content/HairFastGAN/input"

    image = pipe(prompt_en).images[0]

    image.save(f"{input_path}/{path}.png")

    return image

def get_image_from_korean(style_kr, path='img'):
    style_en = translator.translate(style_kr, src='ko', dest='en').text

    prompt_en = f'A photograph of a person with a {style_en} hairstyle, featuring prominently styled hair, looking directly into the camera with a neutral expression, similar to a Korean passport photo. The face is centered and fully visible, with ample space above the head, ensuring no part of the hairstyle is cut off. The background is plain and neutral-colored.'

    image = get_image_from_en(prompt_en, path)

    return image

In [None]:
# @title Hair Generative AI
# @markdown Face 사진 업로드
from google.colab import files
from PIL import Image
import io
import os
import ipywidgets as widgets
from IPython.display import display

save_dir = '/content/HairFastGAN/input'

# 파일 업로드 및 저장 함수 정의
def upload_and_save(selected_name):
    uploaded = files.upload()

    for original_file_name in uploaded.keys():
        print(f"업로드된 {selected_name} 파일 이름: {original_file_name}")
        image = Image.open(io.BytesIO(uploaded[original_file_name]))
        if image.mode != 'RGB' and image.mode != 'RGBA':
            image = image.convert('RGB')
        save_path = os.path.join(save_dir, f"{selected_name}.png")
        image.save(save_path)
        print('')
        print('*' * 45)
        print(f"아래 이미지가 {selected_name} 이미지로 저장되었습니다.")
        display(image)

# 버튼 클릭 시 호출되는 함수 정의
def on_button_click(selected_name):
    upload_and_save(selected_name)

# 버튼 생성 및 클릭 이벤트 설정
face_button = widgets.Button(description="Face 사진 업로드")

face_button.on_click(lambda b: on_button_click('Face'))

# 버튼 표시
display(face_button)

In [None]:
# @markdown Shape 사진 업로드
# 버튼 생성 및 클릭 이벤트 설정
shape_button = widgets.Button(description="Shape 사진 업로드")

shape_button.on_click(lambda b: on_button_click('Shape'))

# 버튼 표시
display(shape_button)

In [None]:
# @markdown Color 사진 업로드
# 버튼 생성 및 클릭 이벤트 설정
color_button = widgets.Button(description="Color 사진 업로드")

color_button.on_click(lambda b: on_button_click('Color'))

# 버튼 표시
display(color_button)

In [None]:
# @markdown Hair 이미지 생성 - 옵션 메세지 입력 (미입력시, 위에서 업로드된 사진 사용)
input_path = "/content/HairFastGAN/input"
Face = "Face.png"

Shape =  ""                           # @param {type:"string"}
if Shape != "":
    print(f'{Shape}의 Shape 이미지를 생성합니다.')
    get_image_from_korean(Shape, 'Shape')
Shape = "Shape.png"

Color = ""                              # @param {type:"string"}
if Color !="":
    print(f'{Color}의 Color 이미지를 생성합니다.')
    get_image_from_korean(Color, 'Color')
Color = "Color.png"


converted_inputs = list(map(convert_input, (Face, Shape, Color)))
need_alignment = any(map(lambda img: img.size != (1024, 1024), converted_inputs))

if Alignment_images == 'On' or Alignment_images == 'Auto' and (need_alignment or is_any_url):
    print('Hair 이미지 생성을 시작합니다.', file=sys.stderr)
    result_image, *converted_inputs = hair_fast_instans[Blending_checkpoint](*converted_inputs, align=True)
else:
    result_image = hair_fast_instans[Blending_checkpoint](*converted_inputs)

face_obj, shape_obj, color_obj = converted_inputs

if Poisson_Blending == 'On':
    print('Start poisson blending', file=sys.stderr)
    result_image, _ = poisson_image_blending(result_image, face_obj, dilate_erosion=Poisson_erossion, maxn=Poissons_iters)

print('Hair 이미지 생성을 완료했습니다.', file=sys.stderr)
display_images(face=face_obj, shape=shape_obj, color=color_obj, result=result_image)