-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopt.py
138 lines (111 loc) · 6 KB
/
opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
from data import get_mnist, get_cifar10, get_fashion,\
get_svhn, get_unlabeled_celebA, available_datasets
def get_options():
parser = argparse.ArgumentParser()
parser = normal_options(parser)
parser = model_options(parser)
parser = infogan_options(parser)
return parser.parse_args()
def normal_options(parser):
# parameters in training
parser.add_argument("--save-path", type=str, default="outputs/")
parser.add_argument("-seed", default=42, type=int, help="random seed")
parser.add_argument("--lrD", default=2e-4, type=float, help="learning rate of Discriminator")
parser.add_argument("--lrG", default=2e-4, type=float, help="learning rate of Generator")
parser.add_argument("--epochs", default=10, type=int, help="total epochs in training")
parser.add_argument("--save-epoch-interval", default=1, type=int, help="interval of epoch for saving model")
parser.add_argument("--cuda", default=True, type=bool, help="using cuda")
parser.add_argument("--adam-betas", default=(0.9, 0.999), type=tuple, help="betas of adam optimizer")
parser.add_argument("--train-D-iter", default=1, type=int, help="times of training D in one iteration")
# parameters of dataset
parser.add_argument("--batch-size", default=64, type=int, help="batch size")
parser.add_argument("--data-name", default="MNIST", choices=available_datasets, type=str, help="dataset name")
# parser.add_argument("--img-size", default=(64,64), type=tuple, help="output image size") # currently cannot be changed
parser.add_argument("--nrow", default=16, type=int, help="number of rows when showing batch of images")
parser.add_argument("--data-path", default="data", type=str, help="path of dataset")
parser.add_argument("--num-workers", default=4, type=int, help="number of workers in dataloader")
parser.add_argument("-inc", "--in-channels", default=3, type=int, help="number of channels of input")
# misc
parser.add_argument("--test", default=False, type=bool, help="train one epoch for test")
parser.add_argument("--board", default=True, type=bool, help="using tensorboard to record") # temporay use tensorboard as default
return parser
def infogan_options(parser):
parser.add_argument("-ncz", default=3, type=int, help="number of continuous factors")
parser.add_argument("-ndlist", default=[10], type=list, help="number of classes of each discrete factors")
parser.add_argument("--lambda", default=1, type=float, help="hyperparameters of mutual information")
return parser
def model_options(parser):
parser.add_argument("--dim-z", default=100, type=int, help="dimension of latent variable z")
return parser
def get_traverse_options():
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True, help="full path of saved netG.pth")
parser.add_argument("--didx", type=int, default=-1, help="index of desired traversal discrete variable")
parser.add_argument("--cidx", type=int, default=-1, help="index of desired traversal continuous variable")
parser.add_argument("--c-range", type=tuple, default=(-2,2), help="range of continuous variable in traversal")
parser.add_argument("--out-name", default="test", type=str, help="name of output images")
parser.add_argument("--seed", default=5224, type=int, help="random seed")
parser.add_argument("--fixmode", default=False, type=bool, help="using fix mode, fix targeted variables")
parser = model_options(parser)
parser = infogan_options(parser)
parser = normal_options(parser)
return parser.parse_args()
def choose_dataset(opt):
""" choose dataset
"""
data_name = opt.data_name
if data_name == "MNIST":
setattr(opt, "data_path", "/home/victorchen/workspace/Venus/torch_download/")
setattr(opt, "in_channels", 1)
data = get_mnist(opt.data_path, opt.batch_size, opt.num_workers)
elif data_name == "CIFAR10":
setattr(opt, "data_path", "/home/victorchen/workspace/Venus/torch_download/")
setattr(opt, "in_channels", 3)
data = get_cifar10(opt.data_path, opt.batch_size, opt.num_workers)
elif data_name == "FASHION":
setattr(opt, "data_path", "/home/victorchen/workspace/Venus/torch_download/FashionMNIST")
setattr(opt, "in_channels", 1)
data = get_fashion(opt.data_path, opt.batch_size, opt.num_workers)
elif data_name == "SVHN":
setattr(opt, "data_path", "/home/victorchen/workspace/Venus/torch_download/svhn")
setattr(opt, "in_channels", 3)
data = get_svhn(opt.data_path, opt.batch_size, opt.num_workers)
elif data_name == "CELEBA":
setattr(opt, "data_path", "/home/victorchen/workspace/Venus/celebA")
setattr(opt, "in_channels", 3)
data = get_unlabeled_celebA(opt.data_path, opt.batch_size, opt.num_workers)
else:
raise NotImplementedError("Not implemented dataset: {}".format(data_name))
return data
class _MetaOptions:
""" options-like object
"""
def __str__(self):
return ";".join(["{}:{}".format(key,val) for key, val in self.__dict__.items()])
@staticmethod
def kws2opts(**kws):
""" Recursively convert all keyword input to option like object.
"""
return _MetaOptions.dict2opts(kws)
@staticmethod
def dict2opts(d: dict):
""" Recursively convert dict to option like object.
"""
o = _MetaOptions()
def _parse(obj, dt: dict):
for key, val in dt.items():
if not isinstance(key, str):
raise AttributeError("Not allowed key in dict with type:{}".format(type(key)))
if isinstance(val, dict):
t = _MetaOptions()
setattr(obj, key, t)
_parse(t, val)
else:
setattr(obj, key, val)
return obj
return _parse(o, d)
if __name__ == "__main__":
opt = _MetaOptions.kws2opts(name="test", lr=1e-3, epochs=20)
print(opt.name)
print(opt.lr, opt.epochs)