diff --git a/src/TensorFlowNET.Core/APIs/tf.exp.cs b/src/TensorFlowNET.Core/APIs/tf.exp.cs
new file mode 100644
index 000000000..885cd96d2
--- /dev/null
+++ b/src/TensorFlowNET.Core/APIs/tf.exp.cs
@@ -0,0 +1,13 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public static partial class tf
+ {
+ public static Tensor exp(Tensor x,
+ string name = null) => gen_math_ops.exp(x, name);
+
+ }
+}
diff --git a/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs
new file mode 100644
index 000000000..32140389b
--- /dev/null
+++ b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public static partial class tf
+ {
+ public static Tensor reduce_logsumexp(Tensor input_tensor,
+ int[] axis = null,
+ bool keepdims = false,
+ string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name);
+
+ }
+}
diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
new file mode 100644
index 000000000..41861968e
--- /dev/null
+++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public static partial class tf
+ {
+ public static Tensor reshape(Tensor tensor,
+ Tensor shape,
+ string name = null) => gen_array_ops.reshape(tensor, shape, name);
+
+ }
+}
diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs
new file mode 100644
index 000000000..16bd48fdf
--- /dev/null
+++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public static partial class tf
+ {
+ public static Tensor tile(Tensor input,
+ Tensor multiples,
+ string name = null) => gen_array_ops.tile(input, multiples, name);
+
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs
index 688169632..74f1fe3ee 100644
--- a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs
+++ b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs
@@ -35,7 +35,7 @@ public class Distribution : _BaseDistribution
/// Python `str` prepended to names of ops created by this function.
/// log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`.
- /*
+
public Tensor log_prob(Tensor value, string name = "log_prob")
{
return _call_log_prob(value, name);
@@ -45,18 +45,39 @@ private Tensor _call_log_prob (Tensor value, string name)
{
with(ops.name_scope(name, "moments", new { value }), scope =>
{
- value = _convert_to_tensor(value, "value", _dtype);
+ try
+ {
+ return _log_prob(value);
+ }
+ catch (Exception e1)
+ {
+ try
+ {
+ return math_ops.log(_prob(value));
+ } catch (Exception e2)
+ {
+ throw new NotImplementedException();
+ }
+ }
});
+ return null;
+ }
+ private Tensor _log_prob(Tensor value)
+ {
throw new NotImplementedException();
-
}
- private Tensor _convert_to_tensor(Tensor value, string name = null, TF_DataType preferred_dtype)
+ private Tensor _prob(Tensor value)
{
throw new NotImplementedException();
}
- */
+
+ public TF_DataType dtype()
+ {
+ return this._dtype;
+ }
+
///
/// Constructs the `Distribution'
diff --git a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs
index 6c77450a0..e82b2ddd7 100644
--- a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs
+++ b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs
@@ -1,3 +1,4 @@
+using System;
using System.Collections.Generic;
using Tensorflow;
@@ -80,7 +81,7 @@ public Tensor _batch_shape()
private Tensor _log_prob(Tensor x)
{
- return _log_unnormalized_prob(_z(x));
+ return _log_unnormalized_prob(_z(x)) -_log_normalization();
}
private Tensor _log_unnormalized_prob (Tensor x)
@@ -92,5 +93,11 @@ private Tensor _z (Tensor x)
{
return (x - this._loc) / this._scale;
}
+
+ private Tensor _log_normalization()
+ {
+ Tensor t = new Tensor(Math.Log(2.0 * Math.PI));
+ return 0.5 * t + math_ops.log(scale());
+ }
}
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index ffdc0d8f5..c5641f066 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -66,6 +66,11 @@ public static Tensor rank(Tensor input, string name = null)
public static Tensor ones_like(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> ones_like_impl(tensor, dtype, name, optimize);
+ public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
+ {
+ return gen_array_ops.reshape(tensor, shape, null);
+ }
+
private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true)
{
return with(ops.name_scope(name, "ones_like", new { tensor }), scope =>
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index e3b8022ac..9d1e8788c 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -48,6 +48,58 @@ public static Tensor squared_difference(Tensor x, Tensor y, string name = null)
return _op.outputs[0];
}
+ ///
+ /// Computes square of x element-wise.
+ ///
+ /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
+ /// A name for the operation (optional).
+ /// A `Tensor`. Has the same type as `x`.
+ public static Tensor square(Tensor x, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x });
+
+ return _op.outputs[0];
+ }
+
+ ///
+ /// Returns which elements of x are finite.
+ ///
+ /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`.
+ /// A name for the operation (optional).
+ /// A `Tensor` of type `bool`.
+ public static Tensor is_finite(Tensor x, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("IsFinite", name, args: new { x });
+
+ return _op.outputs[0];
+ }
+
+ ///
+ /// Computes exponential of x element-wise. \\(y = e^x\\).
+ ///
+ /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`.
+ /// A name for the operation (optional).
+ /// A `Tensor`. Has the same type as `x`.
+ public static Tensor exp(Tensor x, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Exp", name, args: new { x });
+
+ return _op.outputs[0];
+ }
+
+ ///
+ /// Computes natural logarithm of x element-wise.
+ ///
+ /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`.
+ /// name: A name for the operation (optional).
+ /// A `Tensor`. Has the same type as `x`.
+ public static Tensor log(Tensor x, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x });
+
+ return _op.outputs[0];
+ }
+
public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= "")
{
var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate });
@@ -134,6 +186,13 @@ public static Tensor maximum(T1 x, T2 y, string name = null)
return _op.outputs[0];
}
+ public static Tensor _max(Tensor input, int[] axis, bool keep_dims=false, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims });
+
+ return _op.outputs[0];
+ }
+
public static Tensor pow(Tx x, Ty y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Pow", name, args: new { x, y });
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs
index d6b175097..e35a094bd 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs
@@ -57,7 +57,12 @@ public static Tensor square_difference(Tensor x, Tensor y, string name = null)
public static Tensor square(Tensor x, string name = null)
{
- throw new NotImplementedException();
+ return gen_math_ops.square(x, name);
+ }
+
+ public static Tensor log(Tensor x, string name = null)
+ {
+ return gen_math_ops.log(x, name);
}
///
@@ -82,6 +87,51 @@ public static Tensor reduced_shape(Tensor input_shape, Tensor axes)
return gen_data_flow_ops.dynamic_stitch(a1, a2);
}
+ ///
+ /// Computes log(sum(exp(elements across dimensions of a tensor))).
+ /// Reduces `input_tensor` along the dimensions given in `axis`.
+ /// Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ /// entry in `axis`. If `keepdims` is true, the reduced dimensions
+ /// are retained with length 1.
+
+ /// If `axis` has no entries, all dimensions are reduced, and a
+ /// tensor with a single element is returned.
+
+ /// This function is more numerically stable than log(sum(exp(input))). It avoids
+ /// overflows caused by taking the exp of large inputs and underflows caused by
+ /// taking the log of small inputs.
+ ///
+ /// The tensor to reduce. Should have numeric type.
+ /// The dimensions to reduce. If `None` (the default), reduces all
+ /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.
+ ///
+ /// The reduced tensor.
+ public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ {
+ with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope =>
+ {
+ var raw_max = reduce_max(input_tensor, axis, true);
+ var my_max = array_ops.stop_gradient(array_ops.where(gen_math_ops.is_finite(raw_max), raw_max, array_ops.zeros_like(raw_max)));
+ var result = gen_math_ops.log(
+ reduce_sum(
+ gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)),
+ new Tensor(axis),
+ keepdims));
+ if (!keepdims)
+ {
+ my_max = array_ops.reshape(my_max, array_ops.shape(result));
+ }
+ result = gen_math_ops.add(result, my_max);
+ return _may_reduce_to_scalar(keepdims, axis, result);
+ });
+ return null;
+ }
+
+ public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ {
+ return _may_reduce_to_scalar(keepdims, axis, gen_math_ops._max(input_tensor, (int[])_ReductionDims(input_tensor, axis), keepdims, name));
+ }
+
///
/// Casts a tensor to type `int32`.
///
diff --git a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs
index 131ab42cb..d33005191 100644
--- a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs
+++ b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs
@@ -12,6 +12,7 @@ namespace TensorFlowNET.Examples
///
public class NaiveBayesClassifier : Python, IExample
{
+ public Normal dist { get; set; }
public void Run()
{
np.array(1.0f, 1.0f);
@@ -72,16 +73,34 @@ public void fit(NDArray X, NDArray y)
// Create a 3x2 univariate normal distribution with the
// Known mean and variance
var dist = tf.distributions.Normal(mean, tf.sqrt(variance));
-
+ this.dist = dist;
}
- public void predict (NDArray X)
+ public Tensor predict (NDArray X)
{
- // assert self.dist is not None
- // nb_classes, nb_features = map(int, self.dist.scale.shape)
+ if (dist == null)
+ {
+ throw new ArgumentNullException("cant not find the model (normal distribution)!");
+ }
+ int nb_classes = (int) dist.scale().shape[0];
+ int nb_features = (int)dist.scale().shape[1];
+
+ // Conditional probabilities log P(x|c) with shape
+ // (nb_samples, nb_classes)
+ Tensor tile = tf.tile(new Tensor(X), new Tensor(new int[] { -1, nb_classes, nb_features }));
+ Tensor r = tf.reshape(tile, new Tensor(new int[] { -1, nb_classes, nb_features }));
+ var cond_probs = tf.reduce_sum(dist.log_prob(r));
+ // uniform priors
+ var priors = np.log(np.array((1.0 / nb_classes) * nb_classes));
+ // posterior log probability, log P(c) + log P(x|c)
+ var joint_likelihood = tf.add(new Tensor(priors), cond_probs);
+ // normalize to get (log)-probabilities
- throw new NotFiniteNumberException();
+ var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, true);
+ var log_prob = joint_likelihood - norm_factor;
+ // exp to get the actual probabilities
+ return tf.exp(log_prob);
}
}
}