-
Notifications
You must be signed in to change notification settings - Fork 0
/
multi_label_vgg_nn.py
27 lines (22 loc) · 1.13 KB
/
multi_label_vgg_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from preprocessing import get_attribute_dims
from classifier import get_pretrained_model, create_attributes_model, AttributeFCN
# Labels
labels_file = "./data/data/attributes.csv"
label_values_file = "./data/data/label_values.json"
# Trianing validation images
TRAIN_IMAGES_FOLDER = "./data/data/train"
VALID_IMAGES_FOLDER = "./data/data/test"
# TEST_IMAGES_FOLDER = ""
if __name__ == "__main__":
target_dims = get_attribute_dims(label_values_file)
print(target_dims)
pretrained_conv_model, _, _ = get_pretrained_model( "vgg16", pop_last_pool_layer=True )
attribute_models = create_attributes_model( AttributeFCN, 512, pretrained_conv_model,
target_dims,
"weights/vgg16-fcn-266-2/",
labels_file,
TRAIN_IMAGES_FOLDER,
VALID_IMAGES_FOLDER,
num_epochs=10,
is_train=True,
use_gpu=False )