/
ssl_imagenet_datamodule.py
135 lines (111 loc) · 4.94 KB
/
ssl_imagenet_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.utils.warnings import warn_missing_pkg
try:
from torchvision import transforms as transform_lib
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
class SSLImagenetDataModule(LightningDataModule): # pragma: no cover
name = 'imagenet'
def __init__(
self,
data_dir,
meta_dir=None,
num_workers=16,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.'
)
self.data_dir = data_dir
self.num_workers = num_workers
self.meta_dir = meta_dir
@property
def num_classes(self):
return 1000
def _verify_splits(self, data_dir, split):
dirs = os.listdir(data_dir)
if split not in dirs:
raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir}, make sure the'
f'folder contains a subfolder named {split}')
def prepare_data(self):
# imagenet cannot be downloaded... must provide path to folder with the train/val splits
self._verify_splits(self.data_dir, 'train')
self._verify_splits(self.data_dir, 'val')
for split in ['train', 'val']:
files = os.listdir(os.path.join(self.data_dir, split))
if 'meta.bin' not in files:
raise FileNotFoundError("""
no meta.bin present. Imagenet is no longer automatically downloaded by PyTorch.
To get imagenet:
1. download yourself from http://www.image-net.org/challenges/LSVRC/2012/downloads
2. download the devkit (ILSVRC2012_devkit_t12.tar.gz)
3. generate the meta.bin file using the devkit
4. copy the meta.bin file into both train and val split folders
To generate the meta.bin do the following:
from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet
path = '/path/to/folder/with/ILSVRC2012_devkit_t12.tar.gz/'
UnlabeledImagenet.generate_meta_bins(path)
""")
def train_dataloader(self, batch_size, num_images_per_class=-1, add_normalize=False):
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms
dataset = UnlabeledImagenet(self.data_dir,
num_imgs_per_class=num_images_per_class,
meta_dir=self.meta_dir,
split='train',
transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader
def val_dataloader(self, batch_size, num_images_per_class=50, add_normalize=False):
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
dataset = UnlabeledImagenet(self.data_dir,
num_imgs_per_class_val_split=num_images_per_class,
meta_dir=self.meta_dir,
split='val',
transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True
)
return loader
def test_dataloader(self, batch_size, num_images_per_class, add_normalize=False):
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
dataset = UnlabeledImagenet(self.data_dir,
num_imgs_per_class=num_images_per_class,
meta_dir=self.meta_dir,
split='test',
transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader
def _default_transforms(self):
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
imagenet_normalization()
])
return mnist_transforms