-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Execute Redis commands in Torch Script #489
Changes from 24 commits
85cb8c6
4f0b31b
b307435
db1ea2f
13dfd20
6f09302
4755a2d
411f68f
e1c0717
2b5fb26
740ceb7
6be0a54
19ef6a3
152092a
3891559
d672e6a
652c44f
ae398b9
5b38e19
d655511
8fdf44d
c727c40
8ace67a
ff81047
69511df
f89a033
e50e5cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,7 +166,7 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString ** | |
RedisModuleString **runkey, char const **func_name, | ||
long long *timeout, int *variadic) { | ||
|
||
if (argc < 5) { | ||
if (argc < 3) { | ||
RAI_SetError(error, RAI_ESCRIPTRUN, | ||
"ERR wrong number of arguments for 'AI.SCRIPTRUN' command"); | ||
return REDISMODULE_ERR; | ||
|
@@ -183,41 +183,51 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString ** | |
*runkey = argv[argpos]; | ||
|
||
const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); | ||
if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS")) { | ||
if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS") || | ||
!strcasecmp(arg_string, "OUTPUTS")) { | ||
RAI_SetError(error, RAI_ESCRIPTRUN, "ERR function name not specified"); | ||
return REDISMODULE_ERR; | ||
} | ||
*func_name = arg_string; | ||
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); | ||
|
||
// Parse timeout arg if given and store it in timeout | ||
if (!strcasecmp(arg_string, "TIMEOUT")) { | ||
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR) | ||
return REDISMODULE_ERR; | ||
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); | ||
} | ||
if (strcasecmp(arg_string, "INPUTS") != 0) { | ||
RAI_SetError(error, RAI_ESCRIPTRUN, "ERR INPUTS not specified"); | ||
return REDISMODULE_ERR; | ||
} | ||
|
||
bool is_input = true, is_output = false; | ||
bool is_input = false; | ||
bool is_output = false; | ||
bool timeout_set = false; | ||
size_t ninputs = 0, noutputs = 0; | ||
int varidic_start_pos = -1; | ||
|
||
while (++argpos < argc) { | ||
arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); | ||
|
||
// Parse timeout arg if given and store it in timeout | ||
if (!strcasecmp(arg_string, "TIMEOUT") && !timeout_set) { | ||
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR) | ||
return REDISMODULE_ERR; | ||
timeout_set = true; | ||
continue; | ||
} | ||
|
||
if (!strcasecmp(arg_string, "INPUTS") && !is_input) { | ||
is_input = true; | ||
is_output = false; | ||
continue; | ||
} | ||
if (!strcasecmp(arg_string, "OUTPUTS") && !is_output) { | ||
is_input = false; | ||
is_output = true; | ||
} else if (!strcasecmp(arg_string, "$")) { | ||
continue; | ||
} | ||
if (!strcasecmp(arg_string, "$")) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if (!strcasecmp(arg_string, "$") && is_input) |
||
if (varidic_start_pos > -1) { | ||
RAI_SetError(error, RAI_ESCRIPTRUN, | ||
"ERR Already encountered a variable size list of tensors"); | ||
return REDISMODULE_ERR; | ||
} | ||
varidic_start_pos = ninputs; | ||
} else { | ||
continue; | ||
} | ||
// Parse argument name | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this in a block? |
||
RAI_HoldString(NULL, argv[argpos]); | ||
if (is_input) { | ||
ninputs++; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
add_library(torch_c STATIC torch_c.cpp) | ||
add_library(torch_c STATIC torch_c.cpp torch_extensions/torch_redis.cpp) | ||
target_link_libraries(torch_c "${TORCH_LIBRARIES}") | ||
set_property(TARGET torch_c PROPERTY CXX_STANDARD 14) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#include <string> | ||
#include "torch_redis.h" | ||
#include "../../redismodule.h" | ||
|
||
torch::IValue IValueFromRedisReply(RedisModuleCallReply *reply){ | ||
|
||
int reply_type = RedisModule_CallReplyType(reply); | ||
switch(reply_type) { | ||
case REDISMODULE_REPLY_NULL: { | ||
return torch::IValue(); | ||
} | ||
case REDISMODULE_REPLY_STRING: { | ||
size_t len; | ||
const char *replyStr = RedisModule_CallReplyStringPtr(reply, &len); | ||
std::string str = replyStr; | ||
return torch::IValue(str.substr(0,len)); | ||
} | ||
case REDISMODULE_REPLY_INTEGER: { | ||
int intValue = RedisModule_CallReplyInteger(reply); | ||
return torch::IValue(intValue); | ||
} | ||
case REDISMODULE_REPLY_ARRAY: { | ||
c10::impl::GenericList vec = c10::impl::GenericList(c10::AnyType::create()); | ||
size_t len = RedisModule_CallReplyLength(reply); | ||
for (auto i = 0; i < len; ++i) { | ||
RedisModuleCallReply *subReply = RedisModule_CallReplyArrayElement(reply, i); | ||
torch::IValue value = IValueFromRedisReply(subReply); | ||
vec.push_back(value); | ||
} | ||
return torch::IValue(vec); | ||
} | ||
case REDISMODULE_REPLY_ERROR: { | ||
size_t len; | ||
const char *replyStr = RedisModule_CallReplyStringPtr(reply, &len); | ||
throw std::runtime_error(replyStr); | ||
break; | ||
} | ||
default:{ | ||
throw(std::runtime_error("Unsupported redis type")); | ||
} | ||
} | ||
} | ||
|
||
torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args ) { | ||
RedisModuleCtx* ctx = RedisModule_GetThreadSafeContext(NULL); | ||
RedisModule_ThreadSafeContextLock(ctx); | ||
size_t len = args.size(); | ||
RedisModuleString* arguments[len]; | ||
len = 0; | ||
for (std::vector<std::string>::iterator it = args.begin(); it != args.end(); it++) { | ||
const std::string arg = *it; | ||
const char* str = arg.c_str(); | ||
arguments[len++] = RedisModule_CreateString(ctx, str, strlen(str)); | ||
} | ||
|
||
RedisModuleCallReply *reply = RedisModule_Call(ctx, fn_name.c_str(), "!v", arguments, len); | ||
RedisModule_ThreadSafeContextUnlock(ctx); | ||
torch::IValue value = IValueFromRedisReply(reply); | ||
RedisModule_FreeThreadSafeContext(ctx); | ||
RedisModule_FreeCallReply(reply); | ||
for(int i= 0; i < len; i++){ | ||
RedisModule_FreeString(NULL, arguments[i]); | ||
} | ||
return value; | ||
} | ||
|
||
torch::List<torch::IValue> asList(torch::IValue v) { | ||
return v.toList(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#include "torch/jit.h" | ||
#include "torch/script.h" | ||
#include "torch/csrc/jit/frontend/resolver.h" | ||
|
||
namespace torch { | ||
namespace jit { | ||
namespace script { | ||
struct RedisResolver : public Resolver { | ||
|
||
std::shared_ptr<SugaredValue> resolveValue(const std::string &name, Function &m, | ||
const SourceRange &loc) override { | ||
if (strcasecmp(name.c_str(), "torch") == 0) { | ||
return std::make_shared<BuiltinModule>("aten"); | ||
} else if (strcasecmp(name.c_str(), "redis") == 0) { | ||
return std::make_shared<BuiltinModule>("redis"); | ||
} | ||
return nullptr; | ||
} | ||
|
||
TypePtr resolveType(const std::string &name, const SourceRange &loc) override { | ||
return nullptr; | ||
} | ||
}; | ||
inline std::shared_ptr<RedisResolver> redisResolver() { return std::make_shared<RedisResolver>(); } | ||
} // namespace script | ||
} // namespace jit | ||
} // namespace torch | ||
|
||
torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args); | ||
torch::List<torch::IValue> asList(torch::IValue); | ||
|
||
static auto registry = | ||
torch::RegisterOperators("redis::execute", &redisExecute).op("redis::asList", &asList); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
|
||
def redis_string_int_to_tensor(redis_value: Any): | ||
return torch.tensor(int(str(redis_value))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing we need to make sure is that we move the tensor to the right device. We can ask the user to do so, but it needs to be done otherwise a script running on GPU will find itself with an input on CPU and will fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would just add a note here (and maybe in the docs) mentioning the device placement.
and have tensors created within the script be located on the same device as the eventual script inputs. |
||
|
||
|
||
def redis_string_float_to_tensor(redis_value: Any): | ||
return torch.tensor(float(str((redis_value)))) | ||
|
||
|
||
def redis_int_to_tensor(redis_value: int): | ||
return torch.tensor(redis_value) | ||
|
||
|
||
def redis_int_list_to_tensor(redis_value: Any): | ||
values = redis.asList(redis_value) | ||
l = [torch.tensor(int(str(v))).reshape(1,1) for v in values] | ||
return torch.cat(l, dim=0) | ||
|
||
|
||
def redis_hash_to_tensor(redis_value: Any): | ||
values = redis.asList(redis_value) | ||
l = [torch.tensor(int(str(v))).reshape(1,1) for v in values] | ||
return torch.cat(l, dim=0) | ||
|
||
def test_redis_error(): | ||
redis.execute("SET", "x") | ||
|
||
def test_int_set_get(): | ||
redis.execute("SET", "x", "1") | ||
res = redis.execute("GET", "x",) | ||
redis.execute("DEL", "x") | ||
return redis_string_int_to_tensor(res) | ||
|
||
def test_int_set_incr(): | ||
redis.execute("SET", "x", "1") | ||
res = redis.execute("INCR", "x") | ||
redis.execute("DEL", "x") | ||
return redis_string_int_to_tensor(res) | ||
|
||
def test_float_set_get(): | ||
redis.execute("SET", "x", "1.1") | ||
res = redis.execute("GET", "x",) | ||
redis.execute("DEL", "x") | ||
return redis_string_float_to_tensor(res) | ||
|
||
def test_int_list(): | ||
redis.execute("RPUSH", "x", "1") | ||
redis.execute("RPUSH", "x", "2") | ||
res = redis.execute("LRANGE", "x", "0", "2") | ||
redis.execute("DEL", "x") | ||
return redis_int_list_to_tensor(res) | ||
|
||
|
||
def test_hash(): | ||
redis.execute("HSET", "x", "field1", "1", "field2", "2") | ||
res = redis.execute("HVALS", "x") | ||
redis.execute("DEL", "x") | ||
return redis_hash_to_tensor(res) | ||
|
||
|
||
def test_set_key(): | ||
redis.execute("SET", ["x{1}", "1"]) | ||
|
||
|
||
def test_del_key(): | ||
redis.execute("DEL", ["x"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this way, INPUTS can appear after OUTPUTS (even more than once...)
Consider add another flag like "input_done"?