/
imagenet_datamodule.py
266 lines (221 loc) · 8.8 KB
/
imagenet_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# type: ignore[override]
import os
from typing import Any, Callable, Optional
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from pl_bolts.datasets import UnlabeledImagenet
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
else: # pragma: no cover
warn_missing_pkg('torchvision')
class ImagenetDataModule(LightningDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/
Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png
:width: 400
:alt: Imagenet
Specs:
- 1000 classes
- Each image is (3 x varies x varies) (here we default to 3 x 224 x 224)
Imagenet train, val and test dataloaders.
The train set is the imagenet train.
The val set is taken from the train set with `num_imgs_per_val_class` images per class.
For example if `num_imgs_per_val_class=2` then there will be 2,000 images in the validation set.
The test set is the official imagenet validation set.
Example::
from pl_bolts.datamodules import ImagenetDataModule
dm = ImagenetDataModule(IMAGENET_PATH)
model = LitModel()
Trainer().fit(model, datamodule=dm)
"""
name = 'imagenet'
def __init__(
self,
data_dir: str,
meta_dir: Optional[str] = None,
num_imgs_per_val_class: int = 50,
image_size: int = 224,
num_workers: int = 16,
batch_size: int = 32,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: path to the imagenet dataset file
meta_dir: path to meta.bin file
num_imgs_per_val_class: how many images per class for the validation set
image_size: final image size
num_workers: how many data workers
batch_size: batch_size
shuffle: If true shuffles the data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__(*args, **kwargs)
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.'
)
self.image_size = image_size
self.dims = (3, self.image_size, self.image_size)
self.data_dir = data_dir
self.num_workers = num_workers
self.meta_dir = meta_dir
self.num_imgs_per_val_class = num_imgs_per_val_class
self.batch_size = batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes
@property
def num_classes(self) -> int:
"""
Return:
1000
"""
return 1000
def _verify_splits(self, data_dir: str, split: str) -> None:
dirs = os.listdir(data_dir)
if split not in dirs:
raise FileNotFoundError(
f'a {split} Imagenet split was not found in {data_dir},'
f' make sure the folder contains a subfolder named {split}'
)
def prepare_data(self) -> None:
"""
This method already assumes you have imagenet2012 downloaded.
It validates the data using the meta.bin.
.. warning:: Please download imagenet on your own first.
"""
self._verify_splits(self.data_dir, 'train')
self._verify_splits(self.data_dir, 'val')
for split in ['train', 'val']:
files = os.listdir(os.path.join(self.data_dir, split))
if 'meta.bin' not in files:
raise FileNotFoundError(
"""
no meta.bin present. Imagenet is no longer automatically downloaded by PyTorch.
To get imagenet:
1. download yourself from http://www.image-net.org/challenges/LSVRC/2012/downloads
2. download the devkit (ILSVRC2012_devkit_t12.tar.gz)
3. generate the meta.bin file using the devkit
4. copy the meta.bin file into both train and val split folders
To generate the meta.bin do the following:
from pl_bolts.datasets import UnlabeledImagenet
path = '/path/to/folder/with/ILSVRC2012_devkit_t12.tar.gz/'
UnlabeledImagenet.generate_meta_bins(path)
"""
)
def train_dataloader(self) -> DataLoader:
"""
Uses the train split of imagenet2012 and puts away a portion of it for the validation split
"""
transforms = self.train_transform() if self.train_transforms is None else self.train_transforms
dataset = UnlabeledImagenet(
self.data_dir,
num_imgs_per_class=-1,
num_imgs_per_class_val_split=self.num_imgs_per_val_class,
meta_dir=self.meta_dir,
split='train',
transform=transforms
)
loader: DataLoader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)
return loader
def val_dataloader(self) -> DataLoader:
"""
Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class`
Args:
batch_size: the batch size
transforms: the transforms
"""
transforms = self.train_transform() if self.val_transforms is None else self.val_transforms
dataset = UnlabeledImagenet(
self.data_dir,
num_imgs_per_class_val_split=self.num_imgs_per_val_class,
meta_dir=self.meta_dir,
split='val',
transform=transforms
)
loader: DataLoader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)
return loader
def test_dataloader(self) -> DataLoader:
"""
Uses the validation split of imagenet2012 for testing
"""
transforms = self.val_transform() if self.test_transforms is None else self.test_transforms
dataset = UnlabeledImagenet(
self.data_dir, num_imgs_per_class=-1, meta_dir=self.meta_dir, split='test', transform=transforms
)
loader: DataLoader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)
return loader
def train_transform(self) -> Callable:
"""
The standard imagenet transforms
.. code-block:: python
transform_lib.Compose([
transform_lib.RandomResizedCrop(self.image_size),
transform_lib.RandomHorizontalFlip(),
transform_lib.ToTensor(),
transform_lib.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
"""
preprocessing = transform_lib.Compose([
transform_lib.RandomResizedCrop(self.image_size),
transform_lib.RandomHorizontalFlip(),
transform_lib.ToTensor(),
imagenet_normalization(),
])
return preprocessing
def val_transform(self) -> Callable:
"""
The standard imagenet transforms for validation
.. code-block:: python
transform_lib.Compose([
transform_lib.Resize(self.image_size + 32),
transform_lib.CenterCrop(self.image_size),
transform_lib.ToTensor(),
transform_lib.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
"""
preprocessing = transform_lib.Compose([
transform_lib.Resize(self.image_size + 32),
transform_lib.CenterCrop(self.image_size),
transform_lib.ToTensor(),
imagenet_normalization(),
])
return preprocessing