/
pix2pix_dataset.py
117 lines (93 loc) · 4.25 KB
/
pix2pix_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
from data.base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
import util.util as util
import numpy as np
import scipy.ndimage.measurements as measurements
import torch
import skimage.transform
import skimage.feature
import scipy.misc
import os
class Pix2pixDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--no_pairing_check', action='store_true',
help='If specified, skip sanity check of correct label-image file pairing')
return parser
def initialize(self, opt):
self.opt = opt
label_paths, image_paths, instance_paths = self.get_paths(opt)
util.natural_sort(label_paths)
util.natural_sort(image_paths)
if not opt.no_instance:
util.natural_sort(instance_paths)
label_paths = label_paths[:opt.max_dataset_size]
image_paths = image_paths[:opt.max_dataset_size]
instance_paths = instance_paths[:opt.max_dataset_size]
if not opt.no_pairing_check:
for path1, path2 in zip(label_paths, image_paths):
assert self.paths_match(path1, path2), \
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2)
self.label_paths = label_paths
self.image_paths = image_paths
self.instance_paths = instance_paths
size = len(self.label_paths)
ngpus = len(self.opt.gpu_ids)
self.dataset_size = size
#if opt.isTrain:
# round_to_ngpus = (size // ngpus) * ngpus
# self.dataset_size = round_to_ngpus
def get_paths(self, opt):
label_paths = []
image_paths = []
instance_paths = []
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
return label_paths, image_paths, instance_paths
def paths_match(self, path1, path2):
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
return filename1_without_ext == filename2_without_ext
def __getitem__(self, index):
## Label Image
label_path = self.label_paths[index]
label = Image.open(label_path)
params = get_params(self.opt, label.size)
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
label_tensor = transform_label(label) * 255.0
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
## input image (real images)
image_path = self.image_paths[index]
assert self.paths_match(label_path, image_path), \
"The label_path %s and image_path %s don't match." % \
(label_path, image_path)
image = Image.open(image_path)
image = image.convert('RGB')
transform_image = get_transform(self.opt, params)
image_tensor = transform_image(image)
## if using instance maps
if self.opt.no_instance:
instance_tensor = 0
else:
instance_path = self.instance_paths[index]
instance = Image.open(instance_path)
if instance.mode == 'L':
instance_tensor = transform_label(instance) * 255
instance_tensor = instance_tensor.long()
else:
instance_tensor = transform_label(instance)
input_dict = {'label': label_tensor,
'instance': instance_tensor,
'image': image_tensor,
'path': image_path,
}
## Give subclasses a chance to modify the final output
self.postprocess(input_dict)
return input_dict
def postprocess(self, input_dict):
return input_dict
def __len__(self):
return self.dataset_size