Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to compile tensorflow model with input_signature? #802

Closed
mostafafarzaneh opened this issue Dec 12, 2023 · 11 comments
Closed

How to compile tensorflow model with input_signature? #802

mostafafarzaneh opened this issue Dec 12, 2023 · 11 comments

Comments

@mostafafarzaneh
Copy link

Currently, I use tensorflow_model_server to serve my model for inference. Here is the export code that works fine:

model = tf.keras.models.load_model("model.hdf5")

def __decode_images(images, nch):
    o = tf.vectorized_map(lambda x: tf.image.decode_jpeg(x, nch), images)
    o = tf.image.resize(o, (128,128))
    o = tf.cast(o, dtype=tf.float16) / 255
    o = tf.reverse(o, axis=[-1])  # RGB2BGR
    return o


def __encode_images(images):
    images = tf.image.convert_image_dtype(images, tf.uint8, saturate=True)
    o = tf.vectorized_map(tf.image.encode_jpeg, images)
    return o


@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string, name='image')])
def serving(img):
    img = __decode_images(img, 3)
    o = model([img], training=False)
    o = __encode_images(o)
    return {
        'output': o,
    }

tf.saved_model.save(model, export_dir=args.output, signatures=serving

The model will decode the request from jpeg and encode the response to jpeg.

Now I want to compile the model to use it in Inferentia instances. However, I could not find an example or documentation to figure out how to use tfn.trace in this situation.

@mostafafarzaneh
Copy link
Author

I have added the following lines:

image_path = '1.jpg' 
with tf.io.gfile.GFile(image_path, 'rb') as f:
    image_data = f.read()
sample_input = tf.constant([image_data])  # Create a TensorFlow tensor from the image data

model_neuron = tfn.trace(serving, sample_input)

But I got the following errro:

ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.

@jeffhataws
Copy link
Contributor

Hi @mostafafarzaneh ,
Thank you for filing the issue. Will you try the "Export and Compile Saved Model" step here to see if that works for you? Looks like you are using 128x128 image size so just change images_sizes = [128, 128] in the example code to compile and save a model converted to Neuron with the correct input shape. Let us know if you still have problems.

@mostafafarzaneh
Copy link
Author

mostafafarzaneh commented Dec 13, 2023

Thanks @jeffhataws
The problem is I'm using the input_signature to get input as jpg not raw image. That's because I'm using input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string, name='image')].
So the input to trace should be tf.string

@mostafafarzaneh
Copy link
Author

If I do not use input_signature and save the model as is it works fine.

image_sizes=[128,128]
sample_input_raw = tf.random.uniform([1, *image_sizes, 3], dtype=tf.float32)
model_neuron = tfn.trace(model, sample_input_raw)

The thing is I want to send/receive jpg to/from model not raw image.

@jeffhataws
Copy link
Contributor

Hi @mostafafarzaneh ,

You can use TensorFlow Hub to convert the Neuron saved model (after tfn.trace) to one that has preprocessing (done on CPU). For example, the following script can be used to do the conversion after you adapt it to the image sizes and data type that you use:

import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--import_dir", help="saved model to be imported for modification (add jpeg preprocessing)",
    required=True)
parser.add_argument("--export_dir", help="saved model to be exported",
    required=True)
args = parser.parse_args()

#tf.keras.backend.set_learning_phase(0)
#tf.keras.backend.set_floatx('float16')
#model = tf.keras.applications.resnet50.ResNet50(input_shape=(224, 224, 3), weights='imagenet')

model = hub.KerasLayer(args.import_dir,  signature='serving_default', signature_outputs_as_dict=True)

def decode_jpeg_resize(input_tensor):
    # decode jpeg
    tensor = tf.image.decode_png(input_tensor, channels=3)

    # resize
    decoded_shape = tf.shape(tensor)
    tensor = tf.cast(tensor, tf.float32)
    tensor = tf.image.resize(tensor, [224, 224])

    # normalize
    tensor -= np.array([0.485, 0.456, 0.406]).astype(np.float32) * 255.0
    #return tf.cast(tensor, tf.float32)
    # compiled model has been optimized for float16
    return tf.cast(tensor, tf.float16)


@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def serving(input):

    #img = tf.map_fn(decode_jpeg_resize, input, dtype=tf.float32)
    # compiled model has been optimized for float16
    img = tf.map_fn(decode_jpeg_resize, input, dtype=tf.float16)

    # Predict
    predictions = model(img)
    #predictions_precast = model(img)
    #predictions = tf.cast(predictions_precast[0], tf.float32)
    return predictions

tf.saved_model.save(model, export_dir=args.export_dir, signatures=serving)

Let us know if you have problems with this.

@mostafafarzaneh
Copy link
Author

Thanks @jeffhataws

I can confirm that if I compile the model with trace and then use a custom signature, I can save the model with the custom signature.

But, I faced yet another issue. It works fine for molds that expect a fixed image size. However, I have another model that expects a variable size image (None, None, None, 3). But, to use the trace function, I need to create a fixed-size input sample. I faced an error when trying to infer an image.

Here is my code:

Compile:

IMAGE_SIZE = 512

sample_input = np.random.random((1, IMAGE_SIZE, IMAGE_SIZE, 3)).astype(np.float32)  # Example input data
model_neuron = tfn.trace(model, sample_input)
model_neuron.save(args.output)

Convert:

def __decode_images(images, nch):
    o = tf.vectorized_map(lambda x: tf.image.decode_jpeg(x, nch), images)
    o = tf.cast(o, dtype=tf.float16) / 255
    o = tf.reverse(o, axis=[-1])  # RGB2BGR
    return o


def __encode_images(images):
    images = tf.image.convert_image_dtype(images, tf.uint8, saturate=True)
    o = tf.vectorized_map(tf.image.encode_jpeg, images)
    return o


@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string, name='image')])
def serving(img):
    img = __decode_images(img, 3)
    o = model(img, training=False)
    o = __encode_images(o)
    return {
        'output': o
    }

tf.saved_model.save(model, export_dir=args.output, signatures=serving)

This causes the following error in inference:

Pridict Error:  <_InactiveRpcError of RPC that terminated with:
	status = StatusCode.INVALID_ARGUMENT
	details = "{{function_node __inference_pruned_51}} 2 root error(s) found.
  (0) INVALID_ARGUMENT: {{function_node __inference_pruned_51}} Invalid input tensor size: given Tensor<type: float shape: [1,1162,720,3]>, expected size 3145728
	 [[{{node neuron_op_801fd1d5f72893ff}}]]
	 [[neuron_op_801fd1d5f72893ff/_4]]
  (1) INVALID_ARGUMENT: {{function_node __inference_pruned_51}} Invalid input tensor size: given Tensor<type: float shape: [1,1162,720,3]>, expected size 3145728
	 [[{{node neuron_op_801fd1d5f72893ff}}]]
0 successful operations.
0 derived errors ignored.
	 [[StatefulPartitionedCall/StatefulPartitionedCall/aws_neuron_model/StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall]]"
	debug_error_string = "UNKNOWN:Error received from peer  {created_time:"2023-12-18T00:46:21.818117793+00:00", grpc_status:3, grpc_message:"{{function_node __inference_pruned_51}} 2 root error(s) found.\n  (0) INVALID_ARGUMENT: {{function_node __inference_pruned_51}} Invalid input tensor size: given Tensor<type: float shape: [1,1162,720,3]>, expected size 3145728\n\t [[{{node neuron_op_801fd1d5f72893ff}}]]\n\t [[neuron_op_801fd1d5f72893ff/_4]]\n  (1) INVALID_ARGUMENT: {{function_node __inference_pruned_51}} Invalid input tensor size: given Tensor<type: float shape: [1,1162,720,3]>, expected size 3145728\n\t [[{{node neuron_op_801fd1d5f72893ff}}]]\n0 successful operations.\n0 derived errors ignored.\n\t [[StatefulPartitionedCall/StatefulPartitionedCall/aws_neuron_model/StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall]]"}"

@mostafafarzaneh
Copy link
Author

Hi @jeffhataws
Here is more information about the model signature:

1- signature after compile with trace:

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input_1'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 512, 512, 3)
        name: serving_default_input_1:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output_1'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 512, 512, 1)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

Concrete Functions:
  Function Name: '__call__'
    Option #1
      Callable with:
        Argument #1
          input_1: TensorSpec(shape=(None, 512, 512, 3), dtype=tf.float32, name='input_1')

  Function Name: '_default_save_signature'
    Option #1
      Callable with:
        Argument #1
          input_1: TensorSpec(shape=(None, 512, 512, 3), dtype=tf.float32, name='input_1')

  Function Name: 'aws_neuron_function'
    Option #1
      Callable with:
        Argument #1
          args_0

  Function Name: 'call_and_return_all_conditional_losses'
    Option #1
      Callable with:
        Argument #1
          input_1: TensorSpec(shape=(None, 512, 512, 3), dtype=tf.float32, name='input_1')

2- signature after converting the compiled model to custom signature:

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['image'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: serving_default_image:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

Concrete Functions:
  Function Name: '__call__'
    Option #1
      Callable with:
        Argument #1
          input_1: TensorSpec(shape=(None, 512, 512, 3), dtype=tf.float32, name='input_1')

  Function Name: '_default_save_signature'
    Option #1
      Callable with:
        Argument #1
          input_1: TensorSpec(shape=(None, 512, 512, 3), dtype=tf.float32, name='input_1')

  Function Name: 'aws_neuron_function'
    Option #1
      Callable with:
        Argument #1
          args_0

  Function Name: 'call_and_return_all_conditional_losses'
    Option #1
      Callable with:
        Argument #1
          input_1: TensorSpec(shape=(None, 512, 512, 3), dtype=tf.float32, name='input_1')

3- But if I convert the original model to custom signature without compiling with trace, I won't have any problem. Here is the signature.

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['image'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: serving_default_image:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

Concrete Functions:
  Function Name: '__call__'
    Option #1
      Callable with:
        Argument #1
          input: TensorSpec(shape=(None, None, None, 3), dtype=tf.float16, name='input')
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None
    Option #2
      Callable with:
        Argument #1
          inputs: TensorSpec(shape=(None, None, None, 3), dtype=tf.float16, name='inputs')
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None
    Option #3
      Callable with:
        Argument #1
          inputs: TensorSpec(shape=(None, None, None, 3), dtype=tf.float16, name='inputs')
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None
    Option #4
      Callable with:
        Argument #1
          input: TensorSpec(shape=(None, None, None, 3), dtype=tf.float16, name='input')
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None

  Function Name: '_default_save_signature'
    Option #1
      Callable with:
        Argument #1
          input: TensorSpec(shape=(None, None, None, 3), dtype=tf.float16, name='input')

  Function Name: 'call_and_return_all_conditional_losses'
    Option #1
      Callable with:
        Argument #1
          inputs: TensorSpec(shape=(None, None, None, 3), dtype=tf.float16, name='inputs')
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None
    Option #2
      Callable with:
        Argument #1
          input: TensorSpec(shape=(None, None, None, 3), dtype=tf.float16, name='input')
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None
    Option #3
      Callable with:
        Argument #1
          input: TensorSpec(shape=(None, None, None, 3), dtype=tf.float16, name='input')
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None
    Option #4
      Callable with:
        Argument #1
          inputs: TensorSpec(shape=(None, None, None, 3), dtype=tf.float16, name='inputs')
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None

@mrnikwaws
Copy link
Contributor

The short answer is that only fixed shape tensors are supported in Neuron at this time.

The solution customers most commonly apply for this is to use shape "buckets" and then pad their inputs to match one of the compiled sizes. Another approach is to rescale images to the compiled resolution, then do a reverse scaling on the outputs.

Which approach you take will depend on whether the model is sensitive to padding or scaling of the image.

@mostafafarzaneh
Copy link
Author

Thanks @mrnikwaws
Would you happen to have a plan to implement the dynamic shape soon?
Because I'm not an AI expert and just wanted to use the INF. At this stage, it would not be feasible for me to change the model and test it.

@shebbur-aws
Copy link

We have it in our roadmap and will update once its available

@mostafafarzaneh
Copy link
Author

Thanks @shebbur-aws
I might wait for that. Closing this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants