Skip to content

Commit

Permalink
Make use of model_name and classifier consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
auroracramer committed Jul 9, 2019
1 parent 9ce6b2a commit e3c556b
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 37 deletions.
2 changes: 1 addition & 1 deletion birdvoxclassify/__init__.py
@@ -1,5 +1,5 @@
from .version import version as __version__
from .core import predict, get_output_path, process_file, load_model, \
from .core import predict, get_output_path, process_file, load_classifier, \
format_pred, format_pred_batch, compute_pcen, \
get_taxonomy_node, get_taxonomy_path, get_model_path, \
get_pcen_settings, batch_generator, DEFAULT_MODEL_NAME
12 changes: 6 additions & 6 deletions birdvoxclassify/cli.py
Expand Up @@ -37,7 +37,7 @@ def get_file_list(input_list):


def run(inputs, output_dir=None, output_summary_path=None,
classifier_name=DEFAULT_MODEL_NAME, batch_size=512, suffix="",
model_name=DEFAULT_MODEL_NAME, batch_size=512, suffix="",
logger_level=logging.INFO):
"""Runs classification model on input audio clips"""
# Set logger level.
Expand Down Expand Up @@ -67,7 +67,7 @@ def run(inputs, output_dir=None, output_summary_path=None,
file_list,
output_dir=output_dir,
output_summary_path=output_summary_path,
classifier_name=classifier_name,
model_name=model_name,
batch_size=batch_size,
suffix=suffix,
logger_level=logger_level)
Expand Down Expand Up @@ -98,9 +98,9 @@ def parse_args(args):
help='Directory to save individual output file(s)')

parser.add_argument(
'--classifier-name', '-c', default=DEFAULT_MODEL_NAME,
dest='classifier_name',
help='Name of bird species classifier to be used.')
'--model-name', '-c', default=DEFAULT_MODEL_NAME,
dest='model_name',
help='Name of bird species classifier model to be used.')

parser.add_argument(
'--batch-size', '-b', type=positive_int, default=512, dest='batch_size',
Expand Down Expand Up @@ -155,7 +155,7 @@ def main():
run(args.inputs,
output_dir=args.output_dir,
output_summary_path=args.output_summary_path,
classifier_name=args.classifier_name,
model_name=args.model_name,
batch_size=args.batch_size,
suffix=args.suffix,
logger_level=logger_level)
Expand Down
10 changes: 5 additions & 5 deletions birdvoxclassify/core.py
Expand Up @@ -65,7 +65,7 @@ def process_file(filepaths, output_dir=None, output_summary_path=None,

# Load the classifier.
if classifier is None:
classifier = load_model(model_name, custom_objects=custom_objects)
classifier = load_classifier(model_name, custom_objects=custom_objects)

if taxonomy is None:
taxonomy_path = get_taxonomy_path(model_name)
Expand Down Expand Up @@ -586,7 +586,7 @@ def get_model_path(model_name):
return os.path.abspath(path)


def load_model(model_name, custom_objects=None):
def load_classifier(model_name, custom_objects=None):
"""
Loads bird species classification model of the given name.
Expand All @@ -600,7 +600,7 @@ def load_model(model_name, custom_objects=None):
Returns
-------
model : keras.models.Model
classifier : keras.models.Model
Bird species classification model
"""
Expand All @@ -614,15 +614,15 @@ def load_model(model_name, custom_objects=None):
# Suppress TF and Keras warnings when importing
warnings.simplefilter("ignore")
import keras
model = keras.models.load_model(
classifier = keras.models.load_model(
model_path, custom_objects=custom_objects)
except Exception:
exc_str = 'Could not open model "{}":\n{}'
formatted_trace = traceback.format_exc()
exc_formatted_str = exc_str.format(model_path, formatted_trace)
raise BirdVoxClassifyError(exc_formatted_str)

return model
return classifier


def get_taxonomy_path(model_name):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Expand Up @@ -69,7 +69,7 @@ def test_parse_args():
args = parse_args(args)
assert args.output_dir is None
assert args.output_summary_path is None
assert args.classifier_name == MODEL_NAME
assert args.model_name == MODEL_NAME
assert args.batch_size == 512
assert args.suffix == ""
assert args.quiet is False
Expand All @@ -86,7 +86,7 @@ def test_parse_args():
args = parse_args(args)
assert args.output_dir == '/tmp/output/dir'
assert args.output_summary_path == '/tmp/summary.json'
assert args.classifier_name == MODEL_NAME
assert args.model_name == MODEL_NAME
assert args.batch_size == 16
assert args.suffix == 'suffix'
assert args.quiet is True
Expand Down
47 changes: 24 additions & 23 deletions tests/test_core.py
Expand Up @@ -34,7 +34,7 @@
def test_process_file():
test_output_dir = tempfile.mkdtemp()
test_audio_dir = tempfile.mkdtemp()
model = load_model(MODEL_NAME)
classifier = load_classifier(MODEL_NAME)
with open(TAXV1_HIERARCHICAL_PATH) as f:
taxonomy = json.load(f)
test_output_summary_path = os.path.join(test_output_dir, "summary.json")
Expand All @@ -46,23 +46,24 @@ def test_process_file():

try:
# Test with defaults
output = process_file(CHIRP_PATH, classifier_name=MODEL_NAME)
output = process_file(CHIRP_PATH, model_name=MODEL_NAME)
assert type(output) == dict
assert len(output) == 1
for k, v in output.items():
assert isinstance(k, string_types)
assert type(v) == dict

# Test with list
output = process_file([CHIRP_PATH], classifier_name=MODEL_NAME)
output = process_file([CHIRP_PATH], model_name=MODEL_NAME)
assert type(output) == dict
assert len(output) == 1
for k, v in output.items():
assert isinstance(k, string_types)
assert type(v) == dict

# Test with given classifier and taxonomy
output = process_file([CHIRP_PATH], classifier=model, taxonomy=taxonomy)
output = process_file([CHIRP_PATH], classifier=classifier,
taxonomy=taxonomy)
assert type(output) == dict
assert len(output) == 1
for k, v in output.items():
Expand All @@ -71,7 +72,7 @@ def test_process_file():

# Test output_dir
output = process_file([CHIRP_PATH], output_dir=test_output_dir,
classifier=model)
classifier=classifier)
assert type(output) == dict
assert len(output) == 1
for k, v in output.items():
Expand All @@ -83,7 +84,7 @@ def test_process_file():

# Test output dir with suffix
output = process_file([CHIRP_PATH], output_dir=test_output_dir,
classifier=model, suffix='suffix')
classifier=classifier, suffix='suffix')
assert type(output) == dict
assert len(output) == 1
for k, v in output.items():
Expand All @@ -96,7 +97,7 @@ def test_process_file():
# Test output summary file
output = process_file([CHIRP_PATH],
output_summary_path=test_output_summary_path,
classifier=model)
classifier=classifier)
assert type(output) == dict
assert len(output) == 1
for k, v in output.items():
Expand All @@ -119,15 +120,15 @@ def test_process_file():
test_audio_list = [test_a_path, test_b_path, test_c_path, test_d_path]

# Test multiple files
output = process_file(test_audio_list, classifier=model)
output = process_file(test_audio_list, classifier=classifier)
assert type(output) == dict
assert len(output) == len(test_audio_list)
for k, v in output.items():
assert isinstance(k, string_types)
assert type(v) == dict

# Test with different batch_sizes
output = process_file(test_audio_list, classifier=model, batch_size=2)
output = process_file(test_audio_list, classifier=classifier, batch_size=2)
assert type(output) == dict
assert len(output) == len(test_audio_list)
for k, v in output.items():
Expand All @@ -137,7 +138,7 @@ def test_process_file():
# Make sure we create the output dir if it doesn't exist
shutil.rmtree(test_output_dir)
output = process_file([CHIRP_PATH], output_dir=test_output_dir,
classifier=model)
classifier=classifier)
assert type(output) == dict
assert len(output) == 1
for k, v in output.items():
Expand Down Expand Up @@ -475,35 +476,35 @@ def test_compute_pcen():


def test_predict():
model = load_model(MODEL_NAME)
classifier = load_classifier(MODEL_NAME)

audio, sr = sf.read(CHIRP_PATH, dtype='float64')
pcen = compute_pcen(audio, sr)
pred = predict(pcen, model, logging.INFO)
pred = predict(pcen, classifier, logging.INFO)
assert type(pred) == list
assert pred[0].shape == (1, 1)
assert pred[1].shape == (1, 5)
assert pred[2].shape == (1, 15)

gen = batch_generator([CHIRP_PATH]*10, batch_size=10)
batch, batch_filepaths = next(gen)
pred = predict(batch, model, logging.INFO)
pred = predict(batch, classifier, logging.INFO)
assert type(pred) == list
assert pred[0].shape == (10, 1)
assert pred[1].shape == (10, 5)
assert pred[2].shape == (10, 15)

# Test invalid inputs
inv_pcen = compute_pcen(audio, sr, input_format=False)[..., np.newaxis]
pytest.raises(BirdVoxClassifyError, predict, inv_pcen, model, logging.INFO)
pytest.raises(BirdVoxClassifyError, predict, np.array([1, 2, 3, 4]), model,
pytest.raises(BirdVoxClassifyError, predict, inv_pcen, classifier, logging.INFO)
pytest.raises(BirdVoxClassifyError, predict, np.array([1, 2, 3, 4]), classifier,
logging.INFO)
pytest.raises(BirdVoxClassifyError, predict, np.zeros((1, 42, 104, 1)),
model, logging.INFO)
classifier, logging.INFO)
pytest.raises(BirdVoxClassifyError, predict, np.zeros((1, 120, 42, 1)),
model, logging.INFO)
classifier, logging.INFO)
pytest.raises(BirdVoxClassifyError, predict, np.zeros((1, 120, 104, 42)),
model, logging.INFO)
classifier, logging.INFO)


def test_get_output_path():
Expand Down Expand Up @@ -598,21 +599,21 @@ def test_get_model_path():
assert os.path.abspath(model_path) == os.path.abspath(exp_model_path)


def test_load_model():
model = load_model(MODEL_NAME)
assert type(model) == keras.models.Model
def test_load_classifier():
classifier = load_classifier(MODEL_NAME)
assert type(classifier) == keras.models.Model

# Test invalid inputs
invalid_path = get_model_path("invalid-classifier")
with open(invalid_path, "w") as f:
f.write("INVALID")

try:
pytest.raises(BirdVoxClassifyError, load_model, "invalid-classifier")
pytest.raises(BirdVoxClassifyError, load_classifier, "invalid-classifier")
finally:
os.remove(invalid_path)

pytest.raises(BirdVoxClassifyError, load_model, "/invalid/path")
pytest.raises(BirdVoxClassifyError, load_classifier, "/invalid/path")


def test_get_taxonomy_path():
Expand Down

0 comments on commit e3c556b

Please sign in to comment.