diff --git a/src/execution/parsing/script_commands_parser.c b/src/execution/parsing/script_commands_parser.c index 4929fe9b3..e5d584e72 100644 --- a/src/execution/parsing/script_commands_parser.c +++ b/src/execution/parsing/script_commands_parser.c @@ -171,16 +171,15 @@ static int _ScriptExecuteCommand_ParseCommand(RedisModuleCtx *ctx, RedisModuleSt } while (argpos < argc) { - const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); + const char *arg_string = RedisModule_StringPtrLen(argv[argpos++], NULL); // Parse timeout arg if given and store it in timeout. if (!strcasecmp(arg_string, "TIMEOUT")) { - argpos++; if (argpos >= argc) { RAI_SetError(error, RAI_ESCRIPTRUN, "ERR No value provided for TIMEOUT in AI.SCRIPTEXECUTE"); return REDISMODULE_ERR; } - if (ParseTimeout(argv[argpos], error, timeout) == REDISMODULE_ERR) + if (ParseTimeout(argv[argpos++], error, timeout) == REDISMODULE_ERR) return REDISMODULE_ERR; // No other arguments expected after timeout. break; @@ -192,7 +191,6 @@ static int _ScriptExecuteCommand_ParseCommand(RedisModuleCtx *ctx, RedisModuleSt "ERR Already Encountered KEYS scope in AI.SCRIPTEXECUTE command"); return REDISMODULE_ERR; } - argpos++; keysDone = true; if (_ScriptExecuteCommand_ParseKeys(ctx, argv, argc, &argpos, error, sctx) == REDISMODULE_ERR) { @@ -207,7 +205,6 @@ static int _ScriptExecuteCommand_ParseCommand(RedisModuleCtx *ctx, RedisModuleSt "ERR Already Encountered ARGS scope in AI.SCRIPTEXECUTE command"); return REDISMODULE_ERR; } - argpos++; argsDone = true; if (_ScriptExecuteCommand_ParseArgs(ctx, argv, argc, &argpos, error, sctx) == REDISMODULE_ERR) { @@ -223,7 +220,6 @@ static int _ScriptExecuteCommand_ParseCommand(RedisModuleCtx *ctx, RedisModuleSt "ERR Already Encountered INPUTS scope in AI.SCRIPTEXECUTE command"); return REDISMODULE_ERR; } - argpos++; inputsDone = true; if (_ScriptExecuteCommand_ParseInputs(ctx, argv, argc, &argpos, error, inputs) == REDISMODULE_ERR) { @@ -238,7 +234,6 @@ static int _ScriptExecuteCommand_ParseCommand(RedisModuleCtx *ctx, RedisModuleSt "ERR Already Encountered OUTPUTS scope in AI.SCRIPTEXECUTE command"); return REDISMODULE_ERR; } - argpos++; outputsDone = true; if (_ScriptExecuteCommand_ParseOutputs(ctx, argv, argc, &argpos, error, outputs) == REDISMODULE_ERR) { diff --git a/tests/flow/test_data/script.txt b/tests/flow/test_data/script.txt index 4478acc26..f7a0d0397 100644 --- a/tests/flow/test_data/script.txt +++ b/tests/flow/test_data/script.txt @@ -7,3 +7,8 @@ def bar_variadic(tensors: List[Tensor], keys: List[str], args: List[str]): a = tensors[0] l = tensors[1:] return a + l[0] + +def long_func(tensors: List[Tensor], keys: List[str], args: List[str]): + sum=0 + for i in range(10000000): + sum+=1 diff --git a/tests/flow/tests_commands.py b/tests/flow/tests_commands.py index f3d30fd8d..49b495518 100644 --- a/tests/flow/tests_commands.py +++ b/tests/flow/tests_commands.py @@ -296,12 +296,17 @@ def test_pytorch_scriptexecute_errors(env): check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 'OUTPUTS') - check_error_message(env, con, "Invalid arguments provided to AI.SCRIPTEXECUTE", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1, '{1}', 'ARGS') + check_error_message(env, con, "Invalid arguments provided to AI.SCRIPTEXECUTE", + 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1, '{1}', 'ARGS') - check_error_message(env, con, "Invalid argument for inputs count in AI.SCRIPTEXECUTE", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'INPUTS', 'OUTPUTS') + check_error_message(env, con, "Invalid argument for inputs count in AI.SCRIPTEXECUTE", + 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'INPUTS', 'OUTPUTS') - check_error_message(env, con, "Invalid value for TIMEOUT",'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1, '{1}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}', 'TIMEOUT', 'TIMEOUT') + check_error_message(env, con, "Invalid value for TIMEOUT", + 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}', 'TIMEOUT', 'TIMEOUT') + check_error_message(env, con, "No value provided for TIMEOUT in AI.SCRIPTEXECUTE", + 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}', 'TIMEOUT') if env.isCluster(): # cross shard diff --git a/tests/flow/tests_pytorch.py b/tests/flow/tests_pytorch.py index f1b2867e8..5cdd0be92 100644 --- a/tests/flow/tests_pytorch.py +++ b/tests/flow/tests_pytorch.py @@ -1,4 +1,5 @@ import redis +import time from includes import * from RLTest import Env @@ -389,13 +390,40 @@ def test_pytorch_scriptexecute_list_input(env): env.assertEqual(values2, values) -def test_pytorch_scriptinfo(env): +def test_pytorch_scriptexecute_with_timeout(env): if not TEST_PT: env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) return - # env.debugPrint("skipping this tests for now", force=True) - # return + con = get_connection(env, '{$}') + script = load_file_content('script.txt') + ret = con.execute_command('AI.SCRIPTSTORE', 'my_script{$}', DEVICE, + 'ENTRY_POINTS', 2, 'bar', 'long_func', 'SOURCE', script) + env.assertEqual(ret, b'OK') + + con.execute_command('AI.TENSORSET', 'a{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + con.execute_command('AI.TENSORSET', 'b{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + + def run(): + con2 = get_connection(env, '{$}') + con2.execute_command('AI.SCRIPTEXECUTE', 'my_script{$}', 'long_func', 'KEYS', 1, '{$}') + + t = threading.Thread(target=run) + t.start() + + # make sure that we have a long operation that RedisAI will run upon sending the following + # command, to assure that timeout will occur. + time.sleep(0.1) + ret = con.execute_command('AI.SCRIPTEXECUTE', 'my_script{$}', 'bar', + 'INPUTS', 2, 'a{$}', 'b{$}', 'OUTPUTS', 1, 'c{$}', 'TIMEOUT', 1) + env.assertEqual(ret, b'TIMEDOUT') + t.join() + + +def test_pytorch_scriptinfo(env): + if not TEST_PT: + env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) + return con = get_connection(env, '{1}')