Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ShengYang1 committed Sep 14, 2020
1 parent 3f75ccb commit 3feced9
Showing 1 changed file with 71 additions and 27 deletions.
98 changes: 71 additions & 27 deletions rfcs/20200903-pluggable-graph-optimizer-for-tensorflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,26 @@ class CCustomGraphOptimizer : public CustomGraphOptimizer {

Upon initialization of TensorFlow, it uses platform independent `LoadDynamicLibrary()` to load the dynamic library. The plugin library should be installed to the default plugin directory "...python_dir.../site-packages/tensorflow-plugins". The modular tensorflow [RFC](https://github.com/tensorflow/community/blob/master/rfcs/20190305-modular-tensorflow.md) describes the process of loading plugins.

During the plugin library initialization, TensorFlow proper calls the `InitGraphPlugin` API (part of Graph C API), which is defined in plugin and plugin authors needs to implement it to create and register a new custom graph optimizer.
During the plugin library initialization, TensorFlow proper calls `InitGraphModule` to load the library. `InitGraphParams` and `InitGraphPlugin` API (part of Graph C API) are provided and plugin authors needs to implement it to register a new custom graph optimizer.

```cpp
static Status InitGraphModule(void* dso_handle) {
void* dso_symbol;
tensorflow::Env* env = tensorflow::Env::Default();

env->GetSymbolFromLibrary(dso_handle, "InitGraphParams", &dso_symbol).IgnoreError();
using InitGraphParams = void(*)(P_RegistrationParams*, TF_Status*);
auto init_param_fn = reinterpret_cast<InitGraphParams>(dso_symbol);
P_RegistrationParams params;
TF_Status* status = TF_NewStatus();
init_param_fn(&params, status);

env->GetSymbolFromLibrary(dso_handle, "InitGraphPlugin", &dso_symbol).IgnoreError();
using InitGraphPlugin = void(*)(P_RegistrationParams*, TF_Status*);
auto init_plugin_fn = reinterpret_cast<InitGraphPlugin>(dso_symbol);
init_plugin_fn(&params, status);
}
```
### Graph Optimizer Registration
Expand All @@ -150,9 +169,36 @@ class MyOptimizer : public CustomGraphOptimizer {
REGISTER_GRAPH_OPTIMIZER_AS(MyOptimizer, "MyOptimizer");
```

**Proposed C API**

The equivalent C API should provide a series of functions that operate on `TF_OptimizerBuilder`, an opaque struct obtained with the `TF_OptimizerBuilder` call.
The optimizer builder is registered with TensorFlow using the `TF_RegisterOptimizer` function.
`P_RegistrationParams` defines params which plugin authors needs to define, including optimizer name, device type name, and some flags indicating whether existing optimizers should be disabled.
These two registration functions are provided below:

```cpp
typedef struct P_RegistrationParams {
char* name;
int name_len;
char* device;
int device_len;
bool remapping;
bool auto_mixed_precision;
// ...
}
typedef struct TF_OptimizerBuilder {
P_RegistrationParams* params;
void* (*create_func)(),
void (*optimize_func)(void*, TF_GrapplerItem*, TF_Buffer*),
void (*delete_func)(void*)
} TF_OptimizerBuilder;

void TF_RegisterOptimizer(TF_OptimizerBuilder* builder, TF_Status* status);
```
**Proper**
The equivalent C API relies on customized implementations of CustomGraphOptimizer. It is defined in proper side and might look as follows:
The C API relies on customized implementations of CustomGraphOptimizer. It is defined in proper side and might look as follows:
```cpp
class CCustomGraphOptimizer : public CustomGraphOptimizer {
Expand All @@ -177,30 +223,18 @@ class CCustomGraphOptimizer : public CustomGraphOptimizer {
void (*delete_func_)(void*);
void* c_optimizer_;
}
```
**Proposed C API**

The C API should provide a series of functions that operate on `TF_OptimizerBuilder`, an opaque struct obtained with the `TF_OptimizerBuilder` call. The optimizer builder is registered with TensorFlow using the `TF_RegisterOptimizer` function. These two registration functions are provided below:

```cpp
typedef struct P_RegistrationParams {
char* name;
int name_len;
char* device;
int device_len;
bool remapping;
bool auto_mixed_precision;
// ...
void TF_RegisterOptimizer(TF_OptimizerBuilder* builder, TF_Status* status) {
::tensorflow::grappler::CustomGraphOptimizerRegistrar
MyOptimizer_registrar([]() {
return new CCustomGraphOptimizer(
builder->device,
builder->create_func,
builder->delete_func,
builder->optimize_func);
}, builder->params);
TF_SetStatus(status, TF_OK, "");
}
typedef struct TF_OptimizerBuilder {
P_RegistrationParams* params;
void* (*create_func)(),
void (*optimize_func)(void*, TF_GrapplerItem*, TF_Buffer*),
void (*delete_func)(void*)
} TF_OptimizerBuilder;

void TF_RegisterOptimizer(TF_OptimizerBuilder* builder, TF_Status* status);
```

**Plugin**
Expand All @@ -222,7 +256,7 @@ static void MyOptimizer_Optimize(void* optimizer, TF_GrapplerItem* item, TF_Buff
// Fetch GraphDef from TF_GrapplerItem and then optimize it.
}

void InitGraphPlugin(P_RegistrationParams* params) {
void InitGraphParams(P_RegistrationParams* params, TF_Status* status) {
std::string name = "MyOptimizer";
std::string device = "GPU";
params.name = name.c_str();
Expand All @@ -231,13 +265,13 @@ void InitGraphPlugin(P_RegistrationParams* params) {
params.device_len = device.size();
params.remapping = false;
params.auto_mixed_precision = true;
}

void InitGraphPlugin(P_RegistrationParams* params, TF_Status* status) {
TF_OptimizerBuilder* builder =
TF_OptimizerBuilder(params, &MyOptimizer_Create, &MyOptimizer_Optimize, &MyOptimizer_Delete);
TF_Status* status = TF_NewStatus();
TF_RegisterOptimizer(builder, status);
if (TF_GetCode(status) != TF_OK) { /* handle errors */ }
TF_DeleteStatus(status);
}
```
Expand Down Expand Up @@ -356,6 +390,16 @@ typedef struct TF_OptimizerBuilder {
void (*delete_func)(void*)
} TF_OptimizerBuilder;
void TF_RegisterOptimizer(TF_OptimizerBuilder* builder, TF_Status* status);
typedef struct P_RegistrationParams {
char* name;
int name_len;
char* device;
int device_len;
bool remapping;
bool auto_mixed_precision;
} P_RegistrationParams;
void InitGraphParams(P_RegistrationParams* params, TF_Status* status);
void InitGraphPlugin(P_RegistrationParams* params, TF_Status* status);

// TF_GrapplerItem
typedef struct TF_GrapplerItem TF_GrapplerItem;
Expand Down

0 comments on commit 3feced9

Please sign in to comment.