forked from tensorflow/similarity
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_efficientnet.py
129 lines (96 loc) · 4.01 KB
/
test_efficientnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import re
import pytest
import tensorflow as tf
from tensorflow_similarity.architectures import efficientnet
def test_build_effnet_b0_full():
input_layer = tf.keras.layers.Input((224, 224, 3))
output = efficientnet.build_effnet(input_layer, "b0", "imagenet", "full")
effnet = output._keras_history.layer
assert effnet.name == "efficientnetb0"
assert effnet.trainable
total_layer_count = 0
trainable_layer_count = 0
for layer in effnet._self_tracked_trackables:
total_layer_count += 1
if layer.trainable:
trainable_layer_count += 1
assert total_layer_count == 237
assert trainable_layer_count == 237
def test_build_effnet_b1_frozen():
input_layer = tf.keras.layers.Input((240, 240, 3))
output = efficientnet.build_effnet(input_layer, "b1", "imagenet", "frozen")
effnet = output._keras_history.layer
assert effnet.name == "efficientnetb1"
assert not effnet.trainable
total_layer_count = 0
trainable_layer_count = 0
for layer in effnet._self_tracked_trackables:
total_layer_count += 1
if layer.trainable:
trainable_layer_count += 1
assert total_layer_count == 339
assert trainable_layer_count == 0
def test_build_effnet_b0_partial():
input_layer = tf.keras.layers.Input((224, 224, 3))
output = efficientnet.build_effnet(input_layer, "b0", "imagenet", "partial")
effnet = output._keras_history.layer
assert effnet.name == "efficientnetb0"
assert effnet.trainable
total_layer_count = 0
trainable_layer_count = 0
excluded_layers = 0
for layer in effnet._self_tracked_trackables:
total_layer_count += 1
if layer.trainable:
trainable_layer_count += 1
# Check if any of the excluded layers are trainable
if not re.search("^block[5,6,7]|^top", layer.name):
excluded_layers += 1
if isinstance(layer, tf.keras.layers.BatchNormalization):
excluded_layers += 1
assert total_layer_count == 237
assert trainable_layer_count == 93
assert excluded_layers == 0
def test_build_effnet_unsupported_trainable():
input_layer = tf.keras.layers.Input((224, 224, 3))
msg = "foo is not a supported option for 'trainable'."
with pytest.raises(ValueError, match=msg):
_ = efficientnet.build_effnet(input_layer, "b0", "imagenet", "foo")
def test_unsuported_varient():
input_shape = (224, 224, 3)
msg = "Unknown efficientnet variant. Valid B0...B7"
with pytest.raises(ValueError, match=msg):
_ = efficientnet.EfficientNetSim(input_shape, 128, "bad_varient")
def test_include_top():
input_shape = (224, 224, 3)
effnet = efficientnet.EfficientNetSim(input_shape, include_top=True)
# The second to last layer should use gem pooling when include_top is True
assert effnet.layers[-2].name == 'gem_pool'
assert effnet.layers[-2].p == 3.0
# The default is l2_norm True, so we expect the last layer to be
# MetricEmbedding.
assert re.match('metric_embedding', effnet.layers[-1].name) is not None
def test_l2_norm_false():
input_shape = (224, 224, 3)
effnet = efficientnet.EfficientNetSim(
input_shape,
include_top=True,
l2_norm=False)
# The second to last layer should use gem pooling when include_top is True
assert effnet.layers[-2].name == 'gem_pool'
assert effnet.layers[-2].p == 3.0
# If l2_norm is False, we should return a dense layer as the last layer.
assert re.match('dense', effnet.layers[-1].name) is not None
@pytest.mark.parametrize(
"pooling, name",
zip(['gem', 'avg', 'max'], ['gem_pool', 'avg_pool', 'max_pool']),
ids=['gem', 'avg', 'max']
)
def test_include_top_false(pooling, name):
input_shape = (224, 224, 3)
effnet = efficientnet.EfficientNetSim(
input_shape,
include_top=False,
pooling=pooling)
# The second to last layer should use gem pooling when include_top is True
assert effnet.layers[-1].name == name