Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Script commands #50

Merged
merged 6 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/java/com/redislabs/redisai/Command.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ public enum Command implements ProtocolCommand {
MODEL_RUN("AI.MODELRUN"),
MODEL_EXECUTE("AI.MODELEXECUTE"),
SCRIPT_SET("AI.SCRIPTSET"),
SCRIPT_STORE("AI.SCRIPTSTORE"),
SCRIPT_GET("AI.SCRIPTGET"),
SCRIPT_DEL("AI.SCRIPTDEL"),
SCRIPT_RUN("AI.SCRIPTRUN"),
SCRIPT_EXECUTE("AI.SCRIPTEXECUTE"),
DAGRUN("AI.DAGRUN"),
DAGRUN_RO("AI.DAGRUN_RO"),
DAGEXECUTE("AI.DAGEXECUTE"),
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/com/redislabs/redisai/Dag.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ public Dag runScript(String key, String function, String[] inputs, String[] outp
return this;
}

@Override
public Dag executeScript(
String key,
String function,
List<String> keys,
List<String> inputs,
List<String> args,
List<String> outputs) {
List<byte[]> binary =
Script.scriptExecuteFlatArgs(key, function, keys, inputs, keys, outputs, -1L, true);
sazzad16 marked this conversation as resolved.
Show resolved Hide resolved
this.commands.add(binary);
this.tensorgetflag.add(false);
return this;
}

List<byte[]> dagRunFlatArgs(String[] loadKeys, String[] persistKeys) {
List<byte[]> args = new ArrayList<>();
if (loadKeys != null && loadKeys.length > 0) {
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/com/redislabs/redisai/DagRunCommands.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.redislabs.redisai;

import java.util.List;

interface DagRunCommands<T> {
T setTensor(String key, Tensor tensor);

Expand All @@ -10,4 +12,12 @@ interface DagRunCommands<T> {
T executeModel(String key, String[] inputs, String[] outputs);

T runScript(String key, String function, String[] inputs, String[] outputs);

T executeScript(
String key,
String function,
List<String> keys,
List<String> inputs,
List<String> args,
List<String> outputs);
}
3 changes: 2 additions & 1 deletion src/main/java/com/redislabs/redisai/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ public enum Device implements ProtocolCommand {
private final byte[] raw;

Device() {
raw = SafeEncoder.encode(this.name());
raw = SafeEncoder.encode(name());
}

@Override
public byte[] getRaw() {
return raw;
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/redislabs/redisai/Keyword.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public enum Keyword implements ProtocolCommand {
SOURCE,
RESETSTAT,
TAG,
ENTRY_POINTS,
BATCHSIZE,
MINBATCHSIZE,
MINBATCHTIMEOUT,
Expand All @@ -21,6 +22,7 @@ public enum Keyword implements ProtocolCommand {
LOAD,
PERSIST,
KEYS,
ARGS,
PIPE("|>");

private final byte[] raw;
Expand Down
92 changes: 82 additions & 10 deletions src/main/java/com/redislabs/redisai/RedisAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Map;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import redis.clients.jedis.BinaryClient;
import redis.clients.jedis.Client;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisClientConfig;
Expand Down Expand Up @@ -103,6 +104,22 @@ private static JedisPoolConfig initPoolConfig(int poolSize) {
return conf;
}

private Jedis getConnection() {
return pool.getResource();
}

private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) {
BinaryClient client = conn.getClient();
client.sendCommand(command, args);
return client;
}

private Client sendCommand(Jedis conn, Command command, String... args) {
Client client = conn.getClient();
client.sendCommand(command, args);
return client;
}

/**
* Direct mapping to AI.TENSORSET
*
Expand Down Expand Up @@ -320,6 +337,27 @@ public boolean setScript(String key, Script script) {
}
}

/**
* Direct mapping to AI.MODELSTORE command.
*
* <p>{@code AI.SCRIPTSTORE <key> <device> [TAG tag] ENTRY_POINTS <entry_point_count>
* <entry_point> [<entry_point>...] SOURCE "<script>"}
*
* @param key name of key to store the Script in RedisAI server
* @param script the Script Object
* @return true if Script was properly stored in RedisAI server
*/
public boolean storeScript(String key, Script script) {
try (Jedis conn = getConnection()) {
List<String> args = script.getScriptStoreCommandBytes(key);
return sendCommand(conn, Command.SCRIPT_STORE, args.toArray(new String[args.size()]))
.getStatusCodeReply()
.equals("OK");
} catch (JedisDataException ex) {
throw new RedisAIException(ex.getMessage(), ex);
}
}

/**
* Direct mapping to AI.SCRIPTGET
*
Expand Down Expand Up @@ -427,6 +465,50 @@ public boolean runScript(String key, String function, String[] inputs, String[]
}
}

public boolean executeScript(
String key,
String function,
List<String> keys,
List<String> inputs,
List<String> args,
List<String> outputs) {
return executeScript(key, function, keys, inputs, args, outputs, -1);
}

/**
* Direct mapping to AI.SCRIPTEXECUTE command.
*
* <p>{@code AI.SCRIPTEXECUTE <key> <function> [KEYS n <key> [keys...]] [INPUTS m <input> [input
* ...]] [ARGS k <arg> [arg...]] [OUTPUTS k <output> [output ...] [TIMEOUT t]]+}
*
* @param key
* @param function
* @param keys
* @param inputs
* @param args
* @param outputs
* @param timeout timeout in ms
* @return
*/
public boolean executeScript(
String key,
String function,
List<String> keys,
List<String> inputs,
List<String> args,
List<String> outputs,
long timeout) {
try (Jedis conn = getConnection()) {
List<byte[]> binary =
Script.scriptExecuteFlatArgs(key, function, keys, inputs, args, outputs, timeout, false);
return sendCommand(conn, Command.SCRIPT_EXECUTE, binary.toArray(new byte[binary.size()][]))
.getStatusCodeReply()
.equals("OK");
} catch (JedisDataException ex) {
throw new RedisAIException(ex.getMessage(), ex);
}
}

/**
* Direct mapping to AI.DAGRUN specifies a direct acyclic graph of operations to run within
* RedisAI
Expand Down Expand Up @@ -555,16 +637,6 @@ public boolean resetStat(String key) {
}
}

private Jedis getConnection() {
return pool.getResource();
}

private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) {
BinaryClient client = conn.getClient();
client.sendCommand(command, args);
return client;
}

/**
* AI.CONFIG <BACKENDSPATH <path>>
*
Expand Down
Loading