Skip to content

Commit

Permalink
Merge pull request #1 from CJWBW/cog
Browse files Browse the repository at this point in the history
Add Docker environment & web demo
  • Loading branch information
avivga committed Oct 13, 2021
2 parents 883ec8e + 90a4087 commit 8d24a95
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
> **Abstract:** Image translation methods typically aim to manipulate a set of labeled attributes (given as supervision at training time e.g. domain label) while leaving the unlabeled attributes intact. Current methods achieve either: (i) disentanglement, which exhibits low visual fidelity and can only be satisfied where the attributes are perfectly uncorrelated. (ii) visually-plausible translations, which are clearly not disentangled. In this work, we propose OverLORD, a single framework for disentangling labeled and unlabeled attributes as well as synthesizing high-fidelity images, which is composed of two stages; (i) Disentanglement: Learning disentangled representations with latent optimization. Differently from previous approaches, we do not rely on adversarial training or any architectural biases. (ii) Synthesis: Training feed-forward encoders for inferring the learned attributes and tuning the generator in an adversarial manner to increase the perceptual quality. When the labeled and unlabeled attributes are correlated, we model an additional representation that accounts for the correlated attributes and improves disentanglement. We highlight that our flexible framework covers multiple settings as disentangling labeled attributes, pose and appearance, localized concepts, and shape and texture. We present significantly better disentanglement with higher translation quality and greater output diversity than state-of-the-art methods.

<a href="https://arxiv.org/abs/2103.14017" target="_blank"><img src="https://img.shields.io/badge/arXiv-2103.14017-b31b1b.svg"></a>
<a href="https://arxiv.org/abs/2103.14017" target="_blank"><img src="https://img.shields.io/badge/arXiv-2103.14017-b31b1b.svg"></a> <a href="https://replicate.ai/avivga/overlord"><img src="https://img.shields.io/static/v1?label=Replicate&message=Demo and Docker Image&color=blue"></a>

## Description
A framework for high-fidelity disentanglement of labeled and unlabeled attributes. We support two general cases: (i) The labeled and unlabeled attributes are *approximately uncorrelated*. (ii) The labeled and unlabeled attributes are *correlated*. For this case, we suggest simple forms of transformations for learning *pose-independent* or *localized* correlated attributes, by which we achieve better disentanglement both quantitatively and qualitatively than state-of-the-art methods.
Expand Down
22 changes: 22 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
build:
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
- "ninja-build"
python_packages:
- "cmake==3.21.3"
- "torch==1.7.1"
- "torchvision==0.8.2"
- "numpy==1.20.1"
- "ipython==7.21.0"
- "Pillow==8.3.1"
- "imageio==2.9.0"
- "tqdm==4.62.3"
- "tensorboard==2.6.0"
- "scipy==1.7.1"
run:
- pip install dlib

predict: "predict.py:Predictor"
97 changes: 97 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import sys
import os
sys.path.insert(0, "stylegan2-pytorch")
sys.path.insert(0, "stylegan-encoder")
import tempfile
from pathlib import Path
import argparse
import imageio
import dlib
import PIL.Image
import numpy as np
import cog
from network.training import Model
from ffhq_dataset.face_alignment import image_align
from ffhq_dataset.landmarks_detector import LandmarksDetector


TARGET_AGE = {
"All": 0,
"0-9": 1,
"10-19": 2,
"20-29": 3,
"30-39": 4,
"40-49": 5,
"50-59": 6,
"60-69": 7,
"70-79": 8,
}

PREDICTOR = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
LANDMARKS_DETECTOR = LandmarksDetector("shape_predictor_68_face_landmarks.dat")


class Predictor(cog.Predictor):
def setup(self):

manipulate_parser = argparse.ArgumentParser()
manipulate_parser.add_argument("-bd", "--base-dir", type=str, default=".")
manipulate_parser.add_argument(
"-mn", "--model-name", type=str, default="overlord-ffhq-x256-age"
)
manipulate_parser.add_argument("-i", "--img-path", type=str)
manipulate_parser.add_argument(
"-r", "--reference-img-path", type=str, required=False
)
manipulate_parser.add_argument("-o", "--output-img-path", type=str, default="")
self.args, self.extras = manipulate_parser.parse_known_args()

@cog.input(
"image",
type=Path,
help="input facial image. NOTE: image will be aligned and resized to 256*256",
)
@cog.input(
"target_age",
type=str,
options=list(TARGET_AGE.keys()),
default="All",
help="output age",
)
def predict(self, image, target_age="All"):
if os.path.isfile('aligned.png'):
os.remove('aligned.png')
if os.path.isfile('rgb_input.png'):
os.remove('rgb_input.png')

input_path = str(image)
# webcam input might be rgba, convert to rgb first
input = imageio.imread(input_path)
if input.shape[-1] == 4:
rgba_image = PIL.Image.open(input_path)
rgb_image = rgba_image.convert('RGB')
input_path = 'rgb_input.png'
imageio.imwrite(input_path, rgb_image)

out_path = Path(tempfile.mkdtemp()) / "out.png"
self.args.output_img_path = str(out_path)
model_dir = "overlord-ffhq-x256-age"
model = Model.load(model_dir)
align_image(input_path, 'aligned.png')
img = PIL.Image.open('aligned.png')
img = np.array(img.resize((256, 256)))
manipulated_imgs = model.manipulate_by_labels(img)
manipulated_imgs_res = np.split(manipulated_imgs, 9, axis=1)
if target_age == "All":
res = np.concatenate(manipulated_imgs_res[1:], axis=0)

else:
res = manipulated_imgs_res[TARGET_AGE[target_age]]

imageio.imwrite(self.args.output_img_path, res)
return out_path


def align_image(raw_img_path, aligned_face_path):
for i, face_landmarks in enumerate(LANDMARKS_DETECTOR.get_landmarks(raw_img_path), start=1):
image_align(raw_img_path, aligned_face_path, face_landmarks)

0 comments on commit 8d24a95

Please sign in to comment.