Skip to content

Commit

Permalink
Do not reload model on each image deepfakes#39
Browse files Browse the repository at this point in the history
  • Loading branch information
Clorr committed Jan 5, 2018
1 parent 0eb8527 commit 232e1ec
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 29 deletions.
9 changes: 4 additions & 5 deletions lib/cli.py
Expand Up @@ -46,18 +46,17 @@ def process_arguments(self, arguments):

self.images_found = len(self.input_dir)

self.process_directory()
self.process(self.read_directory)
self.finalize()

def process_directory(self):
def read_directory(self):
for filename in self.input_dir:
if self.arguments.verbose:
print('Processing: {}'.format(os.path.basename(filename)))

self.process_image(filename)
yield filename
self.images_processed = self.images_processed + 1

self.finalize()

# for now, we limit this class responsability to the read of files. images and faces are processed outside this class
def process_image(self, filename):
# implement your image processing!
Expand Down
32 changes: 20 additions & 12 deletions scripts/convert.py
Expand Up @@ -31,27 +31,35 @@ def add_optional_arguments(self, parser):
help="Swap the model. Instead of A -> B, swap B -> A.")
return parser

def process_image(self, filename):
model_name = "GAN" # Original # GAN
conv_name = "GAN" # Adjust, Masked # GAN
def process(self, reader):
# Original model goes with Adjust or Masked converter
# GAN converter & model must go together
model_name = "GAN" # TODO Pass as argument
conv_name = "GAN" # TODO Pass as argument

if conv_name.startswith("GAN"):
assert model_name.startswith("GAN") is True, "GAN converter can only be used with GAN model!"
else:
assert model_name.startswith("GAN") is False, "GAN model can only be used with GAN converter!"

model = PluginLoader.get_model(model_name)(self.arguments.model_dir)
model.load(self.arguments.swap_model)

converter = PluginLoader.get_converter(conv_name)(model.converter(False))

try:
image = cv2.imread(filename)
for (idx, face) in enumerate(detect_faces(image)):
if idx > 0 and self.arguments.verbose:
print('- Found more than one face!')
self.verify_output = True
for filename in reader():
image = cv2.imread(filename)
for (idx, face) in enumerate(detect_faces(image)):
if idx > 0 and self.arguments.verbose:
print('- Found more than one face!')
self.verify_output = True

image = converter.patch_image(image, face)
self.faces_detected = self.faces_detected + 1
image = converter.patch_image(image, face)
self.faces_detected = self.faces_detected + 1

output_file = self.output_dir / Path(filename).name
cv2.imwrite(str(output_file), image)
output_file = self.output_dir / Path(filename).name
cv2.imwrite(str(output_file), image)
except Exception as e:
print('Failed to convert image: {}. Reason: {}'.format(filename, e))

Expand Down
24 changes: 13 additions & 11 deletions scripts/extract.py
Expand Up @@ -15,19 +15,21 @@ def create_parser(self, subparser, command, description):
https://github.com/deepfakes/faceswap-playground"
)

def process_image(self, filename):
extractor = PluginLoader.get_extractor("Align")()
def process(self, reader):
extractor_name = "Align" # TODO Pass as argument
extractor = PluginLoader.get_extractor(extractor_name)()

try:
image = cv2.imread(filename)
for (idx, face) in enumerate(detect_faces(image)):
if idx > 0 and self.arguments.verbose:
print('- Found more than one face!')
self.verify_output = True
for filename in reader():
image = cv2.imread(filename)
for (idx, face) in enumerate(detect_faces(image)):
if idx > 0 and self.arguments.verbose:
print('- Found more than one face!')
self.verify_output = True

resized_image = extractor.extract(image, face, 256)
output_file = self.output_dir / Path(filename).stem
cv2.imwrite(str(output_file) + str(idx) + Path(filename).suffix, resized_image)
self.faces_detected = self.faces_detected + 1
resized_image = extractor.extract(image, face, 256)
output_file = self.output_dir / Path(filename).stem
cv2.imwrite(str(output_file) + str(idx) + Path(filename).suffix, resized_image)
self.faces_detected = self.faces_detected + 1
except Exception as e:
print('Failed to extract from image: {}. Reason: {}'.format(filename, e))
2 changes: 1 addition & 1 deletion scripts/train.py
Expand Up @@ -76,7 +76,7 @@ def add_optional_arguments(self, parser):
return parser

def process(self):
variant = "GAN"
variant = "GAN" # TODO Pass as argument

print('Loading data, this may take a while...')
model = PluginLoader.get_model(variant)(self.arguments.model_dir)
Expand Down

0 comments on commit 232e1ec

Please sign in to comment.