Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execute Redis commands in Torch Script #489

Merged
merged 27 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
85cb8c6
wip
DvirDukhan Nov 2, 2020
4f0b31b
after rebase & format
DvirDukhan Dec 3, 2020
b307435
wip
DvirDukhan Dec 7, 2020
db1ea2f
wip
DvirDukhan Dec 7, 2020
13dfd20
Merge branch 'master' into torchscript_extensions
DvirDukhan Dec 14, 2020
6f09302
Merge branch 'torchscript_extensions' of https://github.com/RedisAI/R…
DvirDukhan Dec 14, 2020
4755a2d
wip
DvirDukhan Dec 14, 2020
411f68f
Merge branch 'master' into torchscript_extensions
DvirDukhan Dec 27, 2020
e1c0717
Merge branch 'torchscript_extensions' of https://github.com/RedisAI/R…
DvirDukhan Dec 27, 2020
2b5fb26
Merge branch 'master' into torchscript_extensions
DvirDukhan Dec 29, 2020
740ceb7
Merge branch 'torchscript_extensions' of https://github.com/RedisAI/R…
DvirDukhan Dec 29, 2020
6be0a54
Merge branch 'master' into torchscript_extensions
DvirDukhan Dec 29, 2020
19ef6a3
Merge branch 'torchscript_extensions' of https://github.com/RedisAI/R…
DvirDukhan Dec 29, 2020
152092a
simple get
DvirDukhan Dec 29, 2020
3891559
simple scalars round trip
DvirDukhan Dec 30, 2020
d672e6a
test pass
DvirDukhan Dec 30, 2020
652c44f
make format
DvirDukhan Dec 30, 2020
ae398b9
added key tags
DvirDukhan Dec 30, 2020
5b38e19
correct redis usage
DvirDukhan Dec 30, 2020
d655511
make format
DvirDukhan Dec 30, 2020
8fdf44d
Merge branch 'master' into torchscript_extensions
DvirDukhan Dec 31, 2020
c727c40
Merge branch 'master' into torchscript_extensions
lantiga Jan 6, 2021
8ace67a
Merge branch 'master' into torchscript_extensions
DvirDukhan Jan 13, 2021
ff81047
Merge branch 'master' into torchscript_extensions
DvirDukhan Jan 13, 2021
69511df
fixed PR comments
DvirDukhan Jan 13, 2021
f89a033
fixed typo
DvirDukhan Jan 13, 2021
e50e5cf
Merge branch 'master' into torchscript_extensions
DvirDukhan Jan 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/backends/torch.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ int RAI_InitBackendTorch(int (*get_api_fn)(const char *, void *)) {
get_api_fn("RedisModule_Free", ((void **)&RedisModule_Free));
get_api_fn("RedisModule_Realloc", ((void **)&RedisModule_Realloc));
get_api_fn("RedisModule_Strdup", ((void **)&RedisModule_Strdup));
get_api_fn("RedisModule_CreateString", ((void **)&RedisModule_CreateString));
get_api_fn("RedisModule_FreeString", ((void **)&RedisModule_FreeString));
get_api_fn("RedisModule_Call", ((void **)&RedisModule_Call));
get_api_fn("RedisModule_CallReplyType", ((void **)&RedisModule_CallReplyType));
get_api_fn("RedisModule_CallReplyStringPtr", ((void **)&RedisModule_CallReplyStringPtr));
get_api_fn("RedisModule_CallReplyInteger", ((void **)&RedisModule_CallReplyInteger));
get_api_fn("RedisModule_CallReplyLength", ((void **)&RedisModule_CallReplyLength));
get_api_fn("RedisModule_CallReplyArrayElement", ((void **)&RedisModule_CallReplyArrayElement));
get_api_fn("RedisModule_FreeCallReply", ((void **)&RedisModule_FreeCallReply));
get_api_fn("RedisModule_GetThreadSafeContext", ((void **)&RedisModule_GetThreadSafeContext));
get_api_fn("RedisModule_ThreadSafeContextLock", ((void **)&RedisModule_ThreadSafeContextLock));
get_api_fn("RedisModule_ThreadSafeContextUnlock",
((void **)&RedisModule_ThreadSafeContextUnlock));
get_api_fn("RedisModule_FreeThreadSafeContext", ((void **)&RedisModule_FreeThreadSafeContext));

return REDISMODULE_OK;
}
Expand Down
44 changes: 27 additions & 17 deletions src/command_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **
RedisModuleString **runkey, char const **func_name,
long long *timeout, int *variadic) {

if (argc < 5) {
if (argc < 3) {
RAI_SetError(error, RAI_ESCRIPTRUN,
"ERR wrong number of arguments for 'AI.SCRIPTRUN' command");
return REDISMODULE_ERR;
Expand All @@ -183,41 +183,51 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **
*runkey = argv[argpos];

const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS")) {
if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS") ||
!strcasecmp(arg_string, "OUTPUTS")) {
RAI_SetError(error, RAI_ESCRIPTRUN, "ERR function name not specified");
return REDISMODULE_ERR;
}
*func_name = arg_string;
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);

// Parse timeout arg if given and store it in timeout
if (!strcasecmp(arg_string, "TIMEOUT")) {
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
return REDISMODULE_ERR;
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
}
if (strcasecmp(arg_string, "INPUTS") != 0) {
RAI_SetError(error, RAI_ESCRIPTRUN, "ERR INPUTS not specified");
return REDISMODULE_ERR;
}

bool is_input = true, is_output = false;
bool is_input = false;
bool is_output = false;
bool timeout_set = false;
size_t ninputs = 0, noutputs = 0;
int varidic_start_pos = -1;

while (++argpos < argc) {
arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);

// Parse timeout arg if given and store it in timeout
if (!strcasecmp(arg_string, "TIMEOUT") && !timeout_set) {
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
return REDISMODULE_ERR;
timeout_set = true;
continue;
}

if (!strcasecmp(arg_string, "INPUTS") && !is_input) {
is_input = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this way, INPUTS can appear after OUTPUTS (even more than once...)
Consider add another flag like "input_done"?

is_output = false;
continue;
}
if (!strcasecmp(arg_string, "OUTPUTS") && !is_output) {
is_input = false;
is_output = true;
} else if (!strcasecmp(arg_string, "$")) {
continue;
}
if (!strcasecmp(arg_string, "$")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!strcasecmp(arg_string, "$") && is_input)
otherwise, varidic can inputs can come in the middle of OUTPUTS.

if (varidic_start_pos > -1) {
RAI_SetError(error, RAI_ESCRIPTRUN,
"ERR Already encountered a variable size list of tensors");
return REDISMODULE_ERR;
}
varidic_start_pos = ninputs;
} else {
continue;
}
// Parse argument name
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this in a block?
And also, I think that the modified parsing logic allows to avoid INPUTS/OUTPUTS all together. For example, the following command would not raise an error (and will refer all the keys as inputs):
AI.SCRIPTRUN script_key func_name key1 key2 ...

RAI_HoldString(NULL, argv[argpos]);
if (is_input) {
ninputs++;
Expand Down
2 changes: 1 addition & 1 deletion src/libtorch_c/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
add_library(torch_c STATIC torch_c.cpp)
add_library(torch_c STATIC torch_c.cpp torch_extensions/torch_redis.cpp)
target_link_libraries(torch_c "${TORCH_LIBRARIES}")
set_property(TARGET torch_c PROPERTY CXX_STANDARD 14)
53 changes: 32 additions & 21 deletions src/libtorch_c/torch_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <iostream>
#include <sstream>

#include "torch_extensions/torch_redis.h"
namespace {

static DLDataType getDLDataType(const at::Tensor &t) {
Expand Down Expand Up @@ -246,6 +247,7 @@ void torchRunModule(ModuleContext *ctx, const char *fnName, int variadic, long n
torch::DeviceType output_device_type = torch::kCPU;
torch::Device output_device(output_device_type, -1);

if(nOutputs == 0) return;
int count = 0;
for (size_t i = 0; i < stack.size(); i++) {
if (count > nOutputs - 1) {
Expand Down Expand Up @@ -304,28 +306,37 @@ extern "C" DLManagedTensor *torchNewTensor(DLDataType dtype, long ndims, int64_t
return dl_tensor;
}

extern "C" void *torchCompileScript(const char *script, DLDeviceType device, int64_t device_id,
char **error, void *(*alloc)(size_t)) {
ModuleContext *ctx = new ModuleContext();
ctx->device = device;
ctx->device_id = device_id;
try {
auto cu = torch::jit::compile(script);
auto aten_device_type = getATenDeviceType(device);
if (aten_device_type == at::DeviceType::CUDA && !torch::cuda::is_available()) {
throw std::logic_error("GPU requested but Torch couldn't find CUDA");
}
ctx->cu = cu;
ctx->module = nullptr;
} catch (std::exception &e) {
size_t len = strlen(e.what()) + 1;
*error = (char *)alloc(len * sizeof(char));
strcpy(*error, e.what());
(*error)[len - 1] = '\0';
delete ctx;
return NULL;
extern "C" void* torchCompileScript(const char* script, DLDeviceType device, int64_t device_id,
char **error, void* (*alloc)(size_t))
{
ModuleContext* ctx = new ModuleContext();
ctx->device = device;
ctx->device_id = device_id;
try {
auto cu = std::make_shared<torch::jit::script::CompilationUnit>();
cu->define(
c10::nullopt,
script,
torch::jit::script::redisResolver(),
nullptr);
auto aten_device_type = getATenDeviceType(device);

if (aten_device_type == at::DeviceType::CUDA && !torch::cuda::is_available()) {
throw std::logic_error("GPU requested but Torch couldn't find CUDA");
}
return ctx;
ctx->cu = cu;
ctx->module = nullptr;

}
catch(std::exception& e) {
size_t len = strlen(e.what()) +1;
*error = (char*)alloc(len * sizeof(char));
strcpy(*error, e.what());
(*error)[len-1] = '\0';
delete ctx;
return NULL;
}
return ctx;
}

extern "C" void *torchLoadModel(const char *graph, size_t graphlen, DLDeviceType device,
Expand Down
69 changes: 69 additions & 0 deletions src/libtorch_c/torch_extensions/torch_redis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include <string>
#include "torch_redis.h"
#include "../../redismodule.h"

torch::IValue IValueFromRedisReply(RedisModuleCallReply *reply){

int reply_type = RedisModule_CallReplyType(reply);
switch(reply_type) {
case REDISMODULE_REPLY_NULL: {
return torch::IValue();
}
case REDISMODULE_REPLY_STRING: {
size_t len;
const char *replyStr = RedisModule_CallReplyStringPtr(reply, &len);
std::string str = replyStr;
return torch::IValue(str.substr(0,len));
}
case REDISMODULE_REPLY_INTEGER: {
int intValue = RedisModule_CallReplyInteger(reply);
return torch::IValue(intValue);
}
case REDISMODULE_REPLY_ARRAY: {
c10::impl::GenericList vec = c10::impl::GenericList(c10::AnyType::create());
size_t len = RedisModule_CallReplyLength(reply);
for (auto i = 0; i < len; ++i) {
RedisModuleCallReply *subReply = RedisModule_CallReplyArrayElement(reply, i);
torch::IValue value = IValueFromRedisReply(subReply);
vec.push_back(value);
}
return torch::IValue(vec);
}
case REDISMODULE_REPLY_ERROR: {
size_t len;
const char *replyStr = RedisModule_CallReplyStringPtr(reply, &len);
throw std::runtime_error(replyStr);
break;
}
default:{
throw(std::runtime_error("Unsupported redis type"));
}
}
}

torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args ) {
RedisModuleCtx* ctx = RedisModule_GetThreadSafeContext(NULL);
RedisModule_ThreadSafeContextLock(ctx);
size_t len = args.size();
RedisModuleString* arguments[len];
len = 0;
for (std::vector<std::string>::iterator it = args.begin(); it != args.end(); it++) {
const std::string arg = *it;
const char* str = arg.c_str();
arguments[len++] = RedisModule_CreateString(ctx, str, strlen(str));
}

RedisModuleCallReply *reply = RedisModule_Call(ctx, fn_name.c_str(), "!v", arguments, len);
RedisModule_ThreadSafeContextUnlock(ctx);
torch::IValue value = IValueFromRedisReply(reply);
RedisModule_FreeThreadSafeContext(ctx);
RedisModule_FreeCallReply(reply);
for(int i= 0; i < len; i++){
RedisModule_FreeString(NULL, arguments[i]);
}
return value;
}

torch::List<torch::IValue> asList(torch::IValue v) {
return v.toList();
}
33 changes: 33 additions & 0 deletions src/libtorch_c/torch_extensions/torch_redis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "torch/jit.h"
#include "torch/script.h"
#include "torch/csrc/jit/frontend/resolver.h"

namespace torch {
namespace jit {
namespace script {
struct RedisResolver : public Resolver {

std::shared_ptr<SugaredValue> resolveValue(const std::string &name, Function &m,
const SourceRange &loc) override {
if (strcasecmp(name.c_str(), "torch") == 0) {
return std::make_shared<BuiltinModule>("aten");
} else if (strcasecmp(name.c_str(), "redis") == 0) {
return std::make_shared<BuiltinModule>("redis");
}
return nullptr;
}

TypePtr resolveType(const std::string &name, const SourceRange &loc) override {
return nullptr;
}
};
inline std::shared_ptr<RedisResolver> redisResolver() { return std::make_shared<RedisResolver>(); }
} // namespace script
} // namespace jit
} // namespace torch

torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args);
torch::List<torch::IValue> asList(torch::IValue);

static auto registry =
torch::RegisterOperators("redis::execute", &redisExecute).op("redis::asList", &asList);
3 changes: 0 additions & 3 deletions src/redisai.c
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,6 @@ int RedisAI_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
return RedisAI_ScriptRun_IsKeysPositionRequest_ReportKeys(ctx, argv, argc);
}

if (argc < 6)
lantiga marked this conversation as resolved.
Show resolved Hide resolved
return RedisModule_WrongArity(ctx);

// Convert The script run command into a DAG command that contains a single op.
return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_SCRIPTRUN, false);
}
Expand Down
66 changes: 66 additions & 0 deletions tests/flow/test_data/redis_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

def redis_string_int_to_tensor(redis_value: Any):
return torch.tensor(int(str(redis_value)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we need to make sure is that we move the tensor to the right device. We can ask the user to do so, but it needs to be done otherwise a script running on GPU will find itself with an input on CPU and will fail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just add a note here (and maybe in the docs) mentioning the device placement.
The ideal solution here would be to have a SCRIPT_DEVICE global to allow a user to do

a_tensor = torch.tensor(int(str(redis_value))).to(device=SCRIPT_DEVICE)

and have tensors created within the script be located on the same device as the eventual script inputs.



def redis_string_float_to_tensor(redis_value: Any):
return torch.tensor(float(str((redis_value))))


def redis_int_to_tensor(redis_value: int):
return torch.tensor(redis_value)


def redis_int_list_to_tensor(redis_value: Any):
values = redis.asList(redis_value)
l = [torch.tensor(int(str(v))).reshape(1,1) for v in values]
return torch.cat(l, dim=0)


def redis_hash_to_tensor(redis_value: Any):
values = redis.asList(redis_value)
l = [torch.tensor(int(str(v))).reshape(1,1) for v in values]
return torch.cat(l, dim=0)

def test_redis_error():
redis.execute("SET", "x")

def test_int_set_get():
redis.execute("SET", "x", "1")
res = redis.execute("GET", "x",)
redis.execute("DEL", "x")
return redis_string_int_to_tensor(res)

def test_int_set_incr():
redis.execute("SET", "x", "1")
res = redis.execute("INCR", "x")
redis.execute("DEL", "x")
return redis_string_int_to_tensor(res)

def test_float_set_get():
redis.execute("SET", "x", "1.1")
res = redis.execute("GET", "x",)
redis.execute("DEL", "x")
return redis_string_float_to_tensor(res)

def test_int_list():
redis.execute("RPUSH", "x", "1")
redis.execute("RPUSH", "x", "2")
res = redis.execute("LRANGE", "x", "0", "2")
redis.execute("DEL", "x")
return redis_int_list_to_tensor(res)


def test_hash():
redis.execute("HSET", "x", "field1", "1", "field2", "2")
res = redis.execute("HVALS", "x")
redis.execute("DEL", "x")
return redis_hash_to_tensor(res)


def test_set_key():
redis.execute("SET", ["x{1}", "1"])


def test_del_key():
redis.execute("DEL", ["x"])
Loading