-
Notifications
You must be signed in to change notification settings - Fork 186
/
dataset.py
111 lines (100 loc) · 4.01 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from torchvision import transforms
from torch.utils.data import Dataset
from .data_utils import get_onehot
from .augmentation.randaugment import RandAugment
import torchvision
from PIL import Image
import numpy as np
import copy
class BasicDataset(Dataset):
"""
BasicDataset returns a pair of image and labels (targets).
If targets are not given, BasicDataset returns None as the label.
This class supports strong augmentation for Fixmatch,
and return both weakly and strongly augmented images.
"""
def __init__(self,
alg,
data,
targets=None,
num_classes=None,
transform=None,
is_ulb=False,
strong_transform=None,
onehot=False,
*args, **kwargs):
"""
Args
data: x_data
targets: y_data (if not exist, None)
num_classes: number of label classes
transform: basic transformation of data
use_strong_transform: If True, this dataset returns both weakly and strongly augmented images.
strong_transform: list of transformation functions for strong augmentation
onehot: If True, label is converted into onehot vector.
"""
super(BasicDataset, self).__init__()
self.alg = alg
self.data = data
self.targets = targets
self.num_classes = num_classes
self.is_ulb = is_ulb
self.onehot = onehot
self.transform = transform
if self.is_ulb:
if strong_transform is None:
self.strong_transform = copy.deepcopy(transform)
self.strong_transform.transforms.insert(0, RandAugment(3, 5))
else:
self.strong_transform = strong_transform
def __getitem__(self, idx):
"""
If strong augmentation is not used,
return weak_augment_image, target
else:
return weak_augment_image, strong_augment_image, target
"""
# set idx-th target
if self.targets is None:
target = None
else:
target_ = self.targets[idx]
target = target_ if not self.onehot else get_onehot(self.num_classes, target_)
# set augmented images
img = self.data[idx]
if self.transform is None:
return transforms.ToTensor()(img), target
else:
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
img_w = self.transform(img)
if not self.is_ulb:
return idx, img_w, target
else:
if self.alg == 'fixmatch':
return idx, img_w, self.strong_transform(img)
elif self.alg == 'flexmatch':
return idx, img_w, self.strong_transform(img)
elif self.alg == 'pimodel':
return idx, img_w, self.transform(img)
elif self.alg == 'pseudolabel':
return idx, img_w
elif self.alg == 'vat':
return idx, img_w
elif self.alg == 'meanteacher':
return idx, img_w, self.transform(img)
elif self.alg == 'uda':
return idx, img_w, self.strong_transform(img)
elif self.alg == 'mixmatch':
return idx, img_w, self.transform(img)
elif self.alg == 'remixmatch':
rotate_v_list = [0, 90, 180, 270]
rotate_v1 = np.random.choice(rotate_v_list, 1).item()
img_s1 = self.strong_transform(img)
img_s1_rot = torchvision.transforms.functional.rotate(img_s1, rotate_v1)
img_s2 = self.strong_transform(img)
return idx, img_w, img_s1, img_s2, img_s1_rot, rotate_v_list.index(rotate_v1)
elif self.alg == 'fullysupervised':
return idx
def __len__(self):
return len(self.data)