forked from fastai/fastai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
364 lines (313 loc) · 18.1 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
"Manages data input pipeline - folderstransformbatch input. Includes support for classification, segmentation and bounding boxes"
from ..torch_core import *
from .image import *
from .transform import *
from ..data_block import *
from ..basic_data import *
from ..layers import *
from .learner import *
from concurrent.futures import ProcessPoolExecutor, as_completed
import PIL, warnings
__all__ = ['get_image_files', 'denormalize', 'get_annotations', 'ImageDataBunch',
'ImageItemList', 'normalize', 'normalize_funcs',
'channel_view', 'mnist_stats', 'cifar_stats', 'imagenet_stats', 'download_images',
'verify_images', 'bb_pad_collate', 'ObjectCategoryProcessor', 'ImageToImageList',
'ObjectCategoryList', 'ObjectItemList', 'SegmentationLabelList', 'SegmentationItemList', 'PointsItemList']
image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))
def get_image_files(c:PathOrStr, check_ext:bool=True, recurse=False)->FilePathList:
"Return list of files in `c` that are images. `check_ext` will filter to `image_extensions`."
return get_files(c, extensions=(image_extensions if check_ext else None), recurse=recurse)
def get_annotations(fname, prefix=None):
"Open a COCO style json in `fname` and returns the lists of filenames (with maybe `prefix`) and labelled bboxes."
annot_dict = json.load(open(fname))
id2images, id2bboxes, id2cats = {}, collections.defaultdict(list), collections.defaultdict(list)
classes = {}
for o in annot_dict['categories']:
classes[o['id']] = o['name']
for o in annot_dict['annotations']:
bb = o['bbox']
id2bboxes[o['image_id']].append([bb[1],bb[0], bb[3]+bb[1], bb[2]+bb[0]])
id2cats[o['image_id']].append(classes[o['category_id']])
for o in annot_dict['images']:
if o['id'] in id2bboxes:
id2images[o['id']] = ifnone(prefix, '') + o['file_name']
ids = list(id2images.keys())
return [id2images[k] for k in ids], [[id2bboxes[k], id2cats[k]] for k in ids]
def bb_pad_collate(samples:BatchSamples, pad_idx:int=0) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:
"Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`."
max_len = max([len(s[1].data[1]) for s in samples])
bboxes = torch.zeros(len(samples), max_len, 4)
labels = torch.zeros(len(samples), max_len).long() + pad_idx
imgs = []
for i,s in enumerate(samples):
imgs.append(s[0].data[None])
bbs, lbls = s[1].data
bboxes[i,-len(lbls):] = bbs
labels[i,-len(lbls):] = lbls
return torch.cat(imgs,0), (bboxes,labels)
def _maybe_add_crop_pad(tfms):
tfm_names = [tfm.__name__ for tfm in tfms]
return [crop_pad()] + tfms if 'crop_pad' not in tfm_names else tfms
def _prep_tfm_kwargs(tfms, kwargs):
default_rsz = ResizeMethod.SQUISH if ('size' in kwargs and is_listy(kwargs['size'])) else ResizeMethod.CROP
resize_method = ifnone(kwargs.get('resize_method', default_rsz), default_rsz)
if resize_method <= 2: tfms = _maybe_add_crop_pad(tfms)
kwargs['resize_method'] = resize_method
return tfms, kwargs
def normalize(x:TensorImage, mean:FloatTensor,std:FloatTensor)->TensorImage:
"Normalize `x` with `mean` and `std`."
return (x-mean[...,None,None]) / std[...,None,None]
def denormalize(x:TensorImage, mean:FloatTensor,std:FloatTensor)->TensorImage:
"Denormalize `x` with `mean` and `std`."
return x*std[...,None,None] + mean[...,None,None]
def _normalize_batch(b:Tuple[Tensor,Tensor], mean:FloatTensor, std:FloatTensor, do_y:bool=False)->Tuple[Tensor,Tensor]:
"`b` = `x`,`y` - normalize `x` array of imgs and `do_y` optionally `y`."
x,y = b
mean,std = mean.to(x.device),std.to(x.device)
x = normalize(x,mean,std)
if do_y and len(y.shape) == 4: y = normalize(y,mean,std)
return x,y
def normalize_funcs(mean:FloatTensor, std:FloatTensor, do_y:bool=False)->Tuple[Callable,Callable]:
"Create normalize/denormalize func using `mean` and `std`, can specify `do_y` and `device`."
mean,std = tensor(mean),tensor(std)
return (partial(_normalize_batch, mean=mean, std=std, do_y=do_y),
partial(denormalize, mean=mean, std=std))
cifar_stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
mnist_stats = ([0.15]*3, [0.15]*3)
def channel_view(x:Tensor)->Tensor:
"Make channel the first axis of `x` and flatten remaining axes"
return x.transpose(0,1).contiguous().view(x.shape[1],-1)
def _get_fns(ds, path): #TODO: fix me when from_folder is finished
"List of all file names relative to `path`."
return [str(fn.relative_to(path)) for fn in ds.x.items]
class ImageDataBunch(DataBunch):
@classmethod
def create_from_ll(cls, dss:LabelLists, bs:int=64, ds_tfms:Optional[TfmList]=None,
num_workers:int=defaults.cpus, tfms:Optional[Collection[Callable]]=None, device:torch.device=None,
test:Optional[PathOrStr]=None, collate_fn:Callable=data_collate, size:int=None, **kwargs)->'ImageDataBunch':
dss = dss.transform(tfms=ds_tfms, size=size, **kwargs)
if test is not None: dss.add_test_folder(test)
return dss.databunch(bs=bs, tfms=tfms, num_workers=num_workers, collate_fn=collate_fn, device=device)
@classmethod
def from_folder(cls, path:PathOrStr, train:PathOrStr='train', valid:PathOrStr='valid',
valid_pct=None, classes:Collection=None, **kwargs:Any)->'ImageDataBunch':
"Create from imagenet style dataset in `path` with `train`,`valid`,`test` subfolders (or provide `valid_pct`)."
path=Path(path)
il = ImageItemList.from_folder(path)
if valid_pct is None: src = il.split_by_folder(train=train, valid=valid)
else: src = il.random_split_by_pct(valid_pct)
src = src.label_from_folder(classes=classes)
return cls.create_from_ll(src, **kwargs)
@classmethod
def from_df(cls, path:PathOrStr, df:pd.DataFrame, folder:PathOrStr='.', sep=None, valid_pct:float=0.2,
fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, suffix:str='',
**kwargs:Any)->'ImageDataBunch':
"Create from a DataFrame."
src = (ImageItemList.from_df(df, path=path, folder=folder, suffix=suffix, cols=fn_col)
.random_split_by_pct(valid_pct)
.label_from_df(sep=sep, cols=label_col))
return cls.create_from_ll(src, **kwargs)
@classmethod
def from_csv(cls, path:PathOrStr, folder:PathOrStr='.', sep=None, csv_labels:PathOrStr='labels.csv', valid_pct:float=0.2,
fn_col:int=0, label_col:int=1, suffix:str='',
header:Optional[Union[int,str]]='infer', **kwargs:Any)->'ImageDataBunch':
"Create from a csv file."
path = Path(path)
df = pd.read_csv(path/csv_labels, header=header)
return cls.from_df(path, df, folder=folder, sep=sep, valid_pct=valid_pct,
fn_col=fn_col, label_col=label_col, suffix=suffix, header=header, **kwargs)
@classmethod
def from_lists(cls, path:PathOrStr, fnames:FilePathList, labels:Collection[str], valid_pct:float=0.2, **kwargs):
src = ImageItemList(fnames, path=path).random_split_by_pct(valid_pct).label_from_list(labels)
return cls.create_from_ll(src, **kwargs)
@classmethod
def from_name_func(cls, path:PathOrStr, fnames:FilePathList, label_func:Callable, valid_pct:float=0.2, **kwargs):
src = ImageItemList(fnames, path=path).random_split_by_pct(valid_pct)
return cls.create_from_ll(src.label_from_func(label_func), **kwargs)
@classmethod
def from_name_re(cls, path:PathOrStr, fnames:FilePathList, pat:str, valid_pct:float=0.2, **kwargs):
pat = re.compile(pat)
def _get_label(fn): return pat.search(str(fn)).group(1)
return cls.from_name_func(path, fnames, _get_label, valid_pct=valid_pct, **kwargs)
def batch_stats(self, funcs:Collection[Callable]=None)->Tensor:
"Grab a batch of data and call reduction function `func` per channel"
funcs = ifnone(funcs, [torch.mean,torch.std])
x = self.valid_dl.one_batch()[0].cpu()
return [func(channel_view(x), 1) for func in funcs]
def normalize(self, stats:Collection[Tensor]=None, do_y:bool=None)->None:
"Add normalize transform using `stats` (defaults to `DataBunch.batch_stats`)"
if getattr(self,'norm',False): raise Exception('Can not call normalize twice')
if stats is None: self.stats = self.batch_stats()
else: self.stats = stats
self.norm,self.denorm = normalize_funcs(*self.stats, do_y=do_y)
self.add_tfm(self.norm)
return self
def labels_to_csv(self, dest:str)->None:
"Save file names and labels in `data` as CSV to file name `dest`."
fns = _get_fns(self.train_ds, self.path)
y = [str(o) for o in self.train_ds.y]
fns += _get_fns(self.valid_ds, self.path)
y += [str(o) for o in self.valid_ds.y]
if self.test_ds is not None:
fns += _get_fns(self.test_ds, self.path)
y += [str(o) for o in self.test_ds.y]
df = pd.DataFrame({'name': fns, 'label': y})
df.to_csv(dest, index=False)
@staticmethod
def single_from_classes(path:Union[Path, str], classes:Collection[str], tfms:TfmList=None,
label_cls=CategoryList, **kwargs):
"""Create an empty `ImageDataBunch` in `path` with `classes`. Typically used for inference.
Use `label_cls` to specify the type of your labels"""
sd = ImageItemList([], path=path).split_by_idx([])
return sd.label_const(0, label_cls=label_cls, classes=classes).transform(tfms, **kwargs).databunch()
def download_image(url,dest, timeout=4):
try: r = download_url(url, dest, overwrite=True, show_progress=False, timeout=timeout)
except Exception as e: print(f"Error {url} {e}")
def download_images(urls:Collection[str], dest:PathOrStr, max_pics:int=1000, max_workers:int=8, timeout=4):
"Download images listed in text file `urls` to path `dest`, at most `max_pics`"
urls = open(urls).read().strip().split("\n")[:max_pics]
dest = Path(dest)
dest.mkdir(exist_ok=True)
if max_workers:
with ProcessPoolExecutor(max_workers=max_workers) as ex:
suffixes = [re.findall(r'\.\w+?(?=(?:\?|$))', url) for url in urls]
suffixes = [suffix[0] if len(suffix)>0 else '.jpg' for suffix in suffixes]
futures = [ex.submit(download_image, url, dest/f"{i:08d}{suffixes[i]}", timeout=timeout)
for i,url in enumerate(urls)]
for f in progress_bar(as_completed(futures), total=len(urls)): pass
else:
for i,url in enumerate(progress_bar(urls)):
download_image(url, dest/f"{i:08d}.jpg", timeout=timeout)
def verify_image(file:Path, delete:bool, max_size:Union[int,Tuple[int,int]]=None, dest:Path=None, n_channels:int=3,
interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None, resume:bool=False, **kwargs):
"""Check if the image in `file` exists, it can be opened and has `n_channels`.
If `delete=True`:
(1) removes `file` if any of the verifications fails
(2) saves a modified version of `file` w/o EXIF data if the latter is broken
If `max_size` is specifided, image is resized to the same ratio so that both sizes are less than `max_size`,
using `interp`. Result is stored in `dest`, `ext` forces an extension type, `img_format` and `kwargs` are passed
to PIL.Image.save."""
try:
# deal with partially broken images as indicated by PIL warnings
with warnings.catch_warnings():
warnings.filterwarnings('error')
try:
# must use this workaround to avoid: ResourceWarning: unclosed file warning
with open(file, 'rb') as img_file: PIL.Image.open(img_file)
except Warning as w:
if "Possibly corrupt EXIF data" in str(w):
if delete: # green light to modify files
print(f"{file}: Removing corrupt EXIF data")
warnings.simplefilter("ignore")
# save EXIF-cleaned up image, which happens automatically
PIL.Image.open(file).save(file)
else: # keep user's files intact
print(f"{file}: Not removing corrupt EXIF data, pass `delete=True` to do that")
else: warnings.warn(w)
img = PIL.Image.open(file)
if max_size is None: return
assert isinstance(dest, Path), "You should provide `dest` Path to save resized image"
max_size = listify(max_size, 2)
if img.height > max_size[0] or img.width > max_size[1]:
dest_fname = dest/file.name
if ext is not None: dest_fname=dest_fname.with_suffix(ext)
if resume and os.path.isfile(dest_fname): return
ratio = img.height/img.width
new_h = min(max_size[0], int(max_size[1] * ratio))
new_w = int(new_h/ratio)
if n_channels == 3: img = img.convert("RGB")
img = img.resize((new_w,new_h), resample=interp)
img.save(dest_fname, img_format, **kwargs)
img = np.array(img)
img_channels = 1 if len(img.shape) == 2 else img.shape[2]
assert img_channels == n_channels, f"Image {file} has {img_channels} instead of {n_channels}"
except Exception as e:
print(f'{e}')
if delete: file.unlink()
def verify_images(path:PathOrStr, delete:bool=True, max_workers:int=4, max_size:Union[int,Tuple[int,int]]=None,
dest:PathOrStr='.', n_channels:int=3, interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None,
resume:bool=None, **kwargs):
"""Check if the image in `path` exists, can be opened and has `n_channels`.
If `n_channels` is 3 – it'll try to convert image to RGB. If `delete`, removes it if it fails.
If `resume` – it will skip already existent images in `dest`. If `max_size` is specifided,
image is resized to the same ratio so that both sizes are less than `max_size`, using `interp`.
Result is stored in `dest`, `ext` forces an extension type, `img_format` and `kwargs` are
passed to PIL.Image.save. Use `max_workers` CPUs."""
path = Path(path)
if resume is None and dest == '.': resume=False
dest = path/Path(dest)
os.makedirs(dest, exist_ok=True)
files = get_image_files(path)
if max_workers<2: res = [verify_image(f, delete=delete, max_size=max_size, dest=dest, n_channels=n_channels,
interp=interp, ext=ext, img_format=img_format, resume=resume, **kwargs) for f in files]
with ProcessPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(verify_image, f, delete=delete, max_size=max_size, dest=dest, n_channels=n_channels,
interp=interp, ext=ext, img_format=img_format, resume=resume, **kwargs) for f in files]
for f in progress_bar(as_completed(futures), total=len(files)): pass
class ImageItemList(ItemList):
_bunch = ImageDataBunch
def __post_init__(self):
super().__post_init__()
self.sizes={}
def open(self, fn): return open_image(fn)
def get(self, i):
fn = super().get(i)
res = self.open(fn)
self.sizes[i] = res.size
return res
@classmethod
def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=None, **kwargs)->ItemList:
"Get the list of files in `path` that have an image suffix. `recurse` determines if we search subfolders."
extensions = ifnone(extensions, image_extensions)
return super().from_folder(path=path, extensions=extensions, **kwargs)
@classmethod
def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, folder:PathOrStr='.', suffix:str='')->'ItemList':
"Get the filenames in `col` of `df` and will had `path/folder` in front of them, `suffix` at the end."
suffix = suffix or ''
res = super().from_df(df, path=path, cols=cols)
res.items = np.char.add(np.char.add(f'{folder}/', res.items.astype(str)), suffix)
res.items = np.char.add(f'{res.path}/', res.items)
return res
@classmethod
def from_csv(cls, path:PathOrStr, csv_name:str, cols:IntsOrStrs=0, header:str='infer',
folder:PathOrStr='.', suffix:str='')->'ItemList':
df = pd.read_csv(Path(path)/csv_name, header=header)
return cls.from_df(df, path=Path(path), cols=cols, folder=folder, suffix=suffix)
class ObjectCategoryProcessor(MultiCategoryProcessor):
def process_one(self,item): return [item[0], [self.c2i.get(o,None) for o in item[1]]]
def generate_classes(self, items):
classes = super().generate_classes([o[1] for o in items])
classes = ['background'] + list(classes)
return classes
def _get_size(xs,i):
size = xs.sizes.get(i,None)
if size is None:
# Image hasn't been accessed yet, so we don't know its size
_ = xs[i]
size =xs.sizes[i]
return size
class ObjectCategoryList(MultiCategoryList):
_processor = ObjectCategoryProcessor
def get(self, i):
return ImageBBox.create(*self.x.sizes[i], *self.items[i], classes=self.classes)
def reconstruct(self, t, x): return self[0].reconstruct(*t, x, classes=self.classes)
class ObjectItemList(ImageItemList): _label_cls = ObjectCategoryList
class SegmentationLabelList(ImageItemList):
def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
super().__init__(items, **kwargs)
self.classes,self.loss_func = classes,CrossEntropyFlat()
self.c = len(self.classes)
def new(self, items, classes=None, **kwargs):
return self.__class__(items, ifnone(classes, self.classes), **kwargs)
def open(self, fn): return open_mask(fn)
class SegmentationItemList(ImageItemList): _label_cls = SegmentationLabelList
class PointsItemList(ItemList):
def __post_init__(self):
super().__post_init__()
self.c = len(self.items[0].reshape(-1))
self.loss_func = MSELossFlat()
def get(self, i):
o = super().get(i)
return ImagePoints(FlowField(_get_size(self.x,i), o), scale=True)
class ImageToImageList(ImageItemList): _label_cls = ImageItemList