diff --git a/src/background_workers.c b/src/background_workers.c index d8cd7788e..183fa162f 100644 --- a/src/background_workers.c +++ b/src/background_workers.c @@ -391,19 +391,17 @@ void *RedisAI_Run_ThreadMain(void *arg) { // Run is over, now iterate over the run info structs in the batch // and see if any error was generated - int dagError = 0; + bool first_dag_error = false; for (long long i = 0; i < array_len(batch_rinfo); i++) { RedisAI_RunInfo *rinfo = batch_rinfo[i]; - // We lock on the DAG because error could be set from - // other threads operating on the same DAG (TODO: use atomic) - dagError = __atomic_load_n(rinfo->dagError, __ATOMIC_RELAXED); - // We record that there was an error for later on - run_error = dagError; - + run_error = __atomic_load_n(rinfo->dagError, __ATOMIC_RELAXED); + if (i == 0 && run_error == 1) { + first_dag_error = true; + } // If there was an error and the reference count for the dag // has gone to zero and the client is still around, we unblock - if (dagError) { + if (run_error) { RedisAI_RunInfo *orig = rinfo->orig_copy; long long dagRefCount = RAI_DagRunInfoFreeShallowCopy(rinfo); if (dagRefCount == 0) { @@ -415,12 +413,17 @@ void *RedisAI_Run_ThreadMain(void *arg) { __atomic_add_fetch(rinfo->dagCompleteOpCount, 1, __ATOMIC_RELAXED); } } + if (first_dag_error) { + run_queue_len = queueLength(run_queue_info->run_queue); + continue; + } } // We initialize variables where we'll store the fact hat, after the current // run, all ops for the device or all ops in the dag could be complete. This // way we can avoid placing the op back on the queue if there's nothing left // to do. + RedisModule_Assert(run_error == 0); int device_complete_after_run = RedisAI_DagDeviceComplete(batch_rinfo[0]); int dag_complete_after_run = RedisAI_DagComplete(batch_rinfo[0]); diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index 70a73f415..1aea211ee 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -63,21 +63,21 @@ def test_onnx_modelrun_mnist(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("No graph was found in the protobuf.", exception.__str__()) + env.assertEqual("No graph was found in the protobuf.", str(exception)) try: con.execute_command('AI.MODELSET', 'm_1{1}', 'ONNX', 'BLOB', model_pb) except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("Invalid DEVICE", exception.__str__()) + env.assertEqual("Invalid DEVICE", str(exception)) try: con.execute_command('AI.MODELSET', 'm_2{1}', model_pb) except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("wrong number of arguments for 'AI.MODELSET' command", exception.__str__()) + env.assertEqual("wrong number of arguments for 'AI.MODELSET' command", str(exception)) con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) @@ -86,56 +86,64 @@ def test_onnx_modelrun_mnist(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("model key is empty", exception.__str__()) + env.assertEqual("model key is empty", str(exception)) try: con.execute_command('AI.MODELRUN', 'm_2{1}', 'INPUTS', 'a{1}', 'b{1}', 'c{1}') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("model key is empty", exception.__str__()) + env.assertEqual("model key is empty", str(exception)) try: con.execute_command('AI.MODELRUN', 'm_3{1}', 'a{1}', 'b{1}', 'c{1}') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("model key is empty", exception.__str__()) + env.assertEqual("model key is empty", str(exception)) try: con.execute_command('AI.MODELRUN', 'm_1{1}', 'OUTPUTS', 'c{1}') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("model key is empty", exception.__str__()) + env.assertEqual("model key is empty", str(exception)) try: con.execute_command('AI.MODELRUN', 'm{1}', 'OUTPUTS', 'c{1}') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("INPUTS not specified", exception.__str__()) + env.assertEqual("INPUTS not specified", str(exception)) try: con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'b{1}') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("tensor key is empty", exception.__str__()) + env.assertEqual("tensor key is empty", str(exception)) try: con.execute_command('AI.MODELRUN', 'm_1{1}', 'INPUTS', 'OUTPUTS') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("model key is empty", exception.__str__()) + env.assertEqual("model key is empty", str(exception)) try: con.execute_command('AI.MODELRUN', 'm_1{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("model key is empty", exception.__str__()) + env.assertEqual("model key is empty", str(exception)) + + # This error is caught after the model is sent to the backend, not in parsing like before. + try: + con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'a{1}', 'OUTPUTS', 'b{1}') + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) + env.assertEqual('Expected 1 inputs but got 2', str(exception)) con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}')