Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Docker environment & web demo #1

Merged
merged 1 commit into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

> **Abstract:** Image translation methods typically aim to manipulate a set of labeled attributes (given as supervision at training time e.g. domain label) while leaving the unlabeled attributes intact. Current methods achieve either: (i) disentanglement, which exhibits low visual fidelity and can only be satisfied where the attributes are perfectly uncorrelated. (ii) visually-plausible translations, which are clearly not disentangled. In this work, we propose OverLORD, a single framework for disentangling labeled and unlabeled attributes as well as synthesizing high-fidelity images, which is composed of two stages; (i) Disentanglement: Learning disentangled representations with latent optimization. Differently from previous approaches, we do not rely on adversarial training or any architectural biases. (ii) Synthesis: Training feed-forward encoders for inferring the learned attributes and tuning the generator in an adversarial manner to increase the perceptual quality. When the labeled and unlabeled attributes are correlated, we model an additional representation that accounts for the correlated attributes and improves disentanglement. We highlight that our flexible framework covers multiple settings as disentangling labeled attributes, pose and appearance, localized concepts, and shape and texture. We present significantly better disentanglement with higher translation quality and greater output diversity than state-of-the-art methods.

<a href="https://arxiv.org/abs/2103.14017" target="_blank"><img src="https://img.shields.io/badge/arXiv-2103.14017-b31b1b.svg"></a>
<a href="https://arxiv.org/abs/2103.14017" target="_blank"><img src="https://img.shields.io/badge/arXiv-2103.14017-b31b1b.svg"></a> <a href="https://replicate.ai/avivga/overlord"><img src="https://img.shields.io/static/v1?label=Replicate&message=Demo and Docker Image&color=blue"></a>

## Description
A framework for high-fidelity disentanglement of labeled and unlabeled attributes. We support two general cases: (i) The labeled and unlabeled attributes are *approximately uncorrelated*. (ii) The labeled and unlabeled attributes are *correlated*. For this case, we suggest simple forms of transformations for learning *pose-independent* or *localized* correlated attributes, by which we achieve better disentanglement both quantitatively and qualitatively than state-of-the-art methods.
Expand Down
22 changes: 22 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
build:
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
- "ninja-build"
python_packages:
- "cmake==3.21.3"
- "torch==1.7.1"
- "torchvision==0.8.2"
- "numpy==1.20.1"
- "ipython==7.21.0"
- "Pillow==8.3.1"
- "imageio==2.9.0"
- "tqdm==4.62.3"
- "tensorboard==2.6.0"
- "scipy==1.7.1"
run:
- pip install dlib

predict: "predict.py:Predictor"
97 changes: 97 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import sys
import os
sys.path.insert(0, "stylegan2-pytorch")
sys.path.insert(0, "stylegan-encoder")
import tempfile
from pathlib import Path
import argparse
import imageio
import dlib
import PIL.Image
import numpy as np
import cog
from network.training import Model
from ffhq_dataset.face_alignment import image_align
from ffhq_dataset.landmarks_detector import LandmarksDetector


TARGET_AGE = {
"All": 0,
"0-9": 1,
"10-19": 2,
"20-29": 3,
"30-39": 4,
"40-49": 5,
"50-59": 6,
"60-69": 7,
"70-79": 8,
}

PREDICTOR = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
LANDMARKS_DETECTOR = LandmarksDetector("shape_predictor_68_face_landmarks.dat")


class Predictor(cog.Predictor):
def setup(self):

manipulate_parser = argparse.ArgumentParser()
manipulate_parser.add_argument("-bd", "--base-dir", type=str, default=".")
manipulate_parser.add_argument(
"-mn", "--model-name", type=str, default="overlord-ffhq-x256-age"
)
manipulate_parser.add_argument("-i", "--img-path", type=str)
manipulate_parser.add_argument(
"-r", "--reference-img-path", type=str, required=False
)
manipulate_parser.add_argument("-o", "--output-img-path", type=str, default="")
self.args, self.extras = manipulate_parser.parse_known_args()

@cog.input(
"image",
type=Path,
help="input facial image. NOTE: image will be aligned and resized to 256*256",
)
@cog.input(
"target_age",
type=str,
options=list(TARGET_AGE.keys()),
default="All",
help="output age",
)
def predict(self, image, target_age="All"):
if os.path.isfile('aligned.png'):
os.remove('aligned.png')
if os.path.isfile('rgb_input.png'):
os.remove('rgb_input.png')

input_path = str(image)
# webcam input might be rgba, convert to rgb first
input = imageio.imread(input_path)
if input.shape[-1] == 4:
rgba_image = PIL.Image.open(input_path)
rgb_image = rgba_image.convert('RGB')
input_path = 'rgb_input.png'
imageio.imwrite(input_path, rgb_image)

out_path = Path(tempfile.mkdtemp()) / "out.png"
self.args.output_img_path = str(out_path)
model_dir = "overlord-ffhq-x256-age"
model = Model.load(model_dir)
align_image(input_path, 'aligned.png')
img = PIL.Image.open('aligned.png')
img = np.array(img.resize((256, 256)))
manipulated_imgs = model.manipulate_by_labels(img)
manipulated_imgs_res = np.split(manipulated_imgs, 9, axis=1)
if target_age == "All":
res = np.concatenate(manipulated_imgs_res[1:], axis=0)

else:
res = manipulated_imgs_res[TARGET_AGE[target_age]]

imageio.imwrite(self.args.output_img_path, res)
return out_path


def align_image(raw_img_path, aligned_face_path):
for i, face_landmarks in enumerate(LANDMARKS_DETECTOR.get_landmarks(raw_img_path), start=1):
image_align(raw_img_path, aligned_face_path, face_landmarks)