Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert-micro] how to support on-device training on model with GRU ? #13365

Open
chunseoklee opened this issue Jul 9, 2024 · 6 comments
Open
Labels
FEATURE_REQUEST A formal request for a new or advanced feature. type/discussion We need discussion. Discussion itself can help. Even without conclusions!

Comments

@chunseoklee
Copy link
Contributor

GRU operation in circle can be defined in two ways. During conversion from Keras, it may be converted into :

IMHO, onert-micro is not ready to handle training on multi subgraph.

@chunseoklee chunseoklee added FEATURE_REQUEST A formal request for a new or advanced feature. type/discussion We need discussion. Discussion itself can help. Even without conclusions! labels Jul 9, 2024
@chunseoklee
Copy link
Contributor Author

@BalyshevArtem PTAL

@BalyshevArtem
Copy link
Contributor

  • Single "Custom" GRU operation as in (onert-micro)

I think better to use custom GRU. It also will have better latency and memory consumption effect. And in my opinions easier to support (maybe I am wrong).

@hseok-oh hseok-oh changed the title [onert-miro] how to support on-device training on model with GRU ? [onert-micro] how to support on-device training on model with GRU ? Jul 10, 2024
@chunseoklee
Copy link
Contributor Author

chunseoklee commented Aug 7, 2024

Here is a reference GRU model and fused GRU model by #13602

gru_fused.zip

tflite model is generated by the following code :

  import tensorflow as tf
  from tensorflow import keras
  from tensorflow.keras import regularizers
  import numpy as np

  adapt_data = np.array([[0., 7., 4. , 0.5],
                         [2., 9., 6. , -0.5],
                         [0., 7., 4. , -0.5],
                         [2., 9., 6. , 0.5]], dtype='float32')
  #normalization_layer.adapt(adapt_data)
  classes = 4
  activation = 'tanh'
  model = tf.keras.models.Sequential([
      tf.keras.Input(shape=(10,4)),
      normalization_layer,
      tf.keras.layers.GRU(units=20, activation=activation, use_bias=True, bias_initializer="ones"),
      tf.keras.layers.Dense(classes, activation='softmax')
  ])

  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))

  model.summary()

  run_model = tf.function(lambda x: model(x))

  # This is important, let's fix the input size.
  BATCH_SIZE = 1
  X = 10
  Y = 4
  concrete_func = run_model.get_concrete_function(
      tf.TensorSpec([BATCH_SIZE, X,Y], model.inputs[0].dtype))

  # model directory.
  MODEL_DIR = "keras_model"
  model.save(MODEL_DIR, save_format="tf", signatures=concrete_func)

  converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
  converter.experimental_new_converter = True
  converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                         ]
  #converter.optimizations = [tf.lite.Optimize.DEFAULT]
  converted_model = converter.convert()
  save_to = "GRU.tflite"
  if save_to is not None:
      with open(save_to, 'wb') as tf_lite_file:
          tf_lite_file.write(converted_model)

and apply #13625

@chunseoklee
Copy link
Contributor Author

Let's try to train GRU operation with model at #13365 (comment)

@chunseoklee
Copy link
Contributor Author

chunseoklee commented Aug 13, 2024

@BalyshevArtem
Copy link
Contributor

BalyshevArtem commented Aug 21, 2024

Training result

There is training result for #13737

Model obtained from:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import regularizers
import numpy as np

classes = 4
activation = 'tanh'
model = tf.keras.models.Sequential([
  tf.keras.Input(shape=(60,3)),
  tf.keras.layers.GRU(units=60, activation=activation),
  tf.keras.layers.Dense(classes, activation='softmax')
])

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))
model.summary()
run_model = tf.function(lambda x: model(x))

# This is important, let's fix the input size.
BATCH_SIZE = 1
X = 60
Y = 3
concrete_func = run_model.get_concrete_function(
  tf.TensorSpec([BATCH_SIZE, X,Y], model.inputs[0].dtype))

# model directory.
MODEL_DIR = "keras_model"
model.save(MODEL_DIR, save_format="tf", signatures=concrete_func)

converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                     ]
#converter.optimizations = [tf.lite.Optimize.DEFAULT]
converted_model = converter.convert()
save_to = "gru_stick.tflite"
if save_to is not None:
  with open(save_to, 'wb') as tf_lite_file:
      tf_lite_file.write(converted_model)

Training data is data for a targeted model. In this experiment, 1000 random samples were used for training and 150 for testing from the original training data. Task is a classification task. I used cross entropy as loss and accuracy as metric.
In order to make sure that the GRU layer is learning, we first train only the last FullyConnected layer in the initial model, and then we train both the FullyConnected layer and the GRU layer.

Initial values:

Test Average ACCURACY = 0.34
Test Average CROSS ENTROPY = 2.54871

Train only last (FullyConnected) layer:

Test Average ACCURACY = 0.61
Test Average CROSS ENTROPY = 0.898501

Train last FullyConnected + GRU:

Test Average ACCURACY = 0.72
Test Average CROSS ENTROPY = 0.751191

Thus, it can be seen that the GRU layer is trained and helps to achieve better results in this task.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FEATURE_REQUEST A formal request for a new or advanced feature. type/discussion We need discussion. Discussion itself can help. Even without conclusions!
Projects
None yet
Development

No branches or pull requests

2 participants