Skip to content

Commit

Permalink
tf.GraphKeys #359
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Aug 21, 2019
1 parent faa93bf commit 683aeed
Show file tree
Hide file tree
Showing 17 changed files with 83 additions and 109 deletions.
74 changes: 25 additions & 49 deletions README.md
@@ -1,6 +1,6 @@
![logo](docs/assets/tf.net.logo.png)

**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.

[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community)
[![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net)
Expand Down Expand Up @@ -34,40 +34,15 @@ PM> Install-Package TensorFlow.NET
### Install tensorflow binary
### For CPU version
PM> Install-Package SciSharp.TensorFlow.Redist

### For GPU version (CUDA and cuDNN are required)
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
```

Import TF.NET.

```cs
using Tensorflow;
```

Add two constants:
```cs
// Create a Constant op
var a = tf.constant(4.0f);
var b = tf.constant(5.0f);
var c = tf.add(a, b);

using (var sess = tf.Session())
{
var o = sess.run(c);
}
```
Import TF.NET in your project.

Feed placeholder:
```cs
// Create a placeholder op
var a = tf.placeholder(tf.float32);
var b = tf.placeholder(tf.float32);
var c = tf.add(a, b);

using(var sess = tf.Session())
{
var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f));
}
using static Tensorflow.Binding;
```

Linear Regression:
Expand All @@ -91,39 +66,40 @@ var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
var init = tf.global_variables_initializer();

// Start training
with(tf.Session(), sess =>
using(tf.Session())
{
// Run the initializer
sess.run(init);

// Fit all training data
for (int epoch = 0; epoch < training_epochs; epoch++)
{
foreach (var (x, y) in zip<float>(train_X, train_Y))
sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y));
sess.run(optimizer, (X, x), (Y, y));

// Display logs per epoch step
if ((epoch + 1) % display_step == 0)
{
var c = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y));
var c = sess.run(cost, (X, train_X), (Y, train_Y));
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}");
}

Console.WriteLine("Optimization Finished!");
var training_cost = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y));
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}");

// Testing example
var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f);
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f);
Console.WriteLine("Testing... (Mean square loss Comparison)");

var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), new FeedItem(X, test_X), new FeedItem(Y, test_Y));
Console.WriteLine($"Testing cost={testing_cost}");

var diff = Math.Abs((float)training_cost - (float)testing_cost);
Console.WriteLine($"Absolute mean square loss difference: {diff}");
}

Console.WriteLine("Optimization Finished!");
var training_cost = sess.run(cost, (X, train_X), (Y, train_Y));
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}");

// Testing example
var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f);
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f);
Console.WriteLine("Testing... (Mean square loss Comparison)");
var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]),
(X, test_X), (Y, test_Y));
Console.WriteLine($"Testing cost={testing_cost}");
var diff = Math.Abs((float)training_cost - (float)testing_cost);
Console.WriteLine($"Absolute mean square loss difference: {diff}");

return diff < 0.01;
});
```

Expand Down
5 changes: 5 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.graph.cs
Expand Up @@ -14,11 +14,16 @@
limitations under the License.
******************************************************************************/

using static Tensorflow.ops;

namespace Tensorflow
{
public partial class tensorflow
{
public graph_util_impl graph_util => new graph_util_impl();

public GraphKeys GraphKeys { get; } = new GraphKeys();

public Graph get_default_graph()
{
return ops.get_default_graph();
Expand Down
3 changes: 2 additions & 1 deletion src/TensorFlowNET.Core/APIs/tf.variable.cs
Expand Up @@ -15,14 +15,15 @@
******************************************************************************/

using System.Collections.Generic;
using static Tensorflow.Binding;

namespace Tensorflow
{
public partial class tensorflow
{
public VariableV1[] global_variables(string scope = null)
{
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>)
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>)
.ToArray();
}

Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Framework/meta_graph.py.cs
Expand Up @@ -95,7 +95,7 @@ public static MetaGraphDef read_meta_graph_file(string filename)
break;
case KindOneofCase.BytesList:
//var proto_type = ops.get_collection_proto_type(key)
if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
{
foreach (var value in col.Value.BytesList.Value)
{
Expand Down Expand Up @@ -146,7 +146,7 @@ public static MetaGraphDef read_meta_graph_file(string filename)
}
}

var variables = graph.get_collection<VariableV1>(ops.GraphKeys.GLOBAL_VARIABLES,
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
scope: scope_to_prepend_to_names);
var var_list = new Dictionary<string, VariableV1>();
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);
Expand Down Expand Up @@ -180,7 +180,7 @@ public static MetaGraphDef read_meta_graph_file(string filename)
var graph = ops.get_default_graph();

var var_list = new Dictionary<string, RefVariable>();
var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>;
var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>;

if (variables != null)
{
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Layers/Layer.cs
Expand Up @@ -81,7 +81,7 @@ public virtual Tensor apply(Tensor inputs, Tensor training = null)


// Update global default collections.
_add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS });
_add_elements_to_collection(_updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS });

return outputs;
}
Expand Down
8 changes: 4 additions & 4 deletions src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
Expand Up @@ -152,9 +152,9 @@ public override Tensor AddValue(Tensor val)
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
{
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

//TODO: port this chunck of missing code:
/*
Expand Down Expand Up @@ -191,9 +191,9 @@ public override Tensor AddValue(Tensor val)
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
{
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

switch (original_result)
{
Expand Down
Expand Up @@ -195,7 +195,7 @@ private void _init_from_proto(WhileContextDef context_def, string import_scope =
// their associated TensorArrays for calling the body.
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
var body_result = body(packed_vars_for_body[0]);
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

// Store body_result to keep track of TensorArrays returned by body
var original_body_result = new[] { body_result };
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/Losses/Util.cs
Expand Up @@ -2,7 +2,7 @@
{
public class Util
{
public static void add_loss(Tensor loss, string loss_collection = ops.GraphKeys.LOSSES)
public static void add_loss(Tensor loss, string loss_collection = "losses")
{
if (!string.IsNullOrEmpty(loss_collection))
ops.add_to_collection(loss_collection, loss);
Expand Down
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
Expand Up @@ -22,7 +22,7 @@ namespace Tensorflow
public class LossesImpl
{
public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null,
string loss_collection = ops.GraphKeys.LOSSES, string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS)
string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS)
{
return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate
{
Expand Down Expand Up @@ -101,7 +101,7 @@ public Tensor _num_present(Tensor losses, Tensor weights, bool per_batch = false
Tensor logits,
float weights = 1.0f,
string scope = null,
string loss_collection= ops.GraphKeys.LOSSES,
string loss_collection= "losses",
string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS)
{
return tf_with(ops.name_scope(scope,
Expand Down
10 changes: 5 additions & 5 deletions src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
Expand Up @@ -431,8 +431,8 @@ public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name

merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges);

ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);

return merges[0];
});
Expand Down Expand Up @@ -479,8 +479,8 @@ public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name

merges = _convert_flows_to_tensorarrays(orig_res_t, merges);

ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);

return merges;
});
Expand Down Expand Up @@ -596,7 +596,7 @@ public static Tensor ZerosLikeOutsideLoop(Operation op, int index)
swap_memory: swap_memory);
if (loop_context.outer_context == null)
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context);
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context);
var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
return_same_structure);
Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Summaries/Summary.cs
Expand Up @@ -33,11 +33,11 @@ public Tensor histogram(string name, Tensor tensor, string[] collections = null,
{
var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary");
var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope);
collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES });
collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES });
return val;
}

public Tensor merge_all(string key = ops.GraphKeys.SUMMARIES, string scope= null, string name= null)
public Tensor merge_all(string key = "summaries", string scope= null, string name= null)
{
var summary_ops = ops.get_collection(key, scope: scope);
if (summary_ops == null)
Expand Down Expand Up @@ -67,7 +67,7 @@ public Tensor scalar(string name, Tensor tensor, string[] collections = null, st
{
var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor });
var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope);
collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES });
collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES });
return val;
}

Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Train/Optimizer.cs
Expand Up @@ -198,7 +198,7 @@ public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, Re
if (!tf.context.executing_eagerly())
{
var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>;
var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>;
if (train_op != null && train_op.Contains(apply_updates))
train_op.Add(apply_updates);
}
Expand Down Expand Up @@ -359,7 +359,7 @@ private _OptimizableVariable _get_processor(RefVariable v)


var tmp = variables.trainable_variables();
var vars = ops.get_collection<RefVariable>(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
switch (tmp)
{
case List<RefVariable> values:
Expand All @@ -370,7 +370,7 @@ private _OptimizableVariable _get_processor(RefVariable v)
break;
}

var_list = var_list.Concat(ops.get_collection<RefVariable>(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList();
var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList();
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
var var_refs = processors.Select(x => x.target()).ToArray();

Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Variables/RefVariable.cs
Expand Up @@ -121,16 +121,16 @@ private void _init_from_proto(VariableDef variable_def, string import_scope = ""

if(collections == null)
{
collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES };
collections = new List<string> { tf.GraphKeys.GLOBAL_VARIABLES };
}

// Store the graph key so optimizers know how to only retrieve variables from
// this graph.
_graph_key = ops.get_default_graph().graph_key;

_trainable = trainable;
if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);

ops.init_scope();
var values = init_from_fn ? new object[0] : new object[] { initial_value };
Expand Down
9 changes: 5 additions & 4 deletions src/TensorFlowNET.Core/Variables/variables.py.cs
Expand Up @@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow
{
Expand All @@ -28,7 +29,7 @@ public class variables
/// <returns></returns>
public static object trainable_variables()
{
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES);
return ops.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES);
}

/// <summary>
Expand All @@ -40,11 +41,11 @@ public static VariableV1[] _all_saveable_objects(string scope = "")
{
var all = new List<VariableV1>();

var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope);
if(collection != null)
all.AddRange(collection as List<VariableV1>);

collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope);
collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope);
if (collection != null)
all.AddRange(collection as List<VariableV1>);

Expand All @@ -64,7 +65,7 @@ public static VariableV1[] _all_saveable_objects(string scope = "")
/// <returns>A list of `Variable` objects.</returns>
public static List<VariableV1> global_variables(string scope = null)
{
var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope);

return result == null ? new List<VariableV1>() : result as List<VariableV1>;
}
Expand Down

0 comments on commit 683aeed

Please sign in to comment.