diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 2227e128de84..63521a67b065 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -1101,6 +1101,7 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab): for t_idx, out in enumerate(outs): name = outname + ":" + str(t_idx) etab.set_expr(name, out) + return outs def from_keras(model, shape=None, layout="NCHW"): @@ -1136,6 +1137,85 @@ def _convert_input_layer(keras_layer): input_shape = shape[input_name] if shape is not None and input_name in shape else None etab.set_expr(input_name, new_var(input_name, shape=input_shape)) + def _convert_layer(keras_layer, etab, scope=""): + inbound_nodes = ( + keras_layer.inbound_nodes + if hasattr(keras_layer, "inbound_nodes") + else keras_layer._inbound_nodes + if hasattr(keras_layer, "_inbound_nodes") + else None + ) + if inbound_nodes is None: + raise TypeError( + "Unknown layer type or unsupported Keras version : {}".format(keras_layer) + ) + outs = [] + for node_idx, node in enumerate(inbound_nodes): + # If some nodes in imported model are not relevant to the current model, + # skip such layers. + # - In Keras, model._network_nodes contains keys of all nodes relevant to the + # current model; + # - In tf.Keras, this is already done as part of tensorflow.keras.network.get_config + if ( + not is_tf_keras + and not model._node_key(keras_layer, node_idx) in model._network_nodes + ): + continue + inexpr = [] + # Since Keras allows creating multiple layers from the same name instance, + # we append node index to the expr name to make it unique. + # The one exception is InputLayer. Changing input variable names after conversion + # would confuse users, so we should keep them as far as possible. Fortunately, + # they are named uniquely to input_1, input_2, input_3... by default. + # node_indices attribute removed in tensorflow 2.3, however iterate_inbound() can + # be used + if hasattr(node, "node_indices"): + zip_node = zip( + _as_list(node.inbound_layers), + _as_list(node.node_indices), + _as_list(node.tensor_indices), + _as_list(node.input_tensors), + ) + node_attributes = zip_node + else: + node_attributes = node.iterate_inbound() + for inbound_layer, n_idx, t_idx, _ in node_attributes: + if isinstance(inbound_layer, input_layer_class): + expr_name = inbound_layer.name + _convert_input_layer(inbound_layer) + else: + expr_name = scope + inbound_layer.name + ":" + str(n_idx) + ":" + str(t_idx) + expr = etab.get_expr(expr_name) + inexpr.append(expr) + + # Handle nested layers + if hasattr(keras_layer, "layers"): + input_index = 0 + for layer in keras_layer.layers: + if isinstance(layer, input_layer_class): + # Replace input layer with inbound node + etab.set_expr(layer.name, inexpr[input_index]) + input_index += 1 + else: + # Convert child layer. Prepend scope with parent layer name. + layer_outs = _convert_layer(layer, etab, keras_layer.name + "_" + scope) + + # Get output of last child layer and mark as output of parent. + outname = keras_layer.name + ":" + str(node_idx) + for t_idx, out in enumerate(layer_outs): + name = outname + ":" + str(t_idx) + etab.set_expr(name, out) + outs.extend(layer_outs) + else: + if len(inexpr) == 1: + inexpr = inexpr[0] + outs.extend( + keras_op_to_relay( + inexpr, keras_layer, scope + keras_layer.name + ":" + str(node_idx), etab + ) + ) + return outs + is_tf_keras = _check_model_is_tf_keras() if not is_tf_keras: @@ -1174,57 +1254,8 @@ def _convert_input_layer(keras_layer): if isinstance(keras_layer, input_layer_class): _convert_input_layer(keras_layer) else: - inbound_nodes = ( - keras_layer.inbound_nodes - if hasattr(keras_layer, "inbound_nodes") - else keras_layer._inbound_nodes - if hasattr(keras_layer, "_inbound_nodes") - else None - ) - if inbound_nodes is None: - raise TypeError( - "Unknown layer type or unsupported Keras version : {}".format(keras_layer) - ) - for node_idx, node in enumerate(inbound_nodes): - # If some nodes in imported model are not relevant to the current model, - # skip such layers. - # - In Keras, model._network_nodes contains keys of all nodes relevant to the - # current model; - # - In tf.Keras, this is already done as part of tensorflow.keras.network.get_config - if ( - not is_tf_keras - and not model._node_key(keras_layer, node_idx) in model._network_nodes - ): - continue - inexpr = [] - # Since Keras allows creating multiple layers from the same name instance, - # we append node index to the expr name to make it unique. - # The one exception is InputLayer. Changing input variable names after conversion - # would confuse users, so we should keep them as far as possible. Fortunately, - # they are named uniquely to input_1, input_2, input_3... by default. - # node_indices attribute removed in tensorflow 2.3, however iterate_inbound() can - # be used - if hasattr(node, "node_indices"): - zip_node = zip( - _as_list(node.inbound_layers), - _as_list(node.node_indices), - _as_list(node.tensor_indices), - _as_list(node.input_tensors), - ) - node_attributes = zip_node - else: - node_attributes = node.iterate_inbound() - for inbound_layer, n_idx, t_idx, _ in node_attributes: - if isinstance(inbound_layer, input_layer_class): - expr_name = inbound_layer.name - _convert_input_layer(inbound_layer) - else: - expr_name = inbound_layer.name + ":" + str(n_idx) + ":" + str(t_idx) - expr = etab.get_expr(expr_name) - inexpr.append(expr) - if len(inexpr) == 1: - inexpr = inexpr[0] - keras_op_to_relay(inexpr, keras_layer, keras_layer.name + ":" + str(node_idx), etab) + _convert_layer(keras_layer, etab) + # model._output_coordinates contains out_node(oc[0]), node_index(oc[1]) and tensor_index(oc[2]) # Get all output nodes in etab using the name made from above values. # The out exprs were added to etab in keras_op_to_relay using this name. diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index e9420a55b6e8..709bebfc232c 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -579,6 +579,20 @@ def test_forward_global_pool3d(self, keras): keras_model = keras.models.Model(data, x) verify_keras_frontend(keras_model, layout="NDHWC") + def test_forward_nested_layers(self, keras): + sub_model = keras.applications.MobileNet( + include_top=False, weights="imagenet", input_shape=(224, 224, 3) + ) + keras_model = keras.Sequential( + [ + sub_model, + keras.layers.GlobalAveragePooling2D(), + keras.layers.Dense(1024, activation="relu"), + keras.layers.Dense(2, activation="sigmoid"), + ] + ) + verify_keras_frontend(keras_model) + if __name__ == "__main__": for k in [keras, tf_keras]: