Skip to content

Commit

Permalink
Session-based saving of Tensorflow models, based on Karl Lessard's code
Browse files Browse the repository at this point in the history
karllessard@bdb0420

Further code discussion is at https://groups.google.com/a/tensorflow.org/forum/#!msg/jvm/gGKO-hVS4Pc/LF4rLJOdAQAJ

This commit manually merges Karl's code to mainline tensorflow-java.
However, Ops.java is unchanged here, because there were no functional changes in Ops.java.
This commit fixes tensorflow#100

Later, model-saving may be expected to be function-based, not session-based.
Function-based model saving is discussed in tensorflow#101
  • Loading branch information
schaumba committed Aug 20, 2020
1 parent d9a4105 commit 9a1f535
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,16 @@
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_WhileParams;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.NoOp;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.train.Restore;
import org.tensorflow.op.train.Save;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.proto.util.SaverDef;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.types.TString;


/**
Expand All @@ -67,6 +76,11 @@ public Graph() {
this.nativeHandle = nativeHandle;
}

Graph(TF_Graph nativeHandle, SaverDef saverDef) {
this(nativeHandle);
this.saverDef = saverDef;
}

/**
* Release resources associated with the Graph.
*
Expand Down Expand Up @@ -287,6 +301,17 @@ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
return addGradients(null, new Output<?>[] {y}, x, null);
}

public SaverDef saverDef() {
if (saverDef == null) {
synchronized (this) {
if (saverDef == null) {
saverDef = addVariableSaver(this);
}
}
}
return saverDef;
}

/**
* Used to instantiate an abstract class which overrides the buildSubgraph method to build a
* conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to
Expand Down Expand Up @@ -405,6 +430,7 @@ public Output<?>[] whileLoop(
private final Object nativeHandleLock = new Object();
private TF_Graph nativeHandle;
private int refcount = 0;
private SaverDef saverDef;

private final List<Op> initializers = new ArrayList<>();

Expand Down Expand Up @@ -726,6 +752,53 @@ private static Object[] whileLoop(
}
}

private static SaverDef addVariableSaver(Graph graph) {
Ops tf = Ops.create(graph).withSubScope("save");

List<String> varNames = new ArrayList<>();
List<Operand<?>> varOutputs = new ArrayList<>();
List<DataType<?>> varTypes = new ArrayList<>();

for (Iterator<Operation> iter = graph.operations(); iter.hasNext();) {
Operation op = iter.next();
if (op.type().equals("VariableV2")) {
varNames.add(op.name());
varOutputs.add(op.output(0));
varTypes.add(op.output(0).dataType());
}
}

// FIXME Need an easier way to initialize an NdArray from a list
String[] tmp = new String[varNames.size()];
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);

Placeholder<TString> saveFilename = tf.placeholder(TString.DTYPE);
Save saveVariables = tf.train.save(
saveFilename,
varNamesTensor,
varSlices,
varOutputs
);
Restore restoreVariables = tf.train.restore(
saveFilename,
varNamesTensor,
varSlices,
varTypes
);
List<Op> restoreOps = new ArrayList<>(varOutputs.size());
for (int i = 0; i < varOutputs.size(); ++i) {
restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
}
NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp();

return SaverDef.newBuilder()
.setFilenameTensorName(saveFilename.op().name())
.setSaveTensorName(saveVariables.op().name())
.setRestoreOpName(restoreAll.op().name())
.build();
}

static {
TensorFlow.init();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;

import com.google.protobuf.InvalidProtocolBufferException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
Expand All @@ -30,8 +39,16 @@
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.proto.framework.SavedModel;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.proto.framework.TensorShapeProto;
import org.tensorflow.proto.framework.TensorShapeProto.Dim;
import org.tensorflow.ndarray.Shape;

/**
* SavedModelBundle represents a model loaded from storage.
Expand Down Expand Up @@ -94,6 +111,78 @@ private Loader(String exportDir) {
private RunOptions runOptions = null;
}

public static final class Exporter {

public Exporter withTags(String... tags) {
this.tags.addAll(Arrays.asList(tags));
return this;
}

public Exporter withSignature(Map<String, Operand<?>> inputs, Map<String, Operand<?>> outputs) {
return withSignature("serving_default", "tensorflow/serving/predict", inputs, outputs);
}

public Exporter withSignature(String signatureName, String methodName, Map<String, Operand<?>> inputs, Map<String, Operand<?>> outputs) {
SignatureDef.Builder signatureDefBuilder = SignatureDef.newBuilder();
for (Map.Entry<String, Operand<?>> inputEntry : inputs.entrySet()) {
signatureDefBuilder.putInputs(inputEntry.getKey(), toTensorInfo(inputEntry.getValue().asOutput()));
}
for (Map.Entry<String, Operand<?>> outputEntry : outputs.entrySet()) {
signatureDefBuilder.putOutputs(outputEntry.getKey(), toTensorInfo(outputEntry.getValue().asOutput()));
}
signatureDefBuilder.setMethodName(methodName);
metaGraphDefBuilder.putSignatureDef(signatureName, signatureDefBuilder.build());
return this;
}

public void export(Session session) throws IOException {
Graph graph = session.graph();
if (tags.isEmpty()) {
tags.add("serve");
}
// Important: it is imperative to retrieve the graphDef after the saverDef, as the former might add new ops. FIXME Better way for handling this?
MetaGraphDef metaGraphDef = metaGraphDefBuilder
.setSaverDef(graph.saverDef())
.setGraphDef(graph.toGraphDef())
.setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(tags))
.build();

// Make sure saved model directories exist
Path variableDir = Paths.get(exportDir, "variables");
variableDir.toFile().mkdirs();

// Save variable state, this must be done before we retrieve the `SaverDef` from the graph
session.save(variableDir.resolve("variables").toString());

// Save graph
SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build();
try (OutputStream file = new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) {
savedModelDef.writeTo(file);
}
}

Exporter(String exportDir) {
this.exportDir = exportDir;
}

private final String exportDir;
private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder();
private final List<String> tags = new ArrayList<>();

private static TensorInfo toTensorInfo(Output<?> operand) {
Shape shape = operand.shape();
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
for (int i = 0; i < shape.numDimensions(); ++i) {
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
}
return TensorInfo.newBuilder()
.setDtype(DataType.forNumber(operand.dataType().nativeCode()))
.setTensorShape(tensorShapeBuilder)
.setName(operand.op().name() + ":" + operand.index())
.build();
}
}

/**
* Load a saved model from an export directory. The model that is being loaded should be created
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
Expand Down Expand Up @@ -125,6 +214,10 @@ public static Loader loader(String exportDir) {
return new Loader(exportDir);
}

public static Exporter exporter(String exportDir) {
return new Exporter(exportDir);
}

/**
* Returns the <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto">MetaGraphDef
Expand Down Expand Up @@ -176,7 +269,7 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef
*/
private static SavedModelBundle fromHandle(
TF_Graph graphHandle, TF_Session sessionHandle, MetaGraphDef metaGraphDef) {
Graph graph = new Graph(graphHandle);
Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef());
Session session = new Session(graph, sessionHandle);
return new SavedModelBundle(graph, session, metaGraphDef);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

import java.util.ArrayList;
import java.util.List;
import org.tensorflow.proto.util.SaverDef;
import org.tensorflow.types.TString;

import static org.tensorflow.Graph.resolveOutputs;
import static org.tensorflow.internal.c_api.global.tensorflow.*;
Expand Down Expand Up @@ -444,6 +446,14 @@ public void run(Op op) {
runner().addTarget(op.op()).run();
}

public void save(String prefix) {
SaverDef saverDef = graph.saverDef();
runner()
.addTarget(saverDef.getSaveTensorName())
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
.run();
}

/**
* Output tensors and metadata obtained when executing a session.
*
Expand All @@ -463,6 +473,10 @@ public static final class Run {
public RunMetadata metadata;
}

Graph graph() {
return graph;
}

private final Graph graph;
private final Graph.Reference graphRef;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,39 @@

package org.tensorflow;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import org.junit.jupiter.api.Test;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Init;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.core.Variable;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.types.TFloat32;

/** Unit tests for {@link org.tensorflow.SavedModelBundle}. */
public class SavedModelBundleTest {

private static final float EPSILON = 1e-7f;
private static final String SAVED_MODEL_PATH;

static {
try {
SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString();
Expand Down Expand Up @@ -72,6 +90,73 @@ public void loader() {
}
}

@Test
public void save() throws IOException {
Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");
float reducedSum;
FloatNdArray xValue = StdArrays.ndCopyOf(new float[][] { { 0, 1, 2 }, { 3, 4, 5 } });
Shape xyShape = Shape.of(2, 3L);
try (Graph g = new Graph()) {
Ops tf = Ops.create(g);
Placeholder<TFloat32> x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xyShape));
Variable<TFloat32> y = tf.variable(tf.random.randomUniform(tf.constant(xyShape), TFloat32.DTYPE));
ReduceSum<TFloat32> z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1));
Init init = tf.init();

try (Session s = new Session(g)) {
s.run(init);
try (Tensor<TFloat32> xTensor = TFloat32.tensorOf(xValue);
Tensor<TFloat32> zTensor = s.runner()
.feed(x, xTensor)
.fetch(z)
.run()
.get(0).expect(TFloat32.DTYPE)) {
reducedSum = zTensor.data().getFloat();
}
SavedModelBundle.exporter(testFolder.toString())
.withTags("test")
.withSignature(Collections.singletonMap("input", x), Collections.singletonMap("reducedSum", z))
.export(s);
}
}
assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index"))));
assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001"))));
assertTrue(Files.exists(testFolder.resolve("saved_model.pb")));

// Reload the model just saved and validate its data
try (SavedModelBundle savedModel = SavedModelBundle.load(testFolder.toString(), "test")) {
assertNotNull(savedModel.metaGraphDef());
assertNotNull(savedModel.metaGraphDef().getSaverDef());
assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());

SignatureDef signature = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
assertNotNull(signature);
assertEquals(1, signature.getInputsCount());
assertEquals(1, signature.getOutputsCount());

TensorInfo inputInfo = signature.getInputsMap().get("input");
assertNotNull(inputInfo);
assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount());
for (int i = 0; i < xyShape.numDimensions(); ++i) {
assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize());
}

TensorInfo outputInfo = signature.getOutputsMap().get("reducedSum");
assertNotNull(outputInfo);
assertEquals(0, outputInfo.getTensorShape().getDimCount());

// Run the saved model just loaded and make sure it returns the same result as before
try (Tensor<TFloat32> xTensor = TFloat32.tensorOf(xValue);
Tensor<TFloat32> zTensor = savedModel.session().runner()
.feed(inputInfo.getName(), xTensor)
.fetch(outputInfo.getName())
.run()
.get(0).expect(TFloat32.DTYPE)) {
assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON);
}
}
}

private static RunOptions sillyRunOptions() {
return RunOptions.newBuilder()
.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
Expand Down

0 comments on commit 9a1f535

Please sign in to comment.