/
extract_features.py
159 lines (132 loc) · 6.39 KB
/
extract_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
# is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
CLI for extracting image features.
"""
import argparse
import os
import pickle
import logging
from contextlib import ExitStack
from typing import List, Tuple
import mxnet as mx
import numpy as np
from . import arguments
from . import encoder
from . import utils
from .. import constants as C
from ..log import setup_main_logger
from ..utils import check_condition, determine_context
# Temporary logger, the real one (logging to a file probably, will be created
# in the main function)
logger = logging.getLogger(__name__)
def batching(iterable, n=1):
length = len(iterable)
for ndx in range(0, length, n):
yield iterable[ndx:min(ndx + n, length)]
def get_pretrained_net(args: argparse.Namespace, context: mx.Context) -> Tuple[mx.mod.Module, Tuple[int]]:
# init encoder
image_cnn_encoder_config = encoder.ImageLoadedCnnEncoderConfig(
model_path=args.image_encoder_model_path,
epoch=args.image_encoder_model_epoch,
layer_name=args.image_encoder_layer,
encoded_seq_len=0,
num_embed=100,
preextracted_features=False) # this num does not matter here
image_cnn_encoder = encoder.ImageLoadedCnnEncoder(image_cnn_encoder_config)
symbol = image_cnn_encoder.sym # this is the net before further encoding
arg_shapes, out_shapes, aux_shapes = symbol.infer_shape(source=(1,) + tuple(args.source_image_size))
last_layer_shape = out_shapes[-1][1:]
# Create module
module = mx.mod.Module(symbol=symbol,
data_names=[C.SOURCE_NAME],
label_names=[],
context=context)
module.bind(for_training=False, data_shapes=[(C.SOURCE_NAME, (args.batch_size,) + tuple(args.source_image_size))])
# Init with pretrained net
initializers = image_cnn_encoder.get_initializers()
init = mx.initializer.Mixed(*zip(*initializers))
module.init_params(init)
return module, last_layer_shape
def extract_features_forward(im, module, image_root, output_root, batch_size, source_image_size, context):
batch = mx.nd.zeros((batch_size,) + tuple(source_image_size), context)
# Reading
out_names = []
for i, v in enumerate(im):
batch[i] = utils.load_preprocess_image(os.path.join(image_root, v), source_image_size[1:])
out_names.append(os.path.join(output_root, v.replace("/", "_")))
# Forward
module.forward(mx.io.DataBatch([batch]))
feats = module.get_outputs()[0].asnumpy()
# Chunk last batch which might be smaller
if len(im) < batch_size:
feats = feats[:len(im)]
return feats, out_names
def read_list_file(inp: str) -> List[str]:
with open(inp, "r") as fd:
data_list = [] # type: List[str]
for i in fd.readlines():
data_list.append(i.split("\n")[0])
return data_list
def main():
setup_main_logger(file_logging=False, console=True)
params = argparse.ArgumentParser(description='CLI to extract features from images.')
arguments.add_image_extract_features_cli_args(params)
args = params.parse_args()
image_root = os.path.abspath(args.image_root)
output_root = os.path.abspath(args.output_root)
output_file = os.path.abspath(args.output)
size_out_file = os.path.join(output_root, "image_feature_sizes.pkl")
if os.path.exists(output_root):
logger.info("Overwriting provided path {}.".format(output_root))
else:
os.makedirs(output_root)
# read image list file
image_list = read_list_file(args.input)
# Get pretrained net module (already bind)
with ExitStack() as exit_stack:
check_condition(len(args.device_ids) == 1, "extract_features only supports single device for now")
context = determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)[0]
module, _ = get_pretrained_net(args, context)
# Extract features
with open(output_file, "w") as fout:
for i, im in enumerate(batching(image_list, args.batch_size)):
logger.info("Processing batch {}/{}".format(i + 1, int(np.ceil(len(image_list) / args.batch_size))))
# TODO: enable caching to reuse features and resume computation
feats, out_names = extract_features_forward(im, module,
image_root,
output_root,
args.batch_size,
args.source_image_size,
context)
# Save to disk
out_file_names = utils.save_features(out_names, feats)
# Write to output file
out_file_names = map(lambda x: os.path.basename(x) + "\n", out_file_names)
fout.writelines(out_file_names)
# Save the image size and feature size
with open(size_out_file, "wb") as fout:
pickle.dump({"image_shape": tuple(args.source_image_size), "features_shape": tuple(feats.shape[1:])}, fout)
# Copy image model to output_folder
image_encoder_model_path = utils.copy_mx_model_to(args.image_encoder_model_path,
args.image_encoder_model_epoch,
output_root)
logger.info("Files saved in {}, {} and {}.".format(output_file,
size_out_file,
image_encoder_model_path))
if __name__ == "__main__":
main()