-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[onert-micro] how to support on-device training on model with GRU ? #13365
Comments
@BalyshevArtem PTAL |
I think better to use custom GRU. It also will have better latency and memory consumption effect. And in my opinions easier to support (maybe I am wrong). |
Here is a reference GRU model and fused GRU model by #13602 tflite model is generated by the following code : import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import regularizers
import numpy as np
adapt_data = np.array([[0., 7., 4. , 0.5],
[2., 9., 6. , -0.5],
[0., 7., 4. , -0.5],
[2., 9., 6. , 0.5]], dtype='float32')
#normalization_layer.adapt(adapt_data)
classes = 4
activation = 'tanh'
model = tf.keras.models.Sequential([
tf.keras.Input(shape=(10,4)),
normalization_layer,
tf.keras.layers.GRU(units=20, activation=activation, use_bias=True, bias_initializer="ones"),
tf.keras.layers.Dense(classes, activation='softmax')
])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))
model.summary()
run_model = tf.function(lambda x: model(x))
# This is important, let's fix the input size.
BATCH_SIZE = 1
X = 10
Y = 4
concrete_func = run_model.get_concrete_function(
tf.TensorSpec([BATCH_SIZE, X,Y], model.inputs[0].dtype))
# model directory.
MODEL_DIR = "keras_model"
model.save(MODEL_DIR, save_format="tf", signatures=concrete_func)
converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
]
#converter.optimizations = [tf.lite.Optimize.DEFAULT]
converted_model = converter.convert()
save_to = "GRU.tflite"
if save_to is not None:
with open(save_to, 'wb') as tf_lite_file:
tf_lite_file.write(converted_model) and apply #13625 |
Let's try to train GRU operation with model at #13365 (comment) |
|
Training resultThere is training result for #13737 Model obtained from: import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import regularizers
import numpy as np
classes = 4
activation = 'tanh'
model = tf.keras.models.Sequential([
tf.keras.Input(shape=(60,3)),
tf.keras.layers.GRU(units=60, activation=activation),
tf.keras.layers.Dense(classes, activation='softmax')
])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))
model.summary()
run_model = tf.function(lambda x: model(x))
# This is important, let's fix the input size.
BATCH_SIZE = 1
X = 60
Y = 3
concrete_func = run_model.get_concrete_function(
tf.TensorSpec([BATCH_SIZE, X,Y], model.inputs[0].dtype))
# model directory.
MODEL_DIR = "keras_model"
model.save(MODEL_DIR, save_format="tf", signatures=concrete_func)
converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
]
#converter.optimizations = [tf.lite.Optimize.DEFAULT]
converted_model = converter.convert()
save_to = "gru_stick.tflite"
if save_to is not None:
with open(save_to, 'wb') as tf_lite_file:
tf_lite_file.write(converted_model) Training data is data for a targeted model. In this experiment, 1000 random samples were used for training and 150 for testing from the original training data. Task is a classification task. I used cross entropy as loss and accuracy as metric. Initial values:
Train only last (FullyConnected) layer:
Train last FullyConnected + GRU:
Thus, it can be seen that the GRU layer is trained and helps to achieve better results in this task. |
GRU operation in circle can be defined in two ways. During conversion from Keras, it may be converted into :
IMHO, onert-micro is not ready to handle training on multi subgraph.
The text was updated successfully, but these errors were encountered: