Skip to content

Commit

Permalink
Merge pull request #552 from RedisAI/general_model_inputs_and_outputs…
Browse files Browse the repository at this point in the history
…_names

expose model inputs and outputs with respect to model definition
  • Loading branch information
DvirDukhan committed Jan 13, 2021
2 parents bd9f462 + 21b39cc commit 2fc1827
Show file tree
Hide file tree
Showing 16 changed files with 384 additions and 46 deletions.
68 changes: 68 additions & 0 deletions src/backends/onnxruntime.c
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo

RAI_Device device;
int64_t deviceid;
char **inputs_ = NULL;
char **outputs_ = NULL;

if (!parseDeviceStr(devicestr, &device, &deviceid)) {
RAI_SetError(error, RAI_EMODELCREATE, "ERR unsupported device");
Expand Down Expand Up @@ -352,6 +354,41 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
goto error;
}

size_t n_input_nodes;
status = ort->SessionGetInputCount(session, &n_input_nodes);
if (status != NULL) {
goto error;
}

size_t n_output_nodes;
status = ort->SessionGetOutputCount(session, &n_output_nodes);
if (status != NULL) {
goto error;
}

OrtAllocator *allocator;
status = ort->GetAllocatorWithDefaultOptions(&allocator);

inputs_ = array_new(char *, n_input_nodes);
for (long long i = 0; i < n_input_nodes; i++) {
char *input_name;
status = ort->SessionGetInputName(session, i, allocator, &input_name);
if (status != NULL) {
goto error;
}
inputs_ = array_append(inputs_, input_name);
}

outputs_ = array_new(char *, n_output_nodes);
for (long long i = 0; i < n_output_nodes; i++) {
char *output_name;
status = ort->SessionGetOutputName(session, i, allocator, &output_name);
if (status != NULL) {
goto error;
}
outputs_ = array_append(outputs_, output_name);
}

// Since ONNXRuntime doesn't have a re-serialization function,
// we cache the blob in order to re-serialize it.
// Not optimal for storage purposes, but again, it may be temporary
Expand All @@ -367,11 +404,29 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
ret->opts = opts;
ret->data = buffer;
ret->datalen = modellen;
ret->ninputs = n_input_nodes;
ret->noutputs = n_output_nodes;
ret->inputs = inputs_;
ret->outputs = outputs_;

return ret;

error:
RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status));
if (inputs_) {
n_input_nodes = array_len(inputs_);
for (uint32_t i = 0; i < n_input_nodes; i++) {
status = ort->AllocatorFree(allocator, inputs_[i]);
}
array_free(inputs_);
}
if (outputs_) {
n_output_nodes = array_len(outputs_);
for (uint32_t i = 0; i < n_output_nodes; i++) {
status = ort->AllocatorFree(allocator, outputs_[i]);
}
array_free(outputs_);
}
ort->ReleaseStatus(status);
return NULL;
}
Expand All @@ -381,6 +436,19 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {

RedisModule_Free(model->data);
RedisModule_Free(model->devicestr);
OrtAllocator *allocator;
OrtStatus *status = NULL;
status = ort->GetAllocatorWithDefaultOptions(&allocator);
for (uint32_t i = 0; i < model->ninputs; i++) {
status = ort->AllocatorFree(allocator, model->inputs[i]);
}
array_free(model->inputs);

for (uint32_t i = 0; i < model->noutputs; i++) {
status = ort->AllocatorFree(allocator, model->outputs[i]);
}
array_free(model->outputs);

ort->ReleaseSession(model->session);

model->model = NULL;
Expand Down
2 changes: 2 additions & 0 deletions src/backends/tensorflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
ret->session = session;
ret->backend = backend;
ret->devicestr = RedisModule_Strdup(devicestr);
ret->ninputs = ninputs;
ret->inputs = inputs_;
ret->noutputs = noutputs;
ret->outputs = outputs_;
ret->opts = opts;
ret->refCount = 1;
Expand Down
70 changes: 66 additions & 4 deletions src/backends/tflite.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) {
RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,
const char *modeldef, size_t modellen, RAI_Error *error) {
DLDeviceType dl_device;

RAI_Device device;
int64_t deviceid;
char **inputs_ = NULL;
char **outputs_ = NULL;
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Unsupported device");
return NULL;
Expand All @@ -47,6 +48,36 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
return NULL;
}

size_t ninputs = tfliteModelNumInputs(model, &error_descr);
if (error_descr) {
goto cleanup;
}

size_t noutputs = tfliteModelNumOutputs(model, &error_descr);
if (error_descr) {
goto cleanup;
}

inputs_ = array_new(char *, ninputs);
outputs_ = array_new(char *, noutputs);

for (size_t i = 0; i < ninputs; i++) {
const char *input = tfliteModelInputNameAtIndex(model, i, &error_descr);
if (error_descr) {
goto cleanup;
}
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
}

for (size_t i = 0; i < noutputs; i++) {
const char *output = tfliteModelOutputNameAtIndex(model, i, &error_descr);
;
if (error_descr) {
goto cleanup;
}
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
}

char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
memcpy(buffer, modeldef, modellen);

Expand All @@ -55,20 +86,51 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
ret->session = NULL;
ret->backend = backend;
ret->devicestr = RedisModule_Strdup(devicestr);
ret->inputs = NULL;
ret->outputs = NULL;
ret->ninputs = ninputs;
ret->inputs = inputs_;
ret->noutputs = noutputs;
ret->outputs = outputs_;
ret->refCount = 1;
ret->opts = opts;
ret->data = buffer;
ret->datalen = modellen;

return ret;

cleanup:
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
RedisModule_Free(error_descr);
if (inputs_) {
ninputs = array_len(inputs_);
for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free(inputs_[i]);
}
array_free(inputs_);
}
if (outputs_) {
noutputs = array_len(outputs_);
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free(outputs_[i]);
}
array_free(outputs_);
}
return NULL;
}

void RAI_ModelFreeTFLite(RAI_Model *model, RAI_Error *error) {
RedisModule_Free(model->data);
RedisModule_Free(model->devicestr);
tfliteDeallocContext(model->model);
size_t ninputs = model->ninputs;
for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free(model->inputs[i]);
}
array_free(model->inputs);

size_t noutputs = model->noutputs;
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free(model->outputs[i]);
}
array_free(model->outputs);

model->model = NULL;
}
Expand Down
78 changes: 70 additions & 8 deletions src/backends/torch.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
RAI_Device device = RAI_DEVICE_CPU;
int64_t deviceid = 0;

char **inputs_ = NULL;
char **outputs_ = NULL;

if (!parseDeviceStr(devicestr, &device, &deviceid)) {
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR unsupported device");
return NULL;
Expand Down Expand Up @@ -53,7 +56,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
if (opts.backends_intra_op_parallelism > 0) {
torchSetIntraOpThreads(opts.backends_intra_op_parallelism, &error_descr, RedisModule_Alloc);
}
if (error_descr != NULL) {
if (error_descr) {
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
RedisModule_Free(error_descr);
return NULL;
Expand All @@ -62,10 +65,37 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
void *model =
torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr, RedisModule_Alloc);

if (model == NULL) {
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
RedisModule_Free(error_descr);
return NULL;
if (error_descr) {
goto cleanup;
}

size_t ninputs = torchModelNumInputs(model, &error_descr);
if (error_descr) {
goto cleanup;
}

size_t noutputs = torchModelNumOutputs(model, &error_descr);
if (error_descr) {
goto cleanup;
}

inputs_ = array_new(char *, ninputs);
outputs_ = array_new(char *, noutputs);

for (size_t i = 0; i < ninputs; i++) {
const char *input = torchModelInputNameAtIndex(model, i, &error_descr);
if (error_descr) {
goto cleanup;
}
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
}

for (size_t i = 0; i < noutputs; i++) {
const char *output = "";
if (error_descr) {
goto cleanup;
}
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
}

char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
Expand All @@ -76,14 +106,34 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
ret->session = NULL;
ret->backend = backend;
ret->devicestr = RedisModule_Strdup(devicestr);
ret->inputs = NULL;
ret->outputs = NULL;
ret->ninputs = ninputs;
ret->inputs = inputs_;
ret->noutputs = noutputs;
ret->outputs = outputs_;
ret->opts = opts;
ret->refCount = 1;
ret->data = buffer;
ret->datalen = modellen;

return ret;

cleanup:
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
RedisModule_Free(error_descr);
if (inputs_) {
ninputs = array_len(inputs_);
for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free(inputs_[i]);
}
array_free(inputs_);
}
if (outputs_) {
noutputs = array_len(outputs_);
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free(outputs_[i]);
}
array_free(outputs_);
}
return NULL;
}

void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
Expand All @@ -93,6 +143,18 @@ void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
if (model->data) {
RedisModule_Free(model->data);
}
size_t ninputs = model->ninputs;
for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free(model->inputs[i]);
}
array_free(model->inputs);

size_t noutputs = model->noutputs;
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free(model->outputs[i]);
}
array_free(model->outputs);

torchDeallocContext(model->model);
}

Expand Down
10 changes: 4 additions & 6 deletions src/command_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,15 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a
}
}
}
if ((*model)->inputs && (*model)->ninputs != ninputs) {
if ((*model)->ninputs != ninputs) {
RAI_SetError(error, RAI_EMODELRUN,
"Number of names given as INPUTS during MODELSET and keys given as "
"INPUTS here do not match");
"Number of keys given as INPUTS here does not match model definition");
return REDISMODULE_ERR;
}

if ((*model)->outputs && (*model)->noutputs != noutputs) {
if ((*model)->noutputs != noutputs) {
RAI_SetError(error, RAI_EMODELRUN,
"Number of names given as OUTPUTS during MODELSET and keys given as "
"OUTPUTS here do not match");
"Number of keys given as OUTPUTS here does not match model definition");
return REDISMODULE_ERR;
}
return REDISMODULE_OK;
Expand Down
Loading

0 comments on commit 2fc1827

Please sign in to comment.