From 4481d19055f441e90f3b007b76449b6edc337021 Mon Sep 17 00:00:00 2001 From: Bo Peng Date: Wed, 13 Mar 2019 09:44:55 -0500 Subject: [PATCH 1/4] implemented math_ops.square --- src/TensorFlowNET.Core/Operations/gen_math_ops.cs | 7 +++++++ src/TensorFlowNET.Core/Operations/math_ops.py.cs | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index e3b8022ac..3210c7425 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -48,6 +48,13 @@ public static Tensor squared_difference(Tensor x, Tensor y, string name = null) return _op.outputs[0]; } + public static Tensor square(Tensor x, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x }); + + return _op.outputs[0]; + } + public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= "") { var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs index d6b175097..a114234ad 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs @@ -57,7 +57,7 @@ public static Tensor square_difference(Tensor x, Tensor y, string name = null) public static Tensor square(Tensor x, string name = null) { - throw new NotImplementedException(); + return gen_math_ops.square(x, name); } /// From 115d4892dcaf24cf20b86e765b5a278ba057b148 Mon Sep 17 00:00:00 2001 From: Bo Peng Date: Wed, 13 Mar 2019 10:52:37 -0500 Subject: [PATCH 2/4] implemented _log_prob in normal.py --- .../Distributions/distribution.py.cs | 31 ++++++++++++++++--- .../Operations/Distributions/normal.py.cs | 9 +++++- .../Operations/gen_math_ops.cs | 19 ++++++++++++ .../Operations/math_ops.py.cs | 5 +++ 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs index 688169632..74f1fe3ee 100644 --- a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs +++ b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs @@ -35,7 +35,7 @@ public class Distribution : _BaseDistribution /// Python `str` prepended to names of ops created by this function. /// log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`. - /* + public Tensor log_prob(Tensor value, string name = "log_prob") { return _call_log_prob(value, name); @@ -45,18 +45,39 @@ private Tensor _call_log_prob (Tensor value, string name) { with(ops.name_scope(name, "moments", new { value }), scope => { - value = _convert_to_tensor(value, "value", _dtype); + try + { + return _log_prob(value); + } + catch (Exception e1) + { + try + { + return math_ops.log(_prob(value)); + } catch (Exception e2) + { + throw new NotImplementedException(); + } + } }); + return null; + } + private Tensor _log_prob(Tensor value) + { throw new NotImplementedException(); - } - private Tensor _convert_to_tensor(Tensor value, string name = null, TF_DataType preferred_dtype) + private Tensor _prob(Tensor value) { throw new NotImplementedException(); } - */ + + public TF_DataType dtype() + { + return this._dtype; + } + /// /// Constructs the `Distribution' diff --git a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs index 6c77450a0..e82b2ddd7 100644 --- a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs +++ b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using Tensorflow; @@ -80,7 +81,7 @@ public Tensor _batch_shape() private Tensor _log_prob(Tensor x) { - return _log_unnormalized_prob(_z(x)); + return _log_unnormalized_prob(_z(x)) -_log_normalization(); } private Tensor _log_unnormalized_prob (Tensor x) @@ -92,5 +93,11 @@ private Tensor _z (Tensor x) { return (x - this._loc) / this._scale; } + + private Tensor _log_normalization() + { + Tensor t = new Tensor(Math.Log(2.0 * Math.PI)); + return 0.5 * t + math_ops.log(scale()); + } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 3210c7425..2f307f6d1 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -48,6 +48,12 @@ public static Tensor squared_difference(Tensor x, Tensor y, string name = null) return _op.outputs[0]; } + /// + /// Computes square of x element-wise. + /// + /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`. + /// A name for the operation (optional). + /// A `Tensor`. Has the same type as `x`. public static Tensor square(Tensor x, string name = null) { var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x }); @@ -55,6 +61,19 @@ public static Tensor square(Tensor x, string name = null) return _op.outputs[0]; } + /// + /// Computes natural logarithm of x element-wise. + /// + /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`. + /// name: A name for the operation (optional). + /// A `Tensor`. Has the same type as `x`. + public static Tensor log(Tensor x, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x }); + + return _op.outputs[0]; + } + public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= "") { var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs index a114234ad..0909b1870 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs @@ -60,6 +60,11 @@ public static Tensor square(Tensor x, string name = null) return gen_math_ops.square(x, name); } + public static Tensor log(Tensor x, string name = null) + { + return gen_math_ops.log(x, name); + } + /// /// Helper function for reduction ops. /// From a53908d5907553e11e6d61701d56b8f494f79294 Mon Sep 17 00:00:00 2001 From: Bo Peng Date: Wed, 13 Mar 2019 14:40:04 -0500 Subject: [PATCH 3/4] implemented math_ops: 1) reduce_logsumexp, 2) reduce_max, 3) log, 4) square --- src/TensorFlowNET.Core/APIs/tf.reshape.cs | 14 ++++++ src/TensorFlowNET.Core/APIs/tf.tile.cs | 14 ++++++ .../Operations/array_ops.py.cs | 5 +++ .../Operations/gen_math_ops.cs | 33 ++++++++++++++ .../Operations/math_ops.py.cs | 45 +++++++++++++++++++ .../NaiveBayesClassifier.cs | 29 ++++++++++-- 6 files changed, 136 insertions(+), 4 deletions(-) create mode 100644 src/TensorFlowNET.Core/APIs/tf.reshape.cs create mode 100644 src/TensorFlowNET.Core/APIs/tf.tile.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs new file mode 100644 index 000000000..41861968e --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static Tensor reshape(Tensor tensor, + Tensor shape, + string name = null) => gen_array_ops.reshape(tensor, shape, name); + + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs new file mode 100644 index 000000000..16bd48fdf --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static Tensor tile(Tensor input, + Tensor multiples, + string name = null) => gen_array_ops.tile(input, multiples, name); + + } +} diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index ffdc0d8f5..c5641f066 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -66,6 +66,11 @@ public static Tensor rank(Tensor input, string name = null) public static Tensor ones_like(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => ones_like_impl(tensor, dtype, name, optimize); + public static Tensor reshape(Tensor tensor, Tensor shape, string name = null) + { + return gen_array_ops.reshape(tensor, shape, null); + } + private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) { return with(ops.name_scope(name, "ones_like", new { tensor }), scope => diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 2f307f6d1..9d1e8788c 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -61,6 +61,32 @@ public static Tensor square(Tensor x, string name = null) return _op.outputs[0]; } + /// + /// Returns which elements of x are finite. + /// + /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`. + /// A name for the operation (optional). + /// A `Tensor` of type `bool`. + public static Tensor is_finite(Tensor x, string name = null) + { + var _op = _op_def_lib._apply_op_helper("IsFinite", name, args: new { x }); + + return _op.outputs[0]; + } + + /// + /// Computes exponential of x element-wise. \\(y = e^x\\). + /// + /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`. + /// A name for the operation (optional). + /// A `Tensor`. Has the same type as `x`. + public static Tensor exp(Tensor x, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Exp", name, args: new { x }); + + return _op.outputs[0]; + } + /// /// Computes natural logarithm of x element-wise. /// @@ -160,6 +186,13 @@ public static Tensor maximum(T1 x, T2 y, string name = null) return _op.outputs[0]; } + public static Tensor _max(Tensor input, int[] axis, bool keep_dims=false, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }); + + return _op.outputs[0]; + } + public static Tensor pow(Tx x, Ty y, string name = null) { var _op = _op_def_lib._apply_op_helper("Pow", name, args: new { x, y }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs index 0909b1870..e35a094bd 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs @@ -87,6 +87,51 @@ public static Tensor reduced_shape(Tensor input_shape, Tensor axes) return gen_data_flow_ops.dynamic_stitch(a1, a2); } + /// + /// Computes log(sum(exp(elements across dimensions of a tensor))). + /// Reduces `input_tensor` along the dimensions given in `axis`. + /// Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + /// entry in `axis`. If `keepdims` is true, the reduced dimensions + /// are retained with length 1. + + /// If `axis` has no entries, all dimensions are reduced, and a + /// tensor with a single element is returned. + + /// This function is more numerically stable than log(sum(exp(input))). It avoids + /// overflows caused by taking the exp of large inputs and underflows caused by + /// taking the log of small inputs. + /// + /// The tensor to reduce. Should have numeric type. + /// The dimensions to reduce. If `None` (the default), reduces all + /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`. + /// + /// The reduced tensor. + public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + { + with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope => + { + var raw_max = reduce_max(input_tensor, axis, true); + var my_max = array_ops.stop_gradient(array_ops.where(gen_math_ops.is_finite(raw_max), raw_max, array_ops.zeros_like(raw_max))); + var result = gen_math_ops.log( + reduce_sum( + gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)), + new Tensor(axis), + keepdims)); + if (!keepdims) + { + my_max = array_ops.reshape(my_max, array_ops.shape(result)); + } + result = gen_math_ops.add(result, my_max); + return _may_reduce_to_scalar(keepdims, axis, result); + }); + return null; + } + + public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + { + return _may_reduce_to_scalar(keepdims, axis, gen_math_ops._max(input_tensor, (int[])_ReductionDims(input_tensor, axis), keepdims, name)); + } + /// /// Casts a tensor to type `int32`. /// diff --git a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs index 131ab42cb..a8cc8f11d 100644 --- a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs +++ b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs @@ -12,6 +12,7 @@ namespace TensorFlowNET.Examples /// public class NaiveBayesClassifier : Python, IExample { + public Normal dist { get; set; } public void Run() { np.array(1.0f, 1.0f); @@ -72,16 +73,36 @@ public void fit(NDArray X, NDArray y) // Create a 3x2 univariate normal distribution with the // Known mean and variance var dist = tf.distributions.Normal(mean, tf.sqrt(variance)); - + this.dist = dist; } public void predict (NDArray X) { - // assert self.dist is not None - // nb_classes, nb_features = map(int, self.dist.scale.shape) + if (dist == null) + { + throw new ArgumentNullException("cant not find the model (normal distribution)!"); + } + int nb_classes = (int) dist.scale().shape[0]; + int nb_features = (int)dist.scale().shape[1]; + // Conditional probabilities log P(x|c) with shape + // (nb_samples, nb_classes) + Tensor tile = tf.tile(new Tensor(X), new Tensor(new int[] { -1, nb_classes, nb_features })); + Tensor r = tf.reshape(tile, new Tensor(new int[] { -1, nb_classes, nb_features })); + var cond_probs = tf.reduce_sum(dist.log_prob(r)); + // uniform priors + var priors = np.log(np.array((1.0 / nb_classes) * nb_classes)); - throw new NotFiniteNumberException(); + // posterior log probability, log P(c) + log P(x|c) + var joint_likelihood = tf.add(new Tensor(priors), cond_probs); + // normalize to get (log)-probabilities + /* + var norm_factor = tf.reduce_logsumexp(joint_likelihood, axis = 1, keep_dims = True) + var log_prob = joint_likelihood - norm_factor; + // exp to get the actual probabilities + return tf.exp(log_prob) + */ + throw new NotImplementedException(); } } } From f8b618cba9d76377a47b49a83e7c0fd353a217cf Mon Sep 17 00:00:00 2001 From: Bo Peng Date: Wed, 13 Mar 2019 14:46:09 -0500 Subject: [PATCH 4/4] implemented naive bayes predict API --- src/TensorFlowNET.Core/APIs/tf.exp.cs | 13 +++++++++++++ .../APIs/tf.reduce_logsumexp.cs | 15 +++++++++++++++ .../NaiveBayesClassifier.cs | 10 ++++------ 3 files changed, 32 insertions(+), 6 deletions(-) create mode 100644 src/TensorFlowNET.Core/APIs/tf.exp.cs create mode 100644 src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.exp.cs b/src/TensorFlowNET.Core/APIs/tf.exp.cs new file mode 100644 index 000000000..885cd96d2 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.exp.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static Tensor exp(Tensor x, + string name = null) => gen_math_ops.exp(x, name); + + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs new file mode 100644 index 000000000..32140389b --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static Tensor reduce_logsumexp(Tensor input_tensor, + int[] axis = null, + bool keepdims = false, + string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name); + + } +} diff --git a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs index a8cc8f11d..d33005191 100644 --- a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs +++ b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs @@ -76,7 +76,7 @@ public void fit(NDArray X, NDArray y) this.dist = dist; } - public void predict (NDArray X) + public Tensor predict (NDArray X) { if (dist == null) { @@ -96,13 +96,11 @@ public void predict (NDArray X) // posterior log probability, log P(c) + log P(x|c) var joint_likelihood = tf.add(new Tensor(priors), cond_probs); // normalize to get (log)-probabilities - /* - var norm_factor = tf.reduce_logsumexp(joint_likelihood, axis = 1, keep_dims = True) + + var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, true); var log_prob = joint_likelihood - norm_factor; // exp to get the actual probabilities - return tf.exp(log_prob) - */ - throw new NotImplementedException(); + return tf.exp(log_prob); } } }