Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/background_workers.c
Original file line number Diff line number Diff line change
Expand Up @@ -391,19 +391,17 @@ void *RedisAI_Run_ThreadMain(void *arg) {

// Run is over, now iterate over the run info structs in the batch
// and see if any error was generated
int dagError = 0;
bool first_dag_error = false;
for (long long i = 0; i < array_len(batch_rinfo); i++) {
RedisAI_RunInfo *rinfo = batch_rinfo[i];
// We lock on the DAG because error could be set from
// other threads operating on the same DAG (TODO: use atomic)
dagError = __atomic_load_n(rinfo->dagError, __ATOMIC_RELAXED);

// We record that there was an error for later on
run_error = dagError;

run_error = __atomic_load_n(rinfo->dagError, __ATOMIC_RELAXED);
if (i == 0 && run_error == 1) {
first_dag_error = true;
}
// If there was an error and the reference count for the dag
// has gone to zero and the client is still around, we unblock
if (dagError) {
if (run_error) {
RedisAI_RunInfo *orig = rinfo->orig_copy;
long long dagRefCount = RAI_DagRunInfoFreeShallowCopy(rinfo);
if (dagRefCount == 0) {
Expand All @@ -415,12 +413,17 @@ void *RedisAI_Run_ThreadMain(void *arg) {
__atomic_add_fetch(rinfo->dagCompleteOpCount, 1, __ATOMIC_RELAXED);
}
}
if (first_dag_error) {
run_queue_len = queueLength(run_queue_info->run_queue);
continue;
}
}

// We initialize variables where we'll store the fact hat, after the current
// run, all ops for the device or all ops in the dag could be complete. This
// way we can avoid placing the op back on the queue if there's nothing left
// to do.
RedisModule_Assert(run_error == 0);
int device_complete_after_run = RedisAI_DagDeviceComplete(batch_rinfo[0]);
int dag_complete_after_run = RedisAI_DagComplete(batch_rinfo[0]);

Expand Down
30 changes: 19 additions & 11 deletions tests/flow/tests_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,21 @@ def test_onnx_modelrun_mnist(env):
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("No graph was found in the protobuf.", exception.__str__())
env.assertEqual("No graph was found in the protobuf.", str(exception))

try:
con.execute_command('AI.MODELSET', 'm_1{1}', 'ONNX', 'BLOB', model_pb)
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("Invalid DEVICE", exception.__str__())
env.assertEqual("Invalid DEVICE", str(exception))

try:
con.execute_command('AI.MODELSET', 'm_2{1}', model_pb)
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("wrong number of arguments for 'AI.MODELSET' command", exception.__str__())
env.assertEqual("wrong number of arguments for 'AI.MODELSET' command", str(exception))

con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)

Expand All @@ -86,56 +86,64 @@ def test_onnx_modelrun_mnist(env):
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("model key is empty", exception.__str__())
env.assertEqual("model key is empty", str(exception))

try:
con.execute_command('AI.MODELRUN', 'm_2{1}', 'INPUTS', 'a{1}', 'b{1}', 'c{1}')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("model key is empty", exception.__str__())
env.assertEqual("model key is empty", str(exception))

try:
con.execute_command('AI.MODELRUN', 'm_3{1}', 'a{1}', 'b{1}', 'c{1}')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("model key is empty", exception.__str__())
env.assertEqual("model key is empty", str(exception))

try:
con.execute_command('AI.MODELRUN', 'm_1{1}', 'OUTPUTS', 'c{1}')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("model key is empty", exception.__str__())
env.assertEqual("model key is empty", str(exception))

try:
con.execute_command('AI.MODELRUN', 'm{1}', 'OUTPUTS', 'c{1}')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("INPUTS not specified", exception.__str__())
env.assertEqual("INPUTS not specified", str(exception))

try:
con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'b{1}')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("tensor key is empty", exception.__str__())
env.assertEqual("tensor key is empty", str(exception))

try:
con.execute_command('AI.MODELRUN', 'm_1{1}', 'INPUTS', 'OUTPUTS')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("model key is empty", exception.__str__())
env.assertEqual("model key is empty", str(exception))

try:
con.execute_command('AI.MODELRUN', 'm_1{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("model key is empty", exception.__str__())
env.assertEqual("model key is empty", str(exception))

# This error is caught after the model is sent to the backend, not in parsing like before.
try:
con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'a{1}', 'OUTPUTS', 'b{1}')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual('Expected 1 inputs but got 2', str(exception))

con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}')

Expand Down