-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
140 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#!/usr/bin/env python | ||
|
||
import os | ||
import glob | ||
import numpy as np | ||
|
||
import chainer | ||
import chainer.cuda | ||
from chainer import cuda, serializers, Variable | ||
from chainer import training | ||
import chainer.functions as F | ||
import cv2 | ||
import argparse | ||
import common.net as net | ||
import datasets | ||
from PIL import Image | ||
from utils import save_images_grid | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='CycleGAN model testing script') | ||
parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') | ||
parser.add_argument('--gen_class', '-c', default='Generator_ResBlock_9', help='Default generator class') | ||
parser.add_argument("--load_gen_f_model", default='', help='load generator model') | ||
parser.add_argument("--load_gen_g_model", default='', help='load generator model') | ||
parser.add_argument('--direction', '-d', type=int, default=1, help='direction: 0 for G(X), 1 for F(Y)') | ||
parser.add_argument('--input_channels', type=int, default=3, help='number of input channels') | ||
parser.add_argument('--rows', type=int, default=5, help='rows') | ||
parser.add_argument('--cols', type=int, default=5, help='cols') | ||
parser.add_argument('--eval_folder', '-e', default='evaldata', help='directory to output the evaluation result') | ||
parser.add_argument('--out', '-o', default='output' ,help='saved file name') | ||
parser.add_argument("--resize_to", type=int, default=256, help='resize the image to') | ||
parser.add_argument("--crop_to", type=int, default=256, help='crop the resized image to') | ||
parser.add_argument("--load_dataset", default='silverhair_train', help='load dataset') | ||
parser.add_argument("--recurrent", type=int, default=1, help='apply the function recursively') | ||
|
||
args = parser.parse_args() | ||
print(args) | ||
|
||
if args.gpu >= 0: | ||
chainer.cuda.get_device(args.gpu).use() | ||
|
||
if not os.path.exists(args.eval_folder): | ||
os.makedirs(args.eval_folder) | ||
|
||
gen_g = getattr(net, args.gen_class)() | ||
gen_f = getattr(net, args.gen_class)() | ||
|
||
if args.load_gen_g_model != '': | ||
serializers.load_npz(args.load_gen_g_model, gen_g) | ||
print("Generator G model loaded") | ||
|
||
if args.load_gen_f_model != '': | ||
serializers.load_npz(args.load_gen_f_model, gen_f) | ||
print("Generator F model loaded") | ||
|
||
if args.gpu >= 0: | ||
gen_g.to_gpu() | ||
gen_f.to_gpu() | ||
print("use gpu {}".format(args.gpu)) | ||
|
||
test_dataset = getattr(datasets, args.load_dataset)(flip=0, resize_to=args.resize_to, crop_to=args.crop_to) | ||
|
||
cnt = args.rows * args.cols | ||
xp = gen_g.xp | ||
|
||
input = xp.zeros((cnt, args.input_channels, args.crop_to, args.crop_to)).astype("f") | ||
|
||
for i in range(0, args.rows): | ||
for j in range(0,args.cols): | ||
x, y = test_dataset.get_example(0) | ||
if args.direction == 1: | ||
input[i*args.cols + j, :] = xp.asarray(y) | ||
else: | ||
input[i*args.cols + j, :] = xp.asarray(x) | ||
|
||
input = input | ||
save_images_grid(input,path=args.eval_folder+"/"+args.out+".0.jpg", grid_w=args.rows, grid_h=args.cols) | ||
|
||
for i in range(args.recurrent): | ||
if args.direction == 1: | ||
output = gen_f(input, volatile=True) | ||
else: | ||
output = gen_g(input, volatile=True) | ||
del input | ||
save_images_grid(output,path=args.eval_folder+"/"+args.out+"."+str(i+1)+".jpg", grid_w=args.rows, grid_h=args.cols) | ||
output.unchain_backward() | ||
input = output.data |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .xdog import XDoG | ||
from .save_images import save_images_grid | ||
from .save_images import copy_to_cpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import cv2 | ||
#from PIL import Image | ||
import numpy as np | ||
from chainer import cuda | ||
import chainer | ||
import cupy | ||
import os | ||
|
||
|
||
def copy_to_cpu(imgs): | ||
if type(imgs) == chainer.variable.Variable : | ||
imgs = imgs.data | ||
if type(imgs) == cupy.core.core.ndarray: | ||
imgs = cuda.to_cpu(imgs) | ||
return imgs | ||
|
||
def post_processing_tanh(imgs): | ||
imgs = (imgs + 1) *127.5 | ||
imgs = np.clip(imgs, 0, 255) | ||
imgs = imgs.astype(np.uint8) | ||
return imgs | ||
|
||
# Input imgs format: (batch, channels, width, height) | ||
def save_images_grid(imgs, path, grid_w=4, grid_h=4, post_processing=post_processing_tanh): | ||
imgs = copy_to_cpu(imgs) | ||
if post_processing is not None: | ||
imgs = post_processing(imgs) | ||
b, ch, w, h = imgs.shape | ||
assert b == grid_w*grid_h | ||
|
||
imgs = imgs.reshape((grid_w, grid_h, ch, w, h)) | ||
imgs = imgs.transpose(0, 1, 3, 4, 2) | ||
imgs = imgs.reshape((grid_w, grid_h, w, h, ch)).transpose(0, 2, 1, 3, 4).reshape((grid_w*w, grid_h*h, ch)) | ||
if ch==1: | ||
imgs = imgs.reshape((grid_w*w, grid_h*h)) | ||
cv2.imwrite(path, imgs) | ||
#Image.fromarray(imgs[:,:,::-1]).save(path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters