Skip to content

Commit

Permalink
combine version
Browse files Browse the repository at this point in the history
  • Loading branch information
hytseng0509 committed Jul 3, 2018
1 parent fad7375 commit 9f352f5
Show file tree
Hide file tree
Showing 10 changed files with 921 additions and 158 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ datasets/
videos/
logs/
results/
outputs/
build/
dist/
*.png
Expand Down
33 changes: 33 additions & 0 deletions src/dataset_unpair.py → src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,38 @@
from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, RandomHorizontalFlip, ToTensor, Normalize
import random

class dataset_single(data.Dataset):
def __init__(self, opts, setname, input_dim):
self.dataroot = opts.dataroot
images = os.listdir(os.path.join(self.dataroot, opts.phase + setname))
self.img = [os.path.join(self.dataroot, opts.phase + setname, x) for x in images]
self.size = len(self.img)
self.input_dim = input_dim

# setup image transformation
transforms = [Resize(opts.resize_size, Image.BICUBIC)]
transforms.append(CenterCrop(opts.crop_size))
transforms.append(ToTensor())
transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
self.transforms = Compose(transforms)
print('%d images'%(self.size))
return

def __getitem__(self, index):
data = self.load_img(self.img[index], self.input_dim)
return data

def load_img(self, img_name, input_dim):
img = Image.open(img_name).convert('RGB')
img = self.transforms(img)
if input_dim == 1:
img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114
img = img.unsqueeze(0)
return img

def __len__(self):
return self.size

class dataset_unpair(data.Dataset):
def __init__(self, opts):
self.dataroot = opts.dataroot
Expand All @@ -26,6 +58,7 @@ def __init__(self, opts):
transforms = [Resize(opts.resize_size, Image.BICUBIC)]
if opts.phase == 'train':
transforms.append(RandomCrop(opts.crop_size))
#transforms.append(CenterCrop(opts.crop_size))
else:
transforms.append(CenterCrop(opts.crop_size))
if not opts.no_flip:
Expand Down
53 changes: 53 additions & 0 deletions src/interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from options import TestOptions
from dataset import dataset_single
from model import DRIT, DRIT_concat
from saver import save_imgs
import os

def main():
# parse options
parser = TestOptions()
opts = parser.parse()

# daita loader
print('\n--- load dataset ---')
if opts.a2b:
dataset = dataset_single(opts, 'A', opts.input_dim_a)
else:
dataset = dataset_single(opts, 'B', opts.input_dim_b)
loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=opts.nThreads)

# model
print('\n--- load model ---')
if opts.concat:
model = DRIT_concat(opts)
else:
model = DRIT(opts)
model.setgpu(opts.gpu)
model.resume(opts.resume)
model.eval()

# directory
result_dir = os.path.join(opts.result_dir, opts.name)
if not os.path.exists(result_dir):
os.mkdir(result_dir)

# test
print('\n--- testing ---')
for idx1, img in enumerate(loader):
print('{}/{}'.format(idx1, len(loader)))
img = img.cuda()
imgs = [img]
names = ['input']
with torch.no_grad():
imgs_list = model.interpolate(img, 'gg/6.npy', 'gg/4.npy', a2b=opts.a2b)
for idx2 in range(len(imgs_list)):
imgs.append(imgs_list[idx2])
names.append('output_{}'.format(idx2))
save_imgs(imgs, names, os.path.join(result_dir, '{}'.format(idx1)))

return

if __name__ == '__main__':
main()

0 comments on commit 9f352f5

Please sign in to comment.