forked from tensorflow/similarity
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix tensorflow#224 Add tests for architecture modules.
* Add test coverage for all model architectures * resnet18.build_resnet is now consistent with the other architectures. The function now connects the x input_layer to the model and returns the output layer of the model. * Remove the keras.application.imagenet_utils.preprocess_input() function from resnet18.build_resnet(). * Make the application of pooling and include_tup consistent for all architectures. Pooling is now first applied before checking include_top. * Add min_pixel_value param for visualization.visualize_views to ensure that we properly scale the images when plotting. * Update unsupervised_hello_world to include the prepocess scaling on all dataset. Previously it was not included in the callback and this would break the binary_accuracy in the EvalCallback().
- Loading branch information
1 parent
df0929a
commit 412ab0b
Showing
8 changed files
with
381 additions
and
209 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import re | ||
|
||
import pytest | ||
|
||
from tensorflow_similarity.architectures import resnet18 | ||
|
||
|
||
def test_include_top(): | ||
input_shape = (32, 32, 3) | ||
resnet = resnet18.ResNet18Sim(input_shape, include_top=True) | ||
|
||
# The second to last layer should use gem pooling when include_top is True | ||
assert resnet.layers[-2].name == 'gem_pool' | ||
assert resnet.layers[-2].p == 3.0 | ||
# The default is l2_norm True, so we expect the last layer to be | ||
# MetricEmbedding. | ||
assert re.match('metric_embedding', resnet.layers[-1].name) is not None | ||
|
||
|
||
def test_l2_norm_false(): | ||
input_shape = (32, 32, 3) | ||
resnet = resnet18.ResNet18Sim( | ||
input_shape, | ||
include_top=True, | ||
l2_norm=False) | ||
|
||
# The second to last layer should use gem pooling when include_top is True | ||
assert resnet.layers[-2].name == 'gem_pool' | ||
assert resnet.layers[-2].p == 3.0 | ||
# If l2_norm is False, we should return a dense layer as the last layer. | ||
assert re.match('dense', resnet.layers[-1].name) is not None | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"pooling, name", | ||
zip(['gem', 'avg', 'max'], ['gem_pool', 'avg_pool', 'max_pool']), | ||
ids=['gem', 'avg', 'max'] | ||
) | ||
def test_include_top_false(pooling, name): | ||
input_shape = (32, 32, 3) | ||
resnet = resnet18.ResNet18Sim( | ||
input_shape, | ||
include_top=False, | ||
pooling=pooling) | ||
|
||
# The second to last layer should use gem pooling when include_top is True | ||
assert resnet.layers[-1].name == name |
Oops, something went wrong.