Skip to content

Commit

Permalink
Improved feature generator
Browse files Browse the repository at this point in the history
-fixed imports
-batch size can now be controlled by an argument in the feature generator
  • Loading branch information
zimmerrol committed Jun 28, 2018
1 parent 94c0a7a commit 66ba674
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
7 changes: 4 additions & 3 deletions generate_features.py
Expand Up @@ -56,7 +56,8 @@ def data_generator(filenames, image_directory, batch_size=64):
@click.option("--encoder", "-e", default="VGG19", required=False, type=click.STRING)
@click.option("--layer-name", "-l", default="block5_conv4", required=False, type=click.STRING)
@click.option("--output-folder", "-o", default=".", required=False, type=click.Path(exists=True, file_okay=False, dir_okay=True))
def cmd(data_path, encoder, layer_name, output_folder):
@click.option("--batch-size", "-b", default=64, required=False, type=click.INT)
def cmd(data_path, encoder, layer_name, output_folder, batch_size):
# create data directory if it does not exist
os.makedirs(data_path, exist_ok=True)

Expand All @@ -73,14 +74,14 @@ def cmd(data_path, encoder, layer_name, output_folder):

with h5py.File(os.path.join(output_folder, "image.features.train.{0}.{1}.h5".format(encoder, layer_name)), "w") as h5:
index = 0
for batch in encode_features(model, filenames_train, os.path.join(data_path, "train2017")):
for batch in encode_features(model, filenames_train, os.path.join(data_path, "train2017"), batch_size=batch_size):
for item in batch:
h5.create_dataset(str(index), data=item, compression="lzf")
index += 1

with h5py.File(os.path.join(output_folder, "image.features.val.{0}.{1}.h5".format(encoder, layer_name)), "w") as h5:
index = 0
for batch in encode_features(model, filenames_val, os.path.join(data_path, "val2017")):
for batch in encode_features(model, filenames_val, os.path.join(data_path, "val2017"), batch_size=batch_size):
for item in batch:
h5.create_dataset(str(index), data=item, compression="lzf")
index += 1
Expand Down
Empty file added utility/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions utility/coco.py
Expand Up @@ -37,8 +37,8 @@

import json
import os
import download
from cache import cache
import utility.download as download
from utility.cache import cache

########################################################################

Expand Down
3 changes: 1 addition & 2 deletions utility/utility.py
@@ -1,6 +1,5 @@
import numpy as np
import hickle
import coco
import utility.coco as coco
import h5py

def load_validation_data(maximum_caption_length):
Expand Down

0 comments on commit 66ba674

Please sign in to comment.