/
datasets.py
123 lines (108 loc) · 4.01 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from torch.utils.data import random_split
from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet
from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms
from pl_bolts.models.self_supervised.amdim.ssl_datasets import CIFAR10Mixed
try:
from torchvision.datasets import STL10
except ImportError:
warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover
' install it with `pip install torchvision`.')
class AMDIMPretraining():
""""
For pretraining we use the train transform for both train and val.
"""
@staticmethod
def cifar10(dataset_root, split: str = 'train'):
assert split in ('train', 'val')
dataset = CIFAR10Mixed(
root=dataset_root,
split=split,
transform=amdim_transforms.AMDIMTrainTransformsCIFAR10(),
download=True,
)
return dataset
@staticmethod
def cifar10_tiny(dataset_root, split: str = 'train'):
assert split in ('train', 'val')
dataset = CIFAR10Mixed(
root=dataset_root,
split=split,
transform=amdim_transforms.AMDIMTrainTransformsCIFAR10(),
download=True,
nb_labeled_per_class=50,
)
return dataset
@staticmethod
def imagenet(dataset_root, nb_classes, split: str = 'train'):
assert split in ('train', 'val')
dataset = UnlabeledImagenet(
dataset_root,
nb_classes=nb_classes,
split=split,
transform=amdim_transforms.AMDIMTrainTransformsImageNet128(),
)
return dataset
@staticmethod
def stl(dataset_root, split: str = None):
dataset = STL10(
root=dataset_root,
split='unlabeled',
transform=amdim_transforms.AMDIMTrainTransformsSTL10(),
download=True)
tng_split, val_split = random_split(dataset, [95000, 5000])
return tng_split, val_split
@staticmethod
def get_dataset(datamodule: str, data_dir, split: str = 'train', **kwargs):
datasets = {
'tiny-cifar10': AMDIMPretraining.cifar10_tiny,
'cifar10': AMDIMPretraining.cifar10,
'stl10': AMDIMPretraining.stl,
'imagenet2012': AMDIMPretraining.imagenet,
}
assert datamodule in datasets, 'unrecognized dataset request'
return datasets[datamodule](dataset_root=data_dir, split=split, **kwargs)
class AMDIMPatchesPretraining():
""""
For pretraining we use the train transform for both train and val.
"""
@staticmethod
def cifar10(dataset_root, patch_size, patch_overlap, split: str = 'train'):
assert split in ('train', 'val')
train_transform = amdim_transforms.TransformsC10Patches(
patch_size=patch_size,
patch_overlap=patch_overlap)
dataset = CIFAR10Mixed(
root=dataset_root,
split=split,
transform=train_transform,
download=True,
)
return dataset
@staticmethod
def stl(dataset_root, patch_size, patch_overlap, split: str = None):
train_transform = amdim_transforms.TransformsSTL10Patches(
patch_size=patch_size,
overlap=patch_overlap
)
dataset = STL10(
root=dataset_root,
split='unlabeled',
transform=train_transform,
download=True,
)
tng_split, val_split = random_split(dataset, [95000, 5000])
return tng_split, val_split
@staticmethod
def imagenet(dataset_root, nb_classes, patch_size, patch_overlap, split: str = 'train'):
assert split in ('train', 'val')
train_transform = amdim_transforms.TransformsImageNet128Patches(
patch_size=patch_size,
overlap=patch_overlap
)
dataset = UnlabeledImagenet(
dataset_root,
nb_classes=nb_classes,
split=split,
transform=train_transform,
)
return dataset