-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
61 lines (51 loc) · 3.25 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from constants import LearningHyperParameter
import constants as cst
class Configuration:
""" Represents the configuration file of the simulation, containing all variables of the simulation. """
def __init__(self):
self.IS_WANDB = False
self.IS_SWEEP = False
self.IS_TESTING = True
self.IS_SAMPLING = False
self.IS_TRAINING = False
self.IS_TRAINING_AE = True
self.IS_DEBUG = False
self.CHOSEN_AE = cst.Autoencoders.RQVAE
self.CHOSEN_TRANSFORMER = cst.Transformers.RQTRANSFORMER
self.CHOSEN_MODEL = self.CHOSEN_AE if self.IS_TRAINING_AE else self.CHOSEN_TRANSFORMER
self.WANDB_INSTANCE = None
self.WANDB_RUN_NAME = None
self.WANDB_SWEEP_NAME = None
self.EARLY_STOPPING_METRIC = None
self.FILENAME_CKPT = None
self.HYPER_PARAMETERS = {hp: None for hp in LearningHyperParameter}
self.HYPER_PARAMETERS[LearningHyperParameter.BATCH_SIZE] = 16
self.HYPER_PARAMETERS[LearningHyperParameter.LEARNING_RATE] = 0.001
self.HYPER_PARAMETERS[LearningHyperParameter.OPTIMIZER] = cst.Optimizers.ADAM.value
self.HYPER_PARAMETERS[LearningHyperParameter.DROPOUT] = 0.1
self.HYPER_PARAMETERS[LearningHyperParameter.CODEBOOK_LENGTH] = 4096
self.HYPER_PARAMETERS[LearningHyperParameter.LSTM_LAYERS] = 2
self.HYPER_PARAMETERS[LearningHyperParameter.INIT_KMEANS] = True
self.HYPER_PARAMETERS[LearningHyperParameter.Z_SCORE] = True
self.HYPER_PARAMETERS[LearningHyperParameter.SHARED_CODEBOOK] = False
self.HYPER_PARAMETERS[LearningHyperParameter.CONV_SETUP] = 1
if self.CHOSEN_MODEL == cst.Autoencoders.VQVAE or self.HYPER_PARAMETERS[LearningHyperParameter.SHARED_CODEBOOK]:
self.HYPER_PARAMETERS[LearningHyperParameter.NUM_QUANTIZERS] = 1
else:
self.HYPER_PARAMETERS[LearningHyperParameter.NUM_QUANTIZERS] = 12
self.HYPER_PARAMETERS[LearningHyperParameter.EPOCHS] = 100
self.HYPER_PARAMETERS[LearningHyperParameter.KERNEL_SIZES] = [[3, 7, 7, 6, 5, 5], [3, 7, 7, 6, 4]]
self.HYPER_PARAMETERS[LearningHyperParameter.STRIDES] = [[1, 5, 5, 4, 3, 3], [1, 5, 5, 4, 2]]
self.HYPER_PARAMETERS[LearningHyperParameter.PADDINGS] = [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]
self.HYPER_PARAMETERS[LearningHyperParameter.DILATIONS] = [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]
self.HYPER_PARAMETERS[LearningHyperParameter.HIDDEN_CHANNELS] = [[1, 8, 16, 32, 64, 128, 256], [1, 16, 32, 64, 128, 256]]
self.HYPER_PARAMETERS[LearningHyperParameter.NUM_CONVS] = len(
self.HYPER_PARAMETERS[LearningHyperParameter.KERNEL_SIZES][self.HYPER_PARAMETERS[LearningHyperParameter.CONV_SETUP]]
)
self.HYPER_PARAMETERS[LearningHyperParameter.MULTI_SPECTRAL_RECON_LOSS_WEIGHT] = 1e-6
self.HYPER_PARAMETERS[LearningHyperParameter.RECON_LOSS_WEIGHT] = 1
self.HYPER_PARAMETERS[LearningHyperParameter.P] = 0.3
self.HYPER_PARAMETERS[LearningHyperParameter.NUM_STEPS] = 8
self.HYPER_PARAMETERS[LearningHyperParameter.NUM_HEADS] = 8
self.HYPER_PARAMETERS[LearningHyperParameter.NUM_SPATIAL_TRANSFORMER_LAYERS] = 12
self.HYPER_PARAMETERS[LearningHyperParameter.NUM_DEPTH_TRANSFORMER_LAYERS] = 4