Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for DAGRUN and DAGRUN_RO #8

Merged
merged 11 commits into from
Jun 8, 2020
91 changes: 44 additions & 47 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { Model } from './model';
import * as util from 'util';
import { Script } from './script';
import { Stats } from './stats';
import { Dag, DagCommandInterface } from './dag';

export class Client {
private _sendCommand: any;
Expand All @@ -24,22 +25,12 @@ export class Client {
}

public tensorset(keName: string, t: Tensor): Promise<any> {
const args: any[] = [keName, t.dtype];
t.shape.forEach((value) => args.push(value.toString()));
if (t.data != null) {
if (t.data instanceof Buffer) {
args.push('BLOB');
args.push(t.data);
} else {
args.push('VALUES');
t.data.forEach((value) => args.push(value.toString()));
}
}
const args: any[] = t.tensorSetFlatArgs(keName);
return this._sendCommand('ai.tensorset', args);
}

public tensorget(keName: string): Promise<any> {
const args: any[] = [keName, 'META', 'VALUES'];
public tensorget(keyName: string): Promise<any> {
const args: any[] = Tensor.tensorGetFlatArgs(keyName);
return this._sendCommand('ai.tensorget', args)
.then((reply: any[]) => {
return Tensor.NewTensorFromTensorGetReply(reply);
Expand All @@ -49,30 +40,13 @@ export class Client {
});
}

public modelset(keName: string, m: Model): Promise<any> {
const args: any[] = [keName, m.backend.toString(), m.device];
if (m.tag !== undefined) {
args.push('TAG');
args.push(m.tag.toString());
}
if (m.inputs.length > 0) {
args.push('INPUTS');
m.inputs.forEach((value) => args.push(value));
}
if (m.outputs.length > 0) {
args.push('OUTPUTS');
m.outputs.forEach((value) => args.push(value));
}
args.push('BLOB');
args.push(m.blob);
public modelset(keyName: string, m: Model): Promise<any> {
const args: any[] = m.modelSetFlatArgs(keyName);
return this._sendCommand('ai.modelset', args);
}

public modelrun(modelName: string, inputs: string[], outputs: string[]): Promise<any> {
const args: any[] = [modelName, 'INPUTS'];
inputs.forEach((value) => args.push(value));
args.push('OUTPUTS');
outputs.forEach((value) => args.push(value));
const args: any[] = Model.modelRunFlatArgs(modelName, inputs, outputs);
return this._sendCommand('ai.modelrun', args);
}

Expand All @@ -82,7 +56,7 @@ export class Client {
}

public modelget(modelName: string): Promise<any> {
const args: any[] = [modelName, 'META', 'BLOB'];
const args: any[] = Model.modelGetFlatArgs(modelName);
return this._sendCommand('ai.modelget', args)
.then((reply: any[]) => {
return Model.NewModelFromModelGetReply(reply);
Expand All @@ -92,22 +66,13 @@ export class Client {
});
}

public scriptset(keName: string, s: Script): Promise<any> {
const args: any[] = [keName, s.device];
if (s.tag !== undefined) {
args.push('TAG');
args.push(s.tag);
}
args.push('SOURCE');
args.push(s.script);
public scriptset(keyName: string, s: Script): Promise<any> {
const args: any[] = s.scriptSetFlatArgs(keyName);
return this._sendCommand('ai.scriptset', args);
}

public scriptrun(scriptName: string, functionName: string, inputs: string[], outputs: string[]): Promise<any> {
const args: any[] = [scriptName, functionName, 'INPUTS'];
inputs.forEach((value) => args.push(value));
args.push('OUTPUTS');
outputs.forEach((value) => args.push(value));
const args: any[] = Script.scriptRunFlatArgs(scriptName, functionName, inputs, outputs);
return this._sendCommand('ai.scriptrun', args);
}

Expand All @@ -117,7 +82,7 @@ export class Client {
}

public scriptget(scriptName: string): Promise<any> {
const args: any[] = [scriptName, 'META', 'SOURCE'];
const args: any[] = Script.scriptGetFlatArgs(scriptName);
return this._sendCommand('ai.scriptget', args)
.then((reply: any[]) => {
return Script.NewScriptFromScriptGetReply(reply);
Expand Down Expand Up @@ -151,6 +116,38 @@ export class Client {
});
}

/**
*
* @param loadKeys
* @param persistKeys
* @param dag
*/
public dagrun(loadKeys: string[] | null, persistKeys: string[] | null, dag: Dag): Promise<any> {
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
const args: any[] = dag.dagRunFlatArgs(loadKeys, persistKeys);
return this._sendCommand('ai.dagrun', args)
.then((reply: any[]) => {
return dag.ProcessDagReply(reply);
})
.catch((error: any) => {
throw error;
});
}

/**
*
* @param loadKeys
* @param dag
*/
public dagrun_ro(loadKeys: string[] | null, dag: Dag): Promise<any> {
const args: any[] = dag.dagRunFlatArgs(loadKeys, null);
return this._sendCommand('ai.dagrun_ro', args)
.then((reply: any[]) => {
return dag.ProcessDagReply(reply);
})
.catch((error: any) => {
throw error;
});
}
/**
* Loads the DL/ML backend specified by the backend identifier from path.
*
Expand Down
85 changes: 85 additions & 0 deletions src/dag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import { Model } from './model';
import { Script } from './script';
import { Tensor } from './tensor';
import { Dtype, DTypeMap } from './dtype';

export interface DagCommandInterface {
tensorset(keName: string, t: Tensor);

tensorget(keyName: string);

tensorget(keyName: string);

modelrun(modelName: string, inputs: string[], outputs: string[]);

scriptrun(scriptName: string, functionName: string, inputs: string[], outputs: string[]);
}

/**
* Direct mapping to RedisAI DAGs
*/
export class Dag implements DagCommandInterface {
private _commands: any[][];
private _tensorgetflag: boolean[];

constructor() {
this._commands = [];
this._tensorgetflag = [];
}

public tensorset(keName: string, t: Tensor) {
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
const args: any[] = ['AI.TENSORSET'];
t.tensorSetFlatArgs(keName).forEach((arg) => args.push(arg));
this._commands.push(args);
this._tensorgetflag.push(false);
}

public tensorget(keyName: string) {
const args: any[] = ['AI.TENSORGET'];
Tensor.tensorGetFlatArgs(keyName).forEach((arg) => args.push(arg));
this._commands.push(args);
this._tensorgetflag.push(true);
}

public modelrun(modelName: string, inputs: string[], outputs: string[]) {
const args: any[] = ['AI.MODELRUN'];
Model.modelRunFlatArgs(modelName, inputs, outputs).forEach((arg) => args.push(arg));
this._commands.push(args);
this._tensorgetflag.push(false);
}

public scriptrun(scriptName: string, functionName: string, inputs: string[], outputs: string[]) {
const args: any[] = ['AI.SCRIPTRUN'];
Script.scriptRunFlatArgs(scriptName, functionName, inputs, outputs).forEach((arg) => args.push(arg));
this._commands.push(args);
this._tensorgetflag.push(false);
}

public dagRunFlatArgs(loadKeys: string[] | null, persistKeys: string[] | null): any[] {
const args: any[] = [];
if (loadKeys != null && loadKeys.length > 0) {
args.push('LOAD');
args.push(loadKeys.length);
loadKeys.forEach((value) => args.push(value));
}
if (persistKeys != null && persistKeys.length > 0) {
args.push('PERSIST');
args.push(persistKeys.length);
persistKeys.forEach((value) => args.push(value));
}
this._commands.forEach((value) => {
args.push('|>');
value.forEach((arg) => args.push(arg));
});
return args;
}

public ProcessDagReply(reply: any[]): any[] {
for (let i = 0; i < reply.length; i++) {
if (this._tensorgetflag[i] === true) {
reply[i] = Tensor.NewTensorFromTensorGetReply(reply[i]);
}
}
return reply;
}
}
3 changes: 2 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import { Backend, BackendMap } from './backend';
import { Tensor } from './tensor';
import { Model } from './model';
import { Script } from './script';
import { Dag } from './dag';
import { Client } from './client';
import { Stats } from './stats';
import { Helpers } from './helpers';

export { DTypeMap, Dtype, BackendMap, Backend, Model, Script, Tensor, Client, Stats, Helpers };
export { DTypeMap, Dtype, BackendMap, Backend, Model, Script, Tensor, Dag, Client, Stats, Helpers };
32 changes: 32 additions & 0 deletions src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,36 @@ export class Model {
}
return model;
}

static modelGetFlatArgs(keyName: string): any[] {
const args: any[] = [keyName, 'META', 'BLOB'];
return args;
}

static modelRunFlatArgs(modelName: string, inputs: string[], outputs: string[]): any[] {
const args: any[] = [modelName, 'INPUTS'];
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
inputs.forEach((value) => args.push(value));
args.push('OUTPUTS');
outputs.forEach((value) => args.push(value));
return args;
}

modelSetFlatArgs(keyName: string) {
const args: any[] = [keyName, this.backend.toString(), this.device];
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
if (this.tag !== undefined) {
args.push('TAG');
args.push(this.tag.toString());
}
if (this.inputs.length > 0) {
args.push('INPUTS');
this.inputs.forEach((value) => args.push(value));
}
if (this.outputs.length > 0) {
args.push('OUTPUTS');
this.outputs.forEach((value) => args.push(value));
}
args.push('BLOB');
args.push(this.blob);
return args;
}
}
24 changes: 24 additions & 0 deletions src/script.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,28 @@ export class Script {
}
return script;
}

scriptSetFlatArgs(keyName: string) {
const args: any[] = [keyName, this.device];
if (this.tag !== undefined) {
args.push('TAG');
args.push(this.tag);
}
args.push('SOURCE');
args.push(this.script);
return args;
}

static scriptRunFlatArgs(scriptName: string, functionName: string, inputs: string[], outputs: string[]) {
const args: any[] = [scriptName, functionName, 'INPUTS'];
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
inputs.forEach((value) => args.push(value));
args.push('OUTPUTS');
outputs.forEach((value) => args.push(value));
return args;
}

static scriptGetFlatArgs(scriptName: string) {
const args: any[] = [scriptName, 'META', 'SOURCE'];
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
return args;
}
}
20 changes: 20 additions & 0 deletions src/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ export class Tensor {
set data(value: Buffer | number[] | null) {
this._data = value;
}
//
tensorSetFlatArgs(keName: string): any[] {
const args: any[] = [keName, this.dtype];
this.shape.forEach((value) => args.push(value.toString()));
if (this.data != null) {
if (this.data instanceof Buffer) {
args.push('BLOB');
args.push(this.data);
} else {
args.push('VALUES');
this.data.forEach((value) => args.push(value.toString()));
}
}
return args;
}

static NewTensorFromTensorGetReply(reply: any[]) {
let dt = null;
Expand Down Expand Up @@ -84,4 +99,9 @@ export class Tensor {
}
return new Tensor(dt, shape, values);
}

static tensorGetFlatArgs(keyName: string): any[] {
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
const args: any[] = [keyName, 'META', 'VALUES'];
return args;
}
}
Loading