diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/IgniteFunctionDistributedInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/IgniteFunctionDistributedInferenceExample.java new file mode 100644 index 0000000000000..da9d5432809e3 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/IgniteFunctionDistributedInferenceExample.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.inference; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.examples.ml.regression.linear.LinearRegressionLSQRTrainerExample; +import org.apache.ignite.examples.ml.util.MLSandboxDatasets; +import org.apache.ignite.examples.ml.util.SandboxMLCache; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.builder.IgniteDistributedInfModelBuilder; +import org.apache.ignite.ml.inference.parser.IgniteFunctionInfModelParser; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.reader.InMemoryInfModelReader; +import org.apache.ignite.ml.inference.reader.InfModelReader; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; + +/** + * This example is based on {@link LinearRegressionLSQRTrainerExample}, but to perform inference it uses an approach + * implemented in {@link org.apache.ignite.ml.inference} package. + */ +public class IgniteFunctionDistributedInferenceExample { + /** Run example. */ + public static void main(String... args) throws IOException, ExecutionException, InterruptedException { + System.out.println(); + System.out.println(">>> Linear regression model over cache based dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteCache dataCache = new SandboxMLCache(ignite) + .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA); + + System.out.println(">>> Create new linear regression trainer object."); + LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); + + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> v.copyOfRange(1, v.size()), + (k, v) -> v.get(0) + ); + + System.out.println(">>> Linear regression model: " + mdl); + + System.out.println(">>> Preparing model reader and model parser."); + InfModelReader reader = new InMemoryInfModelReader(mdl); + InfModelParser parser = new IgniteFunctionInfModelParser<>(); + try (InfModel> infMdl = new IgniteDistributedInfModelBuilder(ignite, 4, 4) + .build(reader, parser)) { + System.out.println(">>> Inference model is ready."); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); + + try (QueryCursor> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry observation : observations) { + Vector val = observation.getValue(); + Vector inputs = val.copyOfRange(1, val.size()); + double groundTruth = val.get(0); + + double prediction = infMdl.predict(inputs).get(); + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + } + } + + System.out.println(">>> ---------------------------------"); + + System.out.println(">>> Linear regression model over cache based dataset usage example completed."); + } + } +} diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowDistributedInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowDistributedInferenceExample.java new file mode 100644 index 0000000000000..cc22df33d7498 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowDistributedInferenceExample.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.inference; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import org.apache.ignite.Ignite; +import org.apache.ignite.Ignition; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.builder.IgniteDistributedInfModelBuilder; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.parser.TensorFlowSavedModelInfModelParser; +import org.apache.ignite.ml.inference.reader.FileSystemInfModelReader; +import org.apache.ignite.ml.inference.reader.InfModelReader; +import org.apache.ignite.ml.util.MnistUtils; +import org.tensorflow.Tensor; + +/** + * This example demonstrates how to: load TensorFlow model into Java, make inference in distributed environment using + * Apache Ignite services. + */ +public class TensorFlowDistributedInferenceExample { + /** Path to the directory with saved TensorFlow model. */ + private static final String MODEL_PATH = "examples/src/main/resources/ml/mnist_tf_model"; + + /** Path to the MNIST images data. */ + private static final String MNIST_IMG_PATH = "org/apache/ignite/examples/ml/util/datasets/t10k-images-idx3-ubyte"; + + /** Path to the MNIST labels data. */ + private static final String MNIST_LBL_PATH = "org/apache/ignite/examples/ml/util/datasets/t10k-labels-idx1-ubyte"; + + /** Run example. */ + public static void main(String[] args) throws IOException, ExecutionException, InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + File mdlRsrc = IgniteUtils.resolveIgnitePath(MODEL_PATH); + if (mdlRsrc == null) + throw new IllegalArgumentException("Resource not found [resource_path=" + MODEL_PATH + "]"); + + InfModelReader reader = new FileSystemInfModelReader(mdlRsrc.getPath()); + + InfModelParser parser = new TensorFlowSavedModelInfModelParser("serve") + + .withInput("Placeholder", doubles -> { + float[][][] reshaped = new float[1][28][28]; + for (int i = 0; i < doubles.length; i++) + reshaped[0][i / 28][i % 28] = (float)doubles[i]; + return Tensor.create(reshaped); + }) + + .withOutput(Collections.singletonList("ArgMax"), collectedTensors -> { + return collectedTensors.get("ArgMax").copyTo(new long[1])[0]; + }); + + List images = MnistUtils.mnistAsListFromResource( + MNIST_IMG_PATH, + MNIST_LBL_PATH, + new Random(0), + 10000 + ); + + long t0 = System.currentTimeMillis(); + + try (InfModel> threadedMdl = new IgniteDistributedInfModelBuilder(ignite, 4, 4) + .build(reader, parser)) { + List> futures = new ArrayList<>(images.size()); + for (MnistUtils.MnistLabeledImage image : images) + futures.add(threadedMdl.predict(image.getPixels())); + for (Future f : futures) + f.get(); + } + + long t1 = System.currentTimeMillis(); + + System.out.println("Threaded model throughput: " + images.size() / ((t1 - t0) / 1000.0) + " req/sec"); + } + } +} diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowLocalInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowLocalInferenceExample.java new file mode 100644 index 0000000000000..fc25c7e93854c --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowLocalInferenceExample.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.inference; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.builder.SingleInfModelBuilder; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.parser.TensorFlowSavedModelInfModelParser; +import org.apache.ignite.ml.inference.reader.FileSystemInfModelReader; +import org.apache.ignite.ml.inference.reader.InfModelReader; +import org.apache.ignite.ml.util.MnistUtils; +import org.tensorflow.Tensor; + +/** + * This example demonstrates how to: load TensorFlow model into Java, make inference using this model in one thread. + */ +public class TensorFlowLocalInferenceExample { + /** Path to the directory with saved TensorFlow model. */ + private static final String MODEL_PATH = "examples/src/main/resources/ml/mnist_tf_model"; + + /** Path to the MNIST images data. */ + private static final String MNIST_IMG_PATH = "org/apache/ignite/examples/ml/util/datasets/t10k-images-idx3-ubyte"; + + /** Path to the MNIST labels data. */ + private static final String MNIST_LBL_PATH = "org/apache/ignite/examples/ml/util/datasets/t10k-labels-idx1-ubyte"; + + /** Run example. */ + public static void main(String[] args) throws IOException { + File mdlRsrc = IgniteUtils.resolveIgnitePath(MODEL_PATH); + if (mdlRsrc == null) + throw new IllegalArgumentException("Resource not found [resource_path=" + MODEL_PATH + "]"); + + InfModelReader reader = new FileSystemInfModelReader(mdlRsrc.getPath()); + + InfModelParser parser = new TensorFlowSavedModelInfModelParser("serve") + .withInput("Placeholder", doubles -> { + float[][][] reshaped = new float[1][28][28]; + for (int i = 0; i < doubles.length; i++) + reshaped[0][i / 28][i % 28] = (float)doubles[i]; + return Tensor.create(reshaped); + }) + .withOutput(Collections.singletonList("ArgMax"), collectedTensors -> { + return collectedTensors.get("ArgMax").copyTo(new long[1])[0]; + }); + + List images = MnistUtils.mnistAsListFromResource( + MNIST_IMG_PATH, + MNIST_LBL_PATH, + new Random(0), + 10000 + ); + + long t0 = System.currentTimeMillis(); + + try (InfModel locMdl = new SingleInfModelBuilder().build(reader, parser)) { + for (MnistUtils.MnistLabeledImage image : images) + locMdl.predict(image.getPixels()); + } + + long t1 = System.currentTimeMillis(); + + System.out.println("Threaded model throughput: " + 1.0 * images.size() / ((t1 - t0) / 1000) + " req/sec"); + } +} diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowThreadedInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowThreadedInferenceExample.java new file mode 100644 index 0000000000000..d756016236f4e --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowThreadedInferenceExample.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.inference; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.builder.ThreadedInfModelBuilder; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.parser.TensorFlowSavedModelInfModelParser; +import org.apache.ignite.ml.inference.reader.FileSystemInfModelReader; +import org.apache.ignite.ml.inference.reader.InfModelReader; +import org.apache.ignite.ml.util.MnistUtils; +import org.tensorflow.Tensor; + +/** + * This example demonstrates how to: load TensorFlow model into Java, make inference using this model in multiple + * threads. + */ +public class TensorFlowThreadedInferenceExample { + /** Path to the directory with saved TensorFlow model. */ + private static final String MODEL_PATH = "examples/src/main/resources/ml/mnist_tf_model"; + + /** Path to the MNIST images data. */ + private static final String MNIST_IMG_PATH = "org/apache/ignite/examples/ml/util/datasets/t10k-images-idx3-ubyte"; + + /** Path to the MNIST labels data. */ + private static final String MNIST_LBL_PATH = "org/apache/ignite/examples/ml/util/datasets/t10k-labels-idx1-ubyte"; + + /** Run example. */ + public static void main(String[] args) throws IOException, ExecutionException, InterruptedException { + File mdlRsrc = IgniteUtils.resolveIgnitePath(MODEL_PATH); + if (mdlRsrc == null) + throw new IllegalArgumentException("Resource not found [resource_path=" + MODEL_PATH + "]"); + + InfModelReader reader = new FileSystemInfModelReader(mdlRsrc.getPath()); + + InfModelParser parser = new TensorFlowSavedModelInfModelParser("serve") + + .withInput("Placeholder", doubles -> { + float[][][] reshaped = new float[1][28][28]; + for (int i = 0; i < doubles.length; i++) + reshaped[0][i / 28][i % 28] = (float)doubles[i]; + return Tensor.create(reshaped); + }) + + .withOutput(Collections.singletonList("ArgMax"), collectedTensors -> { + return collectedTensors.get("ArgMax").copyTo(new long[1])[0]; + }); + + List images = MnistUtils.mnistAsListFromResource( + MNIST_IMG_PATH, + MNIST_LBL_PATH, + new Random(0), + 10000 + ); + + long t0 = System.currentTimeMillis(); + + try (InfModel> threadedMdl = new ThreadedInfModelBuilder(8) + .build(reader, parser)) { + List> futures = new ArrayList<>(images.size()); + for (MnistUtils.MnistLabeledImage image : images) + futures.add(threadedMdl.predict(image.getPixels())); + for (Future f : futures) + f.get(); + } + + long t1 = System.currentTimeMillis(); + + System.out.println("Threaded model throughput: " + 1.0 * images.size() / ((t1 - t0) / 1000) + " req/sec"); + } +} diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/package-info.java new file mode 100644 index 0000000000000..4f0c5e5c48685 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Model inference examples. + */ +package org.apache.ignite.examples.ml.inference; \ No newline at end of file diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/util/datasets/t10k-images-idx3-ubyte b/examples/src/main/java/org/apache/ignite/examples/ml/util/datasets/t10k-images-idx3-ubyte new file mode 100644 index 0000000000000..1170b2cae98de Binary files /dev/null and b/examples/src/main/java/org/apache/ignite/examples/ml/util/datasets/t10k-images-idx3-ubyte differ diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/util/datasets/t10k-labels-idx1-ubyte b/examples/src/main/java/org/apache/ignite/examples/ml/util/datasets/t10k-labels-idx1-ubyte new file mode 100644 index 0000000000000..d1c3a970612bb Binary files /dev/null and b/examples/src/main/java/org/apache/ignite/examples/ml/util/datasets/t10k-labels-idx1-ubyte differ diff --git a/examples/src/main/resources/ml/mnist_tf_model/saved_model.pb b/examples/src/main/resources/ml/mnist_tf_model/saved_model.pb new file mode 100644 index 0000000000000..4d36671373b89 Binary files /dev/null and b/examples/src/main/resources/ml/mnist_tf_model/saved_model.pb differ diff --git a/examples/src/main/resources/ml/mnist_tf_model/variables/variables.data-00000-of-00001 b/examples/src/main/resources/ml/mnist_tf_model/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000..a65398faadfc8 Binary files /dev/null and b/examples/src/main/resources/ml/mnist_tf_model/variables/variables.data-00000-of-00001 differ diff --git a/examples/src/main/resources/ml/mnist_tf_model/variables/variables.index b/examples/src/main/resources/ml/mnist_tf_model/variables/variables.index new file mode 100644 index 0000000000000..221dd2de213bc Binary files /dev/null and b/examples/src/main/resources/ml/mnist_tf_model/variables/variables.index differ diff --git a/modules/ml/pom.xml b/modules/ml/pom.xml index ad31da256645d..69c77fff09996 100644 --- a/modules/ml/pom.xml +++ b/modules/ml/pom.xml @@ -23,6 +23,7 @@ 4.0.0 3.6.1 + 1.12.0 @@ -115,6 +116,18 @@ 1.0 + + org.tensorflow + tensorflow + ${tensorflow.version} + + + + org.tensorflow + proto + ${tensorflow.version} + + org.mockito mockito-all diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/InfModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/InfModel.java new file mode 100644 index 0000000000000..c2f6b9579f1ba --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/InfModel.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference; + +/** + * Inference model that can be used to make predictions. + * + * @param Type of model input. + * @param Type of model output. + */ +public interface InfModel extends AutoCloseable { + /** + * Make a prediction for the specified input arguments. + * + * @param input Input arguments. + * @return Prediction result. + */ + public O predict(I input); + + /** {@inheritDoc} */ + @Override public void close(); +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/ModelDescriptor.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/ModelDescriptor.java new file mode 100644 index 0000000000000..fbebef74b9349 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/ModelDescriptor.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference; + +import java.io.Serializable; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.reader.InfModelReader; + +/** + * Model descriptor that encapsulates information about model, {@link InfModelReader} and {@link InfModelParser} which + * is required to build the model. + */ +public class ModelDescriptor implements Serializable { + /** Model name. */ + private final String name; + + /** Model description. */ + private final String desc; + + /** Model signature that keeps input/output types in Protobuf. */ + private final ModelSignature signature; + + /** Model reader. */ + private final InfModelReader reader; + + /** Model parser. */ + private final InfModelParser parser; + + /** + * Constructs a new instance of model descriptor. + * + * @param name Model name. + * @param desc Model description. + * @param signature Model signature that keeps input/output types in Protobuf. + * @param reader Model reader. + * @param parser Model parser. + */ + public ModelDescriptor(String name, String desc, ModelSignature signature, InfModelReader reader, + InfModelParser parser) { + this.name = name; + this.desc = desc; + this.signature = signature; + this.reader = reader; + this.parser = parser; + } + + /** */ + public String getName() { + return name; + } + + /** */ + public String getDesc() { + return desc; + } + + /** */ + public ModelSignature getSignature() { + return signature; + } + + /** */ + public InfModelReader getReader() { + return reader; + } + + /** */ + public InfModelParser getParser() { + return parser; + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/ModelSignature.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/ModelSignature.java new file mode 100644 index 0000000000000..1a5a24532a037 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/ModelSignature.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference; + +import java.io.Serializable; + +/** + * Signature that defines input/output types in Protobuf. + */ +public class ModelSignature implements Serializable { + /** Protobuf schema of all objects required in the model. */ + private final String schema; + + /** Name of the input type (should be presented in the {@link #schema}. */ + private final String inputMsg; + + /** Name of ths output type (should be presented in the {@link #schema}). */ + private final String outputMsg; + + /** + * Constructs a new instance of model signature. + * + * @param schema Protobuf schema of all objects required in the model. + * @param inputMsg Name of the input type (should be presented in the {@link #schema}. + * @param outputMsg Name of ths output type (should be presented in the {@link #schema}). + */ + public ModelSignature(String schema, String inputMsg, String outputMsg) { + this.schema = schema; + this.inputMsg = inputMsg; + this.outputMsg = outputMsg; + } + + /** */ + public String getSchema() { + return schema; + } + + /** */ + public String getInputMsg() { + return inputMsg; + } + + /** */ + public String getOutputMsg() { + return outputMsg; + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/AsyncInfModelBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/AsyncInfModelBuilder.java new file mode 100644 index 0000000000000..adf4659dfe7d2 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/AsyncInfModelBuilder.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import java.io.Serializable; +import java.util.concurrent.Future; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.reader.InfModelReader; + +/** + * Builder of asynchronous inference model. Uses specified model reader (see {@link InfModelReader}) and mode parser + * (see {@link InfModelParser}) to build a model. + */ +@FunctionalInterface +public interface AsyncInfModelBuilder { + /** + * Builds asynchronous inference model using specified model reader and model parser. + * + * @param reader Model reader. + * @param parser Model parser. + * @param Type of model input. + * @param Type of model output. + * @return Inference model. + */ + public InfModel> build(InfModelReader reader, + InfModelParser parser); +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilder.java new file mode 100644 index 0000000000000..7a176e045b980 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilder.java @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import java.io.Serializable; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteQueue; +import org.apache.ignite.Ignition; +import org.apache.ignite.configuration.CollectionConfiguration; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.reader.InfModelReader; +import org.apache.ignite.services.Service; +import org.apache.ignite.services.ServiceContext; + +/** + * Builder that allows to start Apache Ignite services for distributed inference and get a facade that allows to work + * with this distributed inference infrastructure as with a single inference model (see {@link InfModel}). + * + * The common workflow is based on a request/response queues and multiple workers represented by Apache Ignite services. + * When the {@link #build(InfModelReader, InfModelParser)} method is called Apache Ignite starts the specified number of + * service instances and request/response queues. Each service instance reads request queue, processes inbound requests + * and writes responses to response queue. The facade returned by the {@link #build(InfModelReader, InfModelParser)} + * method operates with request/response queues. When the {@link InfModel#predict(Object)} method is called the argument + * is sent as a request to the request queue. When the response is appeared in the response queue the {@link Future} + * correspondent to the previously sent request is completed and the processing finishes. + * + * Be aware that {@link InfModel#close()} method must be called to clear allocated resources, stop services and remove + * queues. + */ +public class IgniteDistributedInfModelBuilder implements AsyncInfModelBuilder { + /** Template of the inference service name. */ + private static final String INFERENCE_SERVICE_NAME_PATTERN = "inference_service_%s"; + + /** Template of the inference request queue name. */ + private static final String INFERENCE_REQUEST_QUEUE_NAME_PATTERN = "inference_queue_req_%s"; + + /** Template of the inference response queue name. */ + private static final String INFERENCE_RESPONSE_QUEUE_NAME_PATTERN = "inference_queue_res_%s"; + + /** Default capacity for all queues used in this class (request queue, response queue, received queue). */ + private static final int QUEUE_CAPACITY = 100; + + /** Default configuration for Apache Ignite queues used in this class (request queue, response queue). */ + private static final CollectionConfiguration queueCfg = new CollectionConfiguration(); + + /** Ignite instance. */ + private final Ignite ignite; + + /** Number of service instances maintaining to make distributed inference. */ + private final int instances; + + /** Max per node number of instances. */ + private final int maxPerNode; + + /** + * Constructs a new instance of Ignite distributed inference model builder. + * + * @param ignite Ignite instance. + * @param instances Number of service instances maintaining to make distributed inference. + * @param maxPerNode Max per node number of instances. + */ + public IgniteDistributedInfModelBuilder(Ignite ignite, int instances, int maxPerNode) { + this.ignite = ignite; + this.instances = instances; + this.maxPerNode = maxPerNode; + } + + /** + * Starts the specified in constructor number of service instances and request/response queues. Each service + * instance reads request queue, processes inbound requests and writes responses to response queue. The returned + * facade is represented by the {@link InfModel} operates with request/response queues, but hides these details + * behind {@link InfModel#predict(Object)} method of {@link InfModel}. + * + * Be aware that {@link InfModel#close()} method must be called to clear allocated resources, stop services and + * remove queues. + * + * @param reader Inference model reader. + * @param parser Inference model parser. + * @param Type of model input. + * @param Type of model output. + * @return Facade represented by {@link InfModel}. + */ + @Override public InfModel> build( + InfModelReader reader, InfModelParser parser) { + return new DistributedInfModel<>(ignite, UUID.randomUUID().toString(), reader, parser, instances, maxPerNode); + } + + /** + * Facade that operates with request/response queues to make distributed inference, but hides these details + * behind {@link InfModel#predict(Object)} method of {@link InfModel}. + * + * Be aware that {@link InfModel#close()} method must be called to clear allocated resources, stop services and + * remove queues. + * + * @param Type of model input. + * @param Type of model output. + */ + private static class DistributedInfModel + implements InfModel> { + /** Ignite instance. */ + private final Ignite ignite; + + /** Suffix that with correspondent templates formats service and queue names. */ + private final String suffix; + + /** Request queue. */ + private final IgniteQueue reqQueue; + + /** Response queue. */ + private final IgniteQueue resQueue; + + /** Futures that represents requests that have been sent, but haven't been responded yet. */ + private final BlockingQueue> futures = new ArrayBlockingQueue<>(QUEUE_CAPACITY); + + /** Thread pool for receiver to work in. */ + private final ExecutorService receiverThreadPool = Executors.newSingleThreadExecutor(); + + /** Flag identified that model is up and running. */ + private final AtomicBoolean running = new AtomicBoolean(false); + + /** Receiver future. */ + private volatile Future receiverFut; + + /** + * Constructs a new instance of distributed inference model. + * + * @param ignite Ignite instance. + * @param suffix Suffix that with correspondent templates formats service and queue names. + * @param reader Inference model reader. + * @param parser Inference model parser. + * @param instances Number of service instances maintaining to make distributed inference. + * @param maxPerNode Max per node number of instances. + */ + DistributedInfModel(Ignite ignite, String suffix, InfModelReader reader, InfModelParser parser, + int instances, int maxPerNode) { + this.ignite = ignite; + this.suffix = suffix; + + reqQueue = ignite.queue(String.format(INFERENCE_REQUEST_QUEUE_NAME_PATTERN, suffix), QUEUE_CAPACITY, + queueCfg); + resQueue = ignite.queue(String.format(INFERENCE_RESPONSE_QUEUE_NAME_PATTERN, suffix), QUEUE_CAPACITY, + queueCfg); + + startReceiver(); + startService(reader, parser, instances, maxPerNode); + + running.set(true); + } + + /** {@inheritDoc} */ + @Override public Future predict(I input) { + if (!running.get()) + throw new IllegalStateException("Inference model is not running"); + + reqQueue.put(input); + + try { + CompletableFuture fut = new CompletableFuture<>(); + futures.put(fut); + return fut; + } + catch (InterruptedException e) { + close(); // In case of exception in the above code the model state becomes invalid and model is closed. + throw new RuntimeException(e); + } + } + + /** + * Starts Apache Ignite services that represent distributed inference infrastructure. + * + * @param reader Inference model reader. + * @param parser Inference model parser. + * @param instances Number of service instances maintaining to make distributed inference. + * @param maxPerNode Max per node number of instances. + */ + private void startService(InfModelReader reader, InfModelParser parser, int instances, int maxPerNode) { + ignite.services().deployMultiple( + String.format(INFERENCE_SERVICE_NAME_PATTERN, suffix), + new IgniteDistributedInfModelService<>(reader, parser, suffix), + instances, + maxPerNode + ); + } + + /** + * Stops Apache Ignite services that represent distributed inference infrastructure. + */ + private void stopService() { + ignite.services().cancel(String.format(INFERENCE_SERVICE_NAME_PATTERN, suffix)); + } + + /** + * Starts the thread that reads the response queue and completed correspondent futures from {@link #futures} + * queue. + */ + private void startReceiver() { + receiverFut = receiverThreadPool.submit(() -> { + try { + while (!Thread.currentThread().isInterrupted()) { + O res; + try { + res = resQueue.take(); + } + catch (IllegalStateException e) { + if (!resQueue.removed()) + throw e; + continue; + } + + CompletableFuture fut = futures.remove(); + fut.complete(res); + } + } + finally { + close(); // If the model is not stopped yet we need to stop it to protect queue from new writes. + while (!futures.isEmpty()) { + CompletableFuture fut = futures.remove(); + fut.cancel(true); + } + } + }); + } + + /** + * Stops receiver thread that reads the response queue and completed correspondent futures from + * {@link #futures} queue. + */ + private void stopReceiver() { + if (receiverFut != null && !receiverFut.isDone()) + receiverFut.cancel(true); + // The receiver thread pool is not reused, so it should be closed here. + receiverThreadPool.shutdown(); + } + + /** + * Remove request/response Ignite queues. + */ + private void removeQueues() { + reqQueue.close(); + resQueue.close(); + } + + /** {@inheritDoc} */ + @Override public void close() { + boolean runningBefore = running.getAndSet(false); + + if (runningBefore) { + stopService(); + stopReceiver(); + removeQueues(); + } + } + } + + /** + * Apache Ignite service that makes inference reading requests from the request queue and writing responses to the + * response queue. This service is assumed to be deployed in {@link #build(InfModelReader, InfModelParser)} method + * and cancelled in {@link InfModel#close()} method of the inference model. + * + * @param Type of model input. + * @param Type of model output. + */ + private static class IgniteDistributedInfModelService + implements Service { + /** */ + private static final long serialVersionUID = -3596084917874395597L; + + /** Inference model reader. */ + private final InfModelReader reader; + + /** Inference model parser. */ + private final InfModelParser parser; + + /** Suffix that with correspondent templates formats service and queue names. */ + private final String suffix; + + /** Request queue, is created in {@link #init(ServiceContext)} method. */ + private transient IgniteQueue reqQueue; + + /** Response queue, is created in {@link #init(ServiceContext)} method. */ + private transient IgniteQueue resQueue; + + /** Inference model, is created in {@link #init(ServiceContext)} method. */ + private transient InfModel mdl; + + /** + * Constructs a new instance of Ignite distributed inference model service. + * + * @param reader Inference model reader. + * @param parser Inference model parser. + * @param suffix Suffix that with correspondent templates formats service and queue names. + */ + IgniteDistributedInfModelService(InfModelReader reader, InfModelParser parser, String suffix) { + this.reader = reader; + this.parser = parser; + this.suffix = suffix; + } + + /** {@inheritDoc} */ + @Override public void init(ServiceContext ctx) { + Ignite ignite = Ignition.localIgnite(); + + reqQueue = ignite.queue(String.format(INFERENCE_REQUEST_QUEUE_NAME_PATTERN, suffix), QUEUE_CAPACITY, + queueCfg); + resQueue = ignite.queue(String.format(INFERENCE_RESPONSE_QUEUE_NAME_PATTERN, suffix), QUEUE_CAPACITY, + queueCfg); + + mdl = parser.parse(reader.read()); + } + + /** {@inheritDoc} */ + @Override public void execute(ServiceContext ctx) { + while (!ctx.isCancelled()) { + I req; + try { + req = reqQueue.take(); + } + catch (IllegalStateException e) { + // If the queue is removed during the take() operation exception should be ignored. + if (!reqQueue.removed()) + throw e; + continue; + } + + O res = mdl.predict(req); + + try { + resQueue.put(res); + } + catch (IllegalStateException e) { + // If the queue is removed during the put() operation exception should be ignored. + if (!resQueue.removed()) + throw e; + } + } + } + + /** {@inheritDoc} */ + @Override public void cancel(ServiceContext ctx) { + // Do nothing. Queues are assumed to be closed in model close() method. + } + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilder.java new file mode 100644 index 0000000000000..f756f4524d62a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilder.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import java.io.Serializable; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.reader.InfModelReader; + +/** + * Implementation of synchronous inference model builder that builds a model processed locally in a single thread. + */ +public class SingleInfModelBuilder implements SyncInfModelBuilder { + /** {@inheritDoc} */ + @Override public InfModel build(InfModelReader reader, + InfModelParser parser) { + return parser.parse(reader.read()); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/SyncInfModelBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/SyncInfModelBuilder.java new file mode 100644 index 0000000000000..7aed8b883dc5c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/SyncInfModelBuilder.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import java.io.Serializable; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.reader.InfModelReader; + +/** + * Builder of synchronous inference model. Uses specified model reader (see {@link InfModelReader}) and mode parser (see + * {@link InfModelParser}) to build a model. + */ +@FunctionalInterface +public interface SyncInfModelBuilder { + /** + * Builds synchronous inference model using specified model reader and model parser. + * + * @param reader Model reader. + * @param parser Model parser. + * @param Type of model input. + * @param Type of model output. + * @return Inference model. + */ + public InfModel build(InfModelReader reader, + InfModelParser parser); +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilder.java new file mode 100644 index 0000000000000..ff538de4abf47 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilder.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import java.io.Serializable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.reader.InfModelReader; + +/** + * Implementation of asynchronous inference model builder that builds model processed locally utilizing specified number + * of threads. + */ +public class ThreadedInfModelBuilder implements AsyncInfModelBuilder { + /** Number of threads to be utilized for model inference. */ + private final int threads; + + /** + * Constructs a new instance of threaded inference model builder. + * + * @param threads Number of threads to be utilized for model inference. + */ + public ThreadedInfModelBuilder(int threads) { + this.threads = threads; + } + + /** {@inheritDoc} */ + @Override public InfModel> build( + InfModelReader reader, InfModelParser parser) { + return new ThreadedInfModel<>(parser.parse(reader.read()), threads); + } + + /** + * Threaded inference model that performs inference in multiply threads. + * + * @param Type of model input. + * @param Type of model output. + */ + private static class ThreadedInfModel + implements InfModel> { + /** Inference model. */ + private final InfModel mdl; + + /** Thread pool. */ + private final ExecutorService threadPool; + + /** + * Constructs a new instance of threaded inference model. + * + * @param mdl Inference model. + * @param threads Thread pool. + */ + ThreadedInfModel(InfModel mdl, int threads) { + this.mdl = mdl; + this.threadPool = Executors.newFixedThreadPool(threads); + } + + /** {@inheritDoc} */ + @Override public Future predict(I input) { + return threadPool.submit(() -> mdl.predict(input)); + } + + /** {@inheritDoc} */ + @Override public void close() { + threadPool.shutdown(); + } + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/package-info.java new file mode 100644 index 0000000000000..bed2e70ec304e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Root package for model inference builders. + */ +package org.apache.ignite.ml.inference.builder; \ No newline at end of file diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/package-info.java new file mode 100644 index 0000000000000..f2ce68c8bb366 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Root package for model inference functionality. + */ +package org.apache.ignite.ml.inference; \ No newline at end of file diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/IgniteFunctionInfModelParser.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/IgniteFunctionInfModelParser.java new file mode 100644 index 0000000000000..a4f13772c87c9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/IgniteFunctionInfModelParser.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.parser; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Implementation of model parser that accepts serialized {@link IgniteFunction}. + * + * @param Type of model input. + * @param Type of model output. + */ +public class IgniteFunctionInfModelParser implements InfModelParser { + /** */ + private static final long serialVersionUID = -4624683614990816434L; + + /** {@inheritDoc} */ + @Override public InfModel parse(byte[] mdl) { + try (ByteArrayInputStream bais = new ByteArrayInputStream(mdl); + ObjectInputStream ois = new ObjectInputStream(bais)) { + @SuppressWarnings("unchecked") + IgniteFunction function = (IgniteFunction)ois.readObject(); + + return new IgniteFunctionInfoModel(function); + } + catch (IOException | ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + /** + * Inference model that wraps {@link IgniteFunction}. + */ + private class IgniteFunctionInfoModel implements InfModel { + /** Ignite function. */ + private final IgniteFunction function; + + /** + * Constructs a new instance of Ignite function. + * + * @param function Ignite function. + */ + IgniteFunctionInfoModel(IgniteFunction function) { + this.function = function; + } + + /** {@inheritDoc} */ + @Override public O predict(I input) { + return function.apply(input); + } + + /** {@inheritDoc} */ + @Override public void close() { + // Do nothing. + } + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/InfModelParser.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/InfModelParser.java new file mode 100644 index 0000000000000..fa62558dbb956 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/InfModelParser.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.parser; + +import java.io.Serializable; +import org.apache.ignite.ml.inference.InfModel; + +/** + * Model parser that accepts a serialized model represented by byte array, parses it and returns {@link InfModel}. + * + * @param Type of model input. + * @param Type of model output. + */ +@FunctionalInterface +public interface InfModelParser extends Serializable { + /** + * Accepts serialized model represented by byte array, parses it and returns {@link InfModel}. + * + * @param mdl Serialized model represented by byte array. + * @return Inference model. + */ + public InfModel parse(byte[] mdl); +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowBaseInfModelParser.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowBaseInfModelParser.java new file mode 100644 index 0000000000000..acc521fb1926f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowBaseInfModelParser.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.parser; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.ignite.ml.inference.InfModel; +import org.tensorflow.Session; +import org.tensorflow.Tensor; + +/** + * Base class for TensorFlow model parsers. Contains the logic that is common for models saved as "SavedModel" and as a + * simple graph. + * + * @param Type of model input. + * @param Type of model output. + */ +public abstract class TensorFlowBaseInfModelParser implements InfModelParser { + /** */ + private static final long serialVersionUID = 5574259553625871456L; + + /** Map of input graph nodes (placeholders) and transformers that allow to transform input into tensor. */ + private final Map> inputs = new HashMap<>(); + + /** List of output graph nodes. */ + private List outputNames; + + /** Transformer that allows to transform tensors into output. */ + private OutputTransformer outputTransformer; + + /** {@inheritDoc} */ + @Override public InfModel parse(byte[] mdl) { + return new TensorFlowInfModel(parseModel(mdl)); + } + + /** + * Parses model specified in serialized form as byte array. + * + * @param mdl Inference model in serialized form as byte array. + * @return TensorFlow session that encapsulates the TensorFlow graph parsed from serialized model. + */ + public abstract Session parseModel(byte[] mdl); + + /** + * Setter that allows to specify additional input graph node and correspondent transformer that allows to transform + * input into tensor. + * + * @param name Name of the input graph node. + * @param transformer Transformer that allows to transform input into tensor. + * @return This instance. + */ + public TensorFlowBaseInfModelParser withInput(String name, InputTransformer transformer) { + if (inputs.containsKey(name)) + throw new IllegalArgumentException("Inputs already contains specified name [name=" + name + "]"); + + inputs.put(name, transformer); + + return this; + } + + /** + * Setter that allows to specify output graph nodes and correspondent transformer that allow to transform tensors + * into output. + * + * @param names List of output graph node names. + * @param transformer Transformer that allow to transform tensors into output. + * @return This instance. + */ + public TensorFlowBaseInfModelParser withOutput(List names, OutputTransformer transformer) { + if (outputNames != null || outputTransformer != null) + throw new IllegalArgumentException("Outputs already specified"); + + outputNames = names; + outputTransformer = transformer; + + return this; + } + + /** + * Input transformer that accepts input and transforms it into tensor. + * + * @param Type of model input. + */ + @FunctionalInterface + public interface InputTransformer extends Serializable { + /** + * Transforms input into tensor. + * + * @param input Input data. + * @return Tensor (transformed input data). + */ + public Tensor transform(I input); + } + + /** + * Output transformer that accepts tensors and transforms them into output. + * + * @param Type of model output. + */ + @FunctionalInterface + public interface OutputTransformer extends Serializable { + /** + * Transforms tensors into output. + * + * @param output Tensors. + * @return Output (transformed tensors). + */ + public O transform(Map> output); + } + + /** + * TensorFlow inference model based on pre-loaded graph and created session. + */ + private class TensorFlowInfModel implements InfModel { + /** TensorFlow session. */ + private final Session ses; + + /** + * Constructs a new instance of TensorFlow inference model. + * + * @param ses TensorFlow session. + */ + TensorFlowInfModel(Session ses) { + this.ses = ses; + } + + /** {@inheritDoc} */ + @Override public O predict(I input) { + Session.Runner runner = ses.runner(); + + runner = feedAll(runner, input); + runner = fetchAll(runner); + + List> prediction = runner.run(); + Map> collectedPredictionTensors = indexTensors(prediction); + + return outputTransformer.transform(collectedPredictionTensors); + } + + /** + * Feeds input into graphs input nodes using input transformers (see {@link #inputs}). + * + * @param runner TensorFlow session runner. + * @param input Input. + * @return TensorFlow session runner. + */ + private Session.Runner feedAll(Session.Runner runner, I input) { + for (Map.Entry> e : inputs.entrySet()) { + String opName = e.getKey(); + InputTransformer transformer = e.getValue(); + + runner = runner.feed(opName, transformer.transform(input)); + } + + return runner; + } + + /** + * Specifies graph output nodes to be fetched using {@link #outputNames}. + * + * @param runner TensorFlow session runner. + * @return TensorFlow session runner. + */ + private Session.Runner fetchAll(Session.Runner runner) { + for (String e : outputNames) + runner.fetch(e); + + return runner; + } + + /** + * Indexes tensors fetched from graph using {@link #outputNames}. + * + * @param tensors List of fetched tensors. + * @return Map of tensor name as a key and tensor as a value. + */ + private Map> indexTensors(List> tensors) { + Map> collectedTensors = new HashMap<>(); + + Iterator outputNamesIter = outputNames.iterator(); + Iterator> tensorsIter = tensors.iterator(); + + while (outputNamesIter.hasNext() && tensorsIter.hasNext()) + collectedTensors.put(outputNamesIter.next(), tensorsIter.next()); + + // We expect that output names and output tensors have the same size. + if (outputNamesIter.hasNext() || tensorsIter.hasNext()) + throw new IllegalStateException("Outputs are incorrect"); + + return collectedTensors; + } + + /** {@inheritDoc} */ + @Override public void close() { + ses.close(); + } + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowGraphInfModelParser.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowGraphInfModelParser.java new file mode 100644 index 0000000000000..7c547aedc9e61 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowGraphInfModelParser.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.parser; + +import org.tensorflow.Graph; +import org.tensorflow.Session; + +/** + * Implementation of TensorFlow model parser that accepts serialized graph definition. + * + * @param Type of model input. + * @param Type of model output. + */ +public class TensorFlowGraphInfModelParser extends TensorFlowBaseInfModelParser { + /** */ + private static final long serialVersionUID = -1872566748640565856L; + + /** {@inheritDoc} */ + @Override public Session parseModel(byte[] mdl) { + Graph graph = new Graph(); + graph.importGraphDef(mdl); + + return new Session(graph); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowSavedModelInfModelParser.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowSavedModelInfModelParser.java new file mode 100644 index 0000000000000..2ee9f1168f220 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowSavedModelInfModelParser.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.parser; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import org.apache.ignite.ml.inference.util.DirectorySerializer; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +/** + * Implementation of TensorFlow model parser that accepts serialized directory with "SavedModel" as an input. The + * directory is assumed to be serialized by {@link DirectorySerializer}. + * + * @param Type of model input. + * @param Type of model output. + */ +public class TensorFlowSavedModelInfModelParser extends TensorFlowBaseInfModelParser { + /** */ + private static final long serialVersionUID = 5638083440240281879L; + + /** Prefix to be used to create temporary directory for TensorFlow model files. */ + private static final String TMP_DIR_PREFIX = "tensorflow_saved_model_"; + + /** Model tags. */ + private final String[] tags; + + /** + * Constructs a new instance of TensorFlow model parser. + * + * @param tags Model tags. + */ + public TensorFlowSavedModelInfModelParser(String... tags) { + this.tags = tags; + } + + /** {@inheritDoc} */ + @Override public Session parseModel(byte[] mdl) { + Path dir = null; + try { + dir = Files.createTempDirectory(TMP_DIR_PREFIX); + DirectorySerializer.deserialize(dir.toAbsolutePath(), mdl); + SavedModelBundle bundle = SavedModelBundle.load(dir.toString(), tags); + return bundle.session(); + } + catch (IOException | ClassNotFoundException e) { + throw new RuntimeException(e); + } + finally { + if (dir != null) + DirectorySerializer.deleteDirectory(dir); + } + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/package-info.java new file mode 100644 index 0000000000000..ce8c27be7d035 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Root package for model inference parsers. + */ +package org.apache.ignite.ml.inference.parser; \ No newline at end of file diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/FileSystemInfModelReader.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/FileSystemInfModelReader.java new file mode 100644 index 0000000000000..1ad2161db9d00 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/FileSystemInfModelReader.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.reader; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import org.apache.ignite.ml.inference.util.DirectorySerializer; + +/** + * Model reader that reads directory and serializes it using {@link DirectorySerializer}. + */ +public class FileSystemInfModelReader implements InfModelReader { + /** */ + private static final long serialVersionUID = 7370932792669930039L; + + /** Path to the directory. */ + private final String path; + + /** + * Constructs a new instance of directory model reader. + * + * @param path Path to the directory. + */ + public FileSystemInfModelReader(String path) { + this.path = path; + } + + /** {@inheritDoc} */ + @Override public byte[] read() { + try { + File file = Paths.get(path).toFile(); + if (!file.exists()) + throw new IllegalArgumentException("File or directory does not exist [path=" + path + "]"); + + if (file.isDirectory()) + return DirectorySerializer.serialize(Paths.get(path)); + else + return Files.readAllBytes(Paths.get(path)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/InMemoryInfModelReader.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/InMemoryInfModelReader.java new file mode 100644 index 0000000000000..6da31aae43b38 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/InMemoryInfModelReader.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.reader; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.io.Serializable; + +/** + * Model reader that reads predefined array of bytes. + */ +public class InMemoryInfModelReader implements InfModelReader { + /** */ + private static final long serialVersionUID = -5518861989758691500L; + + /** Data. */ + private final byte[] data; + + /** + * Constructs a new instance of in-memory model reader that returns specified byte array. + * + * @param data Data. + */ + public InMemoryInfModelReader(byte[] data) { + this.data = data; + } + + /** + * Constructs a new instance of in-memory model reader that returns serialized specified object. + * + * @param obj Data object. + * @param Type of data object. + */ + public InMemoryInfModelReader(T obj) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(obj); + oos.flush(); + + this.data = baos.toByteArray(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** {@inheritDoc} */ + @Override public byte[] read() { + return data; + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/InfModelReader.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/InfModelReader.java new file mode 100644 index 0000000000000..779a1ee8607ba --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/InfModelReader.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.reader; + +import java.io.Serializable; + +/** + * Model reader that reads model from external or internal storage and returns it in serialized form as byte array. + */ +@FunctionalInterface +public interface InfModelReader extends Serializable { + /** + * Rads model and returns it in serialized form as byte array. + * + * @return Inference model in serialized form as byte array. + */ + public byte[] read(); +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/package-info.java new file mode 100644 index 0000000000000..fba25881e78e8 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/reader/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Root package for model inference readers. + */ +package org.apache.ignite.ml.inference.reader; \ No newline at end of file diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/IgniteModelDescriptorStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/IgniteModelDescriptorStorage.java new file mode 100644 index 0000000000000..a198190027cd9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/IgniteModelDescriptorStorage.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.storage; + +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.ml.inference.ModelDescriptor; + +/** + * Model descriptor storage based on Apache Ignite cache. + */ +public class IgniteModelDescriptorStorage implements ModelDescriptorStorage { + /** Apache Ignite cache name to keep model descriptors. */ + private static final String MODEL_DESCRIPTOR_CACHE_NAME = "MODEL_DESCRIPTOR_CACHE"; + + /** Apache Ignite cache to keep model descriptors. */ + private final IgniteCache models; + + /** + * Constructs a new instance of Ignite model descriptor storage. + * + * @param ignite Ignite instance. + */ + public IgniteModelDescriptorStorage(Ignite ignite) { + models = ignite.getOrCreateCache(MODEL_DESCRIPTOR_CACHE_NAME); + } + + /** {@inheritDoc} */ + @Override public void put(String mdlId, ModelDescriptor mdl) { + models.put(mdlId, mdl); + } + + /** {@inheritDoc} */ + @Override public ModelDescriptor get(String mdlId) { + return models.get(mdlId); + } + + /** {@inheritDoc} */ + @Override public void remove(String mdlId) { + models.remove(mdlId); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/LocalModelDescriptorStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/LocalModelDescriptorStorage.java new file mode 100644 index 0000000000000..99e3dacbe321a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/LocalModelDescriptorStorage.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.storage; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.ignite.ml.inference.ModelDescriptor; + +/** + * Model descriptor storage based on local hash map. + */ +public class LocalModelDescriptorStorage implements ModelDescriptorStorage { + /** Hash map model storage. */ + private final Map models = new ConcurrentHashMap<>(); + + /** {@inheritDoc} */ + @Override public void put(String name, ModelDescriptor mdl) { + models.put(name, mdl); + } + + /** {@inheritDoc} */ + @Override public ModelDescriptor get(String name) { + return models.get(name); + } + + /** {@inheritDoc} */ + @Override public void remove(String name) { + models.remove(name); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/ModelDescriptorStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/ModelDescriptorStorage.java new file mode 100644 index 0000000000000..c124efbfccc2a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/ModelDescriptorStorage.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.storage; + +import org.apache.ignite.ml.inference.ModelDescriptor; + +/** + * Storage that allows to load, keep and get access to model descriptors (see {@link ModelDescriptor}). + */ +public interface ModelDescriptorStorage { + /** + * Saves the specified model descriptor with the specified model identifier. + * + * @param mdlId Model identifier. + * @param mdl Model descriptor. + */ + public void put(String mdlId, ModelDescriptor mdl); + + /** + * Returns model descriptor saved for the specified model identifier. + * + * @param mdlId Model identifier. + * @return Model descriptor. + */ + public ModelDescriptor get(String mdlId); + + /** + * Removes model descriptor for the specified model descriptor. + * + * @param mdlId Model identifier. + */ + public void remove(String mdlId); +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/package-info.java new file mode 100644 index 0000000000000..168f4e4f24ce6 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Root package for model inference descriptor storages. + */ +package org.apache.ignite.ml.inference.storage; \ No newline at end of file diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/util/DirectorySerializer.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/util/DirectorySerializer.java new file mode 100644 index 0000000000000..669571895e189 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/util/DirectorySerializer.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.util; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +/** + * Utils class that helps to serialize directory content as a has map and then deserialize it. + */ +public class DirectorySerializer { + /** + * Serializes directory content. + * + * @param path Path to the directory. + * @return Serialized directory content. + * @throws IOException If directory cannot be serialized. + */ + public static byte[] serialize(Path path) throws IOException { + File file = path.toFile(); + + if (!file.isDirectory()) + throw new IllegalStateException("Path is not directory [path=\"" + path + "\"]"); + + Map data = new HashMap<>(); + serialize(data, path, file); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos); + oos.writeObject(data); + oos.flush(); + + return baos.toByteArray(); + } + + /** + * Deserializes directory content. + * + * @param path Path to the directory. + * @param data Serialized content. + * @throws IOException If the directory cannot be deserialized. + * @throws ClassNotFoundException If the directory cannot be deserialized. + */ + @SuppressWarnings("unchecked") + public static void deserialize(Path path, byte[] data) throws IOException, ClassNotFoundException { + ByteArrayInputStream bais = new ByteArrayInputStream(data); + ObjectInputStream ois = new ObjectInputStream(bais); + Map files = (Map)ois.readObject(); + + for (Map.Entry file : files.entrySet()) { + Path dst = path.resolve(file.getKey()); + File dstFile = dst.toFile(); + Files.createDirectories(dstFile.getParentFile().toPath()); + Files.createFile(dstFile.toPath()); + FileOutputStream fos = new FileOutputStream(dstFile); + fos.write(file.getValue()); + fos.flush(); + } + } + + /** + * Removes the specified directory. + * + * @param path Path to the directory. + */ + public static void deleteDirectory(Path path) { + File file = path.toFile(); + if (file.isDirectory()) { + File[] children = file.listFiles(); + if (children != null) { + for (File child : children) + deleteDirectory(child.toPath()); + } + } + + try { + Files.delete(path); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Serializes directory content or file. + * + * @param data Storage to keep pairs of file name and file content. + * @param basePath Base path to the serialized directory. + * @param file File to be serialized. + * @throws IOException If the file cannot be serialized. + */ + private static void serialize(Map data, Path basePath, File file) throws IOException { + if (file.isFile()) { + String relative = basePath.relativize(file.toPath()).toString(); + byte[] bytes = Files.readAllBytes(file.toPath()); + data.put(relative, bytes); + } + else { + File[] children = file.listFiles(); + if (children != null) { + for (File child : children) + serialize(data, basePath, child); + } + } + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/util/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/util/package-info.java new file mode 100644 index 0000000000000..f0b9d8790b184 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/util/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Root package for util classes used in {@link org.apache.ignite.ml.inference} package. + */ +package org.apache.ignite.ml.inference.util; \ No newline at end of file diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java index 503572d8d3b22..42cf1f0168751 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java @@ -20,6 +20,7 @@ import java.io.FileInputStream; import java.io.FileWriter; import java.io.IOException; +import java.io.InputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -89,32 +90,62 @@ public static Stream mnistAsStream(String imagesPath, String labels * @return List of MNIST samples. * @throws IOException In case of exception. */ - public static List mnistAsList(String imagesPath, String labelsPath, Random rnd, int cnt) throws IOException { + public static List mnistAsList(String imagesPath, String labelsPath, Random rnd, + int cnt) throws IOException { + return mnistAsList(new FileInputStream(imagesPath), new FileInputStream(labelsPath), rnd, cnt); + } + + /** + * Read random {@code count} samples from MNIST dataset from two resources (images and labels) into a stream of + * labeled vectors. + * + * @param imagesPath Path to the resource with images. + * @param labelsPath Path to the resource with labels. + * @param rnd Random numbers generator. + * @param cnt Count of samples to read. + * @return List of MNIST samples. + * @throws IOException In case of exception. + */ + public static List mnistAsListFromResource(String imagesPath, String labelsPath, Random rnd, + int cnt) throws IOException { + return mnistAsList( + MnistUtils.class.getClassLoader().getResourceAsStream(imagesPath), + MnistUtils.class.getClassLoader().getResourceAsStream(labelsPath), + rnd, + cnt + ); + } + /** + * Read random {@code count} samples from MNIST dataset from two resources (images and labels) into a stream of + * labeled vectors. + * + * @param imageStream Stream with image data. + * @param lbStream Stream with label data. + * @param rnd Random numbers generator. + * @param cnt Count of samples to read. + * @return List of MNIST samples. + * @throws IOException In case of exception. + */ + private static List mnistAsList(InputStream imageStream, InputStream lbStream, Random rnd, + int cnt) throws IOException { List res = new ArrayList<>(); - try ( - FileInputStream isImages = new FileInputStream(imagesPath); - FileInputStream isLabels = new FileInputStream(labelsPath) - ) { - read4Bytes(isImages); // Skip magic number. - int numOfImages = read4Bytes(isImages); - int imgHeight = read4Bytes(isImages); - int imgWidth = read4Bytes(isImages); - - read4Bytes(isLabels); // Skip magic number. - read4Bytes(isLabels); // Skip number of labels. - - int numOfPixels = imgHeight * imgWidth; - - for (int imgNum = 0; imgNum < numOfImages; imgNum++) { - double[] pixels = new double[numOfPixels]; - for (int p = 0; p < numOfPixels; p++) { - int c = 128 - isImages.read(); - pixels[p] = ((double)c) / 128; - } - res.add(new MnistLabeledImage(pixels, isLabels.read())); - } + read4Bytes(imageStream); // Skip magic number. + int numOfImages = read4Bytes(imageStream); + int imgHeight = read4Bytes(imageStream); + int imgWidth = read4Bytes(imageStream); + + read4Bytes(lbStream); // Skip magic number. + read4Bytes(lbStream); // Skip number of labels. + + int numOfPixels = imgHeight * imgWidth; + + for (int imgNum = 0; imgNum < numOfImages; imgNum++) { + double[] pixels = new double[numOfPixels]; + for (int p = 0; p < numOfPixels; p++) + pixels[p] = (float)(1.0 * (imageStream.read() & 0xFF) / 255); + res.add(new MnistLabeledImage(pixels, lbStream.read())); } Collections.shuffle(res, rnd); @@ -163,7 +194,7 @@ public static void asLIBSVM(String imagesPath, String labelsPath, String outPath * @param is Input stream. * @throws IOException In case of exception. */ - private static int read4Bytes(FileInputStream is) throws IOException { + private static int read4Bytes(InputStream is) throws IOException { return (is.read() << 24) | (is.read() << 16) | (is.read() << 8) | (is.read()); } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index e26b5b87cfbe6..f9645d80f58d4 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -23,6 +23,7 @@ import org.apache.ignite.ml.dataset.DatasetTestSuite; import org.apache.ignite.ml.environment.EnvironmentTestSuite; import org.apache.ignite.ml.genetic.GAGridTestSuite; +import org.apache.ignite.ml.inference.InferenceTestSuite; import org.apache.ignite.ml.knn.KNNTestSuite; import org.apache.ignite.ml.math.MathImplMainTestSuite; import org.apache.ignite.ml.nn.MLPTestSuite; @@ -59,6 +60,7 @@ EnvironmentTestSuite.class, StructuresTestSuite.class, CommonTestSuite.class, + InferenceTestSuite.class, BaggingTest.class }) public class IgniteMLTestSuite { diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/InferenceTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/InferenceTestSuite.java new file mode 100644 index 0000000000000..d8fc55b8d9c2e --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/InferenceTestSuite.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference; + +import org.apache.ignite.ml.inference.builder.IgniteDistributedInfModelBuilderTest; +import org.apache.ignite.ml.inference.builder.SingleInfModelBuilder; +import org.apache.ignite.ml.inference.builder.SingleInfModelBuilderTest; +import org.apache.ignite.ml.inference.builder.ThreadedInfModelBuilderTest; +import org.apache.ignite.ml.inference.util.DirectorySerializerTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in {@link org.apache.ignite.ml.inference} package. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + IgniteDistributedInfModelBuilderTest.class, + SingleInfModelBuilderTest.class, + ThreadedInfModelBuilderTest.class, + DirectorySerializerTest.class +}) +public class InferenceTestSuite { +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilderTest.java new file mode 100644 index 0000000000000..292319e5b0bc6 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilderTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import org.apache.ignite.Ignite; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests for {@link IgniteDistributedInfModelBuilder} class. + */ +public class IgniteDistributedInfModelBuilderTest extends GridCommonAbstractTest { + /** Number of nodes in grid */ + private static final int NODE_COUNT = 3; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** */ + public void testBuild() throws ExecutionException, InterruptedException { + AsyncInfModelBuilder mdlBuilder = new IgniteDistributedInfModelBuilder(ignite, 1, 1); + + InfModel> infMdl = mdlBuilder.build( + InfModelBuilderTestUtil.getReader(), + InfModelBuilderTestUtil.getParser() + ); + + // TODO: IGNITE-10250: Test hangs sometimes because of Ignite queue issue. + // for (int i = 0; i < 100; i++) + // assertEquals(Integer.valueOf(i), infMdl.predict(i).get()); + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/InfModelBuilderTestUtil.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/InfModelBuilderTestUtil.java new file mode 100644 index 0000000000000..b95e7598905c1 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/InfModelBuilderTestUtil.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.parser.InfModelParser; +import org.apache.ignite.ml.inference.reader.InfModelReader; + +/** + * Util class for model builder tests. + */ +class InfModelBuilderTestUtil { + /** + * Creates dummy model reader used in tests. + * + * @return Dummy model reader used in tests. + */ + static InfModelReader getReader() { + return () -> new byte[0]; + } + + /** + * Creates dummy model parser used in tests. + * + * @return Dummy model parser used in tests. + */ + static InfModelParser getParser() { + return m -> new InfModel() { + @Override public Integer predict(Integer input) { + return input; + } + + @Override public void close() { + // Do nothing. + } + }; + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilderTest.java new file mode 100644 index 0000000000000..22596f2d4cbc5 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilderTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import org.apache.ignite.ml.inference.InfModel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link SingleInfModelBuilder}. + */ +public class SingleInfModelBuilderTest { + /** */ + @Test + public void testBuild() { + SyncInfModelBuilder mdlBuilder = new SingleInfModelBuilder(); + + InfModel infMdl = mdlBuilder.build( + InfModelBuilderTestUtil.getReader(), + InfModelBuilderTestUtil.getParser() + ); + + for (int i = 0; i < 100; i++) + assertEquals(Integer.valueOf(i), infMdl.predict(i)); + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilderTest.java new file mode 100644 index 0000000000000..6d2f344feb311 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilderTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import org.apache.ignite.ml.inference.InfModel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link ThreadedInfModelBuilder} class. + */ +public class ThreadedInfModelBuilderTest { + /** */ + @Test + public void testBuild() throws ExecutionException, InterruptedException { + AsyncInfModelBuilder mdlBuilder = new ThreadedInfModelBuilder(10); + + InfModel> infMdl = mdlBuilder.build( + InfModelBuilderTestUtil.getReader(), + InfModelBuilderTestUtil.getParser() + ); + + for (int i = 0; i < 100; i++) + assertEquals(Integer.valueOf(i), infMdl.predict(i).get()); + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/util/DirectorySerializerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/util/DirectorySerializerTest.java new file mode 100644 index 0000000000000..d2d6b169720c8 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/util/DirectorySerializerTest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.util; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Scanner; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link DirectorySerializer} class. + */ +public class DirectorySerializerTest { + /** Source directory prefix. */ + private static final String SRC_DIRECTORY_PREFIX = "directory_serializer_test_src"; + + /** Destination directory prefix. */ + private static final String DST_DIRECTORY_PREFIX = "directory_serializer_test_dst"; + + /** */ + @Test + public void testSerializeDeserializeWithFile() throws IOException, ClassNotFoundException { + Path src = Files.createTempDirectory(SRC_DIRECTORY_PREFIX); + Path dst = Files.createTempDirectory(DST_DIRECTORY_PREFIX); + try { + File file = new File(src.toString() + "/test.txt"); + Files.createFile(file.toPath()); + try (FileWriter fw = new FileWriter(file)) { + fw.write("Hello, world!"); + fw.flush(); + } + + byte[] serialized = DirectorySerializer.serialize(src); + DirectorySerializer.deserialize(dst, serialized); + + File[] files = dst.toFile().listFiles(); + + assertNotNull(files); + assertEquals(1, files.length); + assertEquals("test.txt", files[0].getName()); + + Scanner scanner = new Scanner(files[0]); + assertTrue(scanner.hasNextLine()); + assertEquals("Hello, world!", scanner.nextLine()); + assertFalse(scanner.hasNextLine()); + } + finally { + DirectorySerializer.deleteDirectory(src); + DirectorySerializer.deleteDirectory(dst); + } + } + + /** */ + @Test + public void testSerializeDeserializeWithDirectory() throws IOException, ClassNotFoundException { + Path src = Files.createTempDirectory(SRC_DIRECTORY_PREFIX); + Path dst = Files.createTempDirectory(DST_DIRECTORY_PREFIX); + try { + Files.createDirectories(Paths.get(src.toString() + "/a/b/")); + File file = new File(src.toString() + "/a/b/test.txt"); + Files.createFile(file.toPath()); + try (FileWriter fw = new FileWriter(file)) { + fw.write("Hello, world!"); + fw.flush(); + } + + byte[] serialized = DirectorySerializer.serialize(src); + DirectorySerializer.deserialize(dst, serialized); + + File[] files = dst.toFile().listFiles(); + + assertNotNull(files); + assertEquals(1, files.length); + assertEquals("a", files[0].getName()); + assertTrue(files[0].isDirectory()); + + files = files[0].listFiles(); + + assertNotNull(files); + assertEquals(1, files.length); + assertEquals("b", files[0].getName()); + assertTrue(files[0].isDirectory()); + + files = files[0].listFiles(); + + assertNotNull(files); + assertEquals(1, files.length); + assertEquals("test.txt", files[0].getName()); + + Scanner scanner = new Scanner(files[0]); + assertTrue(scanner.hasNextLine()); + assertEquals("Hello, world!", scanner.nextLine()); + assertFalse(scanner.hasNextLine()); + } + finally { + DirectorySerializer.deleteDirectory(src); + DirectorySerializer.deleteDirectory(dst); + } + } +}