-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiments.py
176 lines (144 loc) · 7.17 KB
/
experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import time
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from transformations.custom_transforms import DOG, Gabor
'''
Defining specific experiments
'''
class Experiment(object):
'''
Maintain experiment type
'''
def __init__(self, args):
self.arch = args.arch
self.finetune = args.finetune
self.experiment_dataset = args.dataset
self.concat = args.concat
self.same = args.same
self.DOG = args.DOG
self.DOG_options = args.options
self.gabor = args.gabor
self.scales = args.scales
self.orientations = args.orientations
self.own = args.own
self.directory = args.data
self.decay = args.decay
self.name = self.get_name()
def get_name(self):
'''
Given type of experiment, return information about experiment as string.
'''
f = 'F' if self.finetune else ''
decay = str(self.decay) if self.decay else ''
if self.own and self.finetune:
return f + str(self.own) + decay
if self.own:
return str(own)
concat_or_same = 'original'
if self.concat:
concat_or_same = 'concat'
elif self.same:
concat_or_same = 'same'
transform = 'None'
if self.DOG:
transform = 'DOG:'
if self.DOG_options:
transform += f'({str(self.DOG_options)})'
elif self.gabor:
transform = 'gabor:'
if self.scales:
transform += f'v({str(self.scales)})'
if self.orientations:
transform += f'u({str(self.orientations)})'
name = f + str(self.experiment_dataset) + '-' + str(self.arch) + '-' \
+ concat_or_same + '-' + transform + decay
return name
def get_transformation_type(self):
'''
Get custom transformation to apply, with hyperparameters passed.
### Returns:
Transform object
'''
if self.DOG:
if not self.DOG_options:
return DOG()
elif len(self.DOG_options) == 1:
return DOG(sigma=float(self.DOG_options[0]))
else:
return DOG(sigma=float(self.DOG_options[0]), k=float(self.DOG_options[1]))
if self.gabor:
if not self.scales and not self.orientations:
return Gabor()
if self.scales and not self.orientations:
return Gabor(scales=[float(s) for s in self.scales])
if not self.scales and self.orientations:
return Gabor(orientations=[float(u) for u in self.orientations])
else:
return Gabor(scales=[float(s) for s in self.scales] , orientations=[float(u) for u in self.orientations])
return None
def get_transformation_set(self):
'''
Gets the transformations to apply for the experiment training, additional and validation set
### Returns:
train, additional, validation - transformations to apply respectively
'''
# defined normalization values from image-net.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
original_transformation = [transforms.ToTensor(), normalize]
custom_transform = self.get_transformation_type()
additional_transformation = [custom_transform,transforms.ToTensor()]
if self.same and custom_transform:
# Only train and test on a custom transformation
return transforms.Compose(additional_transformation), None, transforms.Compose(additional_transformation)
if custom_transform:
# Concat with custom transform
return transforms.Compose(original_transformation), transforms.Compose(additional_transformation), transforms.Compose(original_transformation)
else:
# No additional data
return transforms.Compose(original_transformation), None, transforms.Compose(original_transformation)
# Datasets for differnet experiments
def get_data_set(self, transform, additional_transform, validation_transform):
'''
Gets the training and test dataset for an experiment, applying relevant transformation and concatenations
### Arguments:
transform: list('Transform') - list of Transform objects - applied to training
additional_transform: list('Transform') - list of Transform objects - applied to training set
validation_transform: list('Transform') - list of Transform objects to be applied to validation set
### Returns:
'torchvision.dataset', 'torchvision.dataset' - pair corresponding to training and validation set
'''
if self.experiment_dataset == None:
# Custom Dataset - with directory/train and directory/val sub folders
traindir = os.path.join(self.directory, 'train')
valdir = os.path.join(self.directory, 'val')
train_dataset = self.define_dataset(traindir, transform)
test_dataset = self.define_dataset(valdir, validation_transform)
if self.concat:
transformed_train_dataset = self.define_dataset(traindir, additional_transform)
train_dataset = train_dataset + transformed_train_dataset
elif self.experiment_dataset == 0:
train_dataset = datasets.CIFAR10(root=self.directory, train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root=self.directory, train=False, download=True, transform=validation_transform)
if self.concat:
transformed_train_dataset = datasets.CIFAR10(root=self.directory, train=True, download=True, transform=additional_transform)
train_dataset = train_dataset + transformed_train_dataset
elif self.experiment_dataset == 1:
train_dataset = datasets.CIFAR100(root=self.directory, train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR100(root=self.directory, train=False, download=True, transform=validation_transform)
if self.concat:
transformed_train_dataset = datasets.CIFAR100(root=self.directory, train=True, download=True, transform=additional_transform)
train_dataset = train_dataset + transformed_train_dataset
return train_dataset, test_dataset
def define_dataset(self, directory, augmentations):
'''
Defines a custom dataset.
### Arguments:
directory: str - directory where the dataset is located.
augmentations: 'transforms.Compose' - a list of Transforms to apply.
### Returns:
'torchvision.dataset' - dataset image folder with augmentations to apply
'''
dataset = datasets.ImageFolder(directory,augmentations)
return dataset