Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction using tflite_flutter takes too long (8 seconds) while same model in Kotlin predicts in 200ms?? #66

Closed
farazk86 opened this issue Jan 21, 2021 · 2 comments

Comments

@farazk86
Copy link

Hi @am15h ,

I'm at my wits end and can't figure out how else to optimize my model but in flutter, prediction takes about 8 to 9 seconds, which is very long.. I thought something was wrong with my model but when I tried the same model in Kotlin, it gave result in under 200 ms.

I'm only taking into account the interpreter.run() command and using Stopwatch() to keep track of it.

 timer.start();
_interpreter.run(inputIds, predictions);
print('inference done in ' + timer.elapsedMilliseconds.toString());
timer.reset();

I'm initializing the model like:

var interpreterOptions = InterpreterOptions()..threads = NUM_LITE_THREADS;
    _interpreter = await Interpreter.fromAsset(
      modelFile,
      options: interpreterOptions,
    );

I'm not using NNAPI as it does not improve the inference speed, and can't use gpudelegate as it fails to initialize model.

My input is of the shape [1, 32] and is of type int8. My outputs are of shape [1, 32, 50527] and if of type float32

I thought this was an error in my model but when I ran the same model in Kotlin using:

tflite.runForMultipleInputsOutputs(arrayOf(inputIds), outputs)

I get the same prediction in under 200ms.. The Kotlin model is initialized on the CPU just like the flutter one is:

private suspend fun loadModel(): Interpreter = withContext(Dispatchers.IO) {
        val assetFileDescriptor = getApplication<Application>().assets.openFd(MODEL_PATH)
        assetFileDescriptor.use {
            val fileChannel = FileInputStream(assetFileDescriptor.fileDescriptor).channel
            val modelBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, it.startOffset, it.declaredLength)

            val opts = Interpreter.Options()
            opts.setNumThreads(NUM_LITE_THREADS)
            return@use Interpreter(modelBuffer, opts)
        }
    }

Is there any reason why the model is performing so poorly in flutter? What can I change fix? Any thoughts on this will be very helpful.

Thank you

@am15h
Copy link
Owner

am15h commented Jan 22, 2021

@farazk86 Are you using multidimensional dart lists for inputIds and predictions, if yes , can you try using TensorBuffer from tflite_flutter_helper instead? I will try to investigate the cause of such terrible performance with dart lists.

@farazk86
Copy link
Author

@farazk86 Are you using multidimensional dart lists for inputIds and predictions, if yes , can you try using TensorBuffer from tflite_flutter_helper instead? I will try to investigate the cause of such terrible performance with dart lists.

Thanks for the reply. Unfortunately, using TensorBuffer did not provide any considerable speedup boost, it reduced inference time by a couple seconds only.

I ended up using flutter's invoke channelMethod to do inference in java. This reduced the inference time to 300ms :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants