Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.exp.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static Tensor exp(Tensor x,
string name = null) => gen_math_ops.exp(x, name);

}
}
15 changes: 15 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static Tensor reduce_logsumexp(Tensor input_tensor,
int[] axis = null,
bool keepdims = false,
string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name);

}
}
14 changes: 14 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.reshape.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static Tensor reshape(Tensor tensor,
Tensor shape,
string name = null) => gen_array_ops.reshape(tensor, shape, name);

}
}
14 changes: 14 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.tile.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static Tensor tile(Tensor input,
Tensor multiples,
string name = null) => gen_array_ops.tile(input, multiples, name);

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class Distribution : _BaseDistribution
/// <param name="name"> Python `str` prepended to names of ops created by this function.</param>
/// <returns>log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`.</returns>

/*

public Tensor log_prob(Tensor value, string name = "log_prob")
{
return _call_log_prob(value, name);
Expand All @@ -45,18 +45,39 @@ private Tensor _call_log_prob (Tensor value, string name)
{
with(ops.name_scope(name, "moments", new { value }), scope =>
{
value = _convert_to_tensor(value, "value", _dtype);
try
{
return _log_prob(value);
}
catch (Exception e1)
{
try
{
return math_ops.log(_prob(value));
} catch (Exception e2)
{
throw new NotImplementedException();
}
}
});
return null;
}

private Tensor _log_prob(Tensor value)
{
throw new NotImplementedException();

}

private Tensor _convert_to_tensor(Tensor value, string name = null, TF_DataType preferred_dtype)
private Tensor _prob(Tensor value)
{
throw new NotImplementedException();
}
*/

public TF_DataType dtype()
{
return this._dtype;
}


/// <summary>
/// Constructs the `Distribution'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using Tensorflow;

Expand Down Expand Up @@ -80,7 +81,7 @@ public Tensor _batch_shape()

private Tensor _log_prob(Tensor x)
{
return _log_unnormalized_prob(_z(x));
return _log_unnormalized_prob(_z(x)) -_log_normalization();
}

private Tensor _log_unnormalized_prob (Tensor x)
Expand All @@ -92,5 +93,11 @@ private Tensor _z (Tensor x)
{
return (x - this._loc) / this._scale;
}

private Tensor _log_normalization()
{
Tensor t = new Tensor(Math.Log(2.0 * Math.PI));
return 0.5 * t + math_ops.log(scale());
}
}
}
5 changes: 5 additions & 0 deletions src/TensorFlowNET.Core/Operations/array_ops.py.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ public static Tensor rank(Tensor input, string name = null)
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> ones_like_impl(tensor, dtype, name, optimize);

public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
{
return gen_array_ops.reshape(tensor, shape, null);
}

private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
{
return with(ops.name_scope(name, "ones_like", new { tensor }), scope =>
Expand Down
59 changes: 59 additions & 0 deletions src/TensorFlowNET.Core/Operations/gen_math_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,58 @@ public static Tensor squared_difference(Tensor x, Tensor y, string name = null)
return _op.outputs[0];
}

/// <summary>
/// Computes square of x element-wise.
/// </summary>
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.</param>
/// <param name="name"> A name for the operation (optional).</param>
/// <returns> A `Tensor`. Has the same type as `x`.</returns>
public static Tensor square(Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x });

return _op.outputs[0];
}

/// <summary>
/// Returns which elements of x are finite.
/// </summary>
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`.</param>
/// <param name="name"> A name for the operation (optional).</param>
/// <returns> A `Tensor` of type `bool`.</returns>
public static Tensor is_finite(Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("IsFinite", name, args: new { x });

return _op.outputs[0];
}

/// <summary>
/// Computes exponential of x element-wise. \\(y = e^x\\).
/// </summary>
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`.</param>
/// <param name="name"> A name for the operation (optional).</param>
/// <returns> A `Tensor`. Has the same type as `x`.</returns>
public static Tensor exp(Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Exp", name, args: new { x });

return _op.outputs[0];
}

/// <summary>
/// Computes natural logarithm of x element-wise.
/// </summary>
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`.</param>
/// <param name="name"> name: A name for the operation (optional).</param>
/// <returns> A `Tensor`. Has the same type as `x`.</returns>
public static Tensor log(Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x });

return _op.outputs[0];
}

public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= "")
{
var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate });
Expand Down Expand Up @@ -134,6 +186,13 @@ public static Tensor maximum<T1, T2>(T1 x, T2 y, string name = null)
return _op.outputs[0];
}

public static Tensor _max(Tensor input, int[] axis, bool keep_dims=false, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims });

return _op.outputs[0];
}

public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Pow", name, args: new { x, y });
Expand Down
52 changes: 51 additions & 1 deletion src/TensorFlowNET.Core/Operations/math_ops.py.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ public static Tensor square_difference(Tensor x, Tensor y, string name = null)

public static Tensor square(Tensor x, string name = null)
{
throw new NotImplementedException();
return gen_math_ops.square(x, name);
}

public static Tensor log(Tensor x, string name = null)
{
return gen_math_ops.log(x, name);
}

/// <summary>
Expand All @@ -82,6 +87,51 @@ public static Tensor reduced_shape(Tensor input_shape, Tensor axes)
return gen_data_flow_ops.dynamic_stitch(a1, a2);
}

/// <summary>
/// Computes log(sum(exp(elements across dimensions of a tensor))).
/// Reduces `input_tensor` along the dimensions given in `axis`.
/// Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
/// entry in `axis`. If `keepdims` is true, the reduced dimensions
/// are retained with length 1.

/// If `axis` has no entries, all dimensions are reduced, and a
/// tensor with a single element is returned.

/// This function is more numerically stable than log(sum(exp(input))). It avoids
/// overflows caused by taking the exp of large inputs and underflows caused by
/// taking the log of small inputs.
/// </summary>
/// <param name="input_tensor"> The tensor to reduce. Should have numeric type.</param>
/// <param name="axis"> The dimensions to reduce. If `None` (the default), reduces all
/// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param>
/// <param name="keepdims"></param>
/// <returns> The reduced tensor.</returns>
public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
{
with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope =>
{
var raw_max = reduce_max(input_tensor, axis, true);
var my_max = array_ops.stop_gradient(array_ops.where(gen_math_ops.is_finite(raw_max), raw_max, array_ops.zeros_like(raw_max)));
var result = gen_math_ops.log(
reduce_sum(
gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)),
new Tensor(axis),
keepdims));
if (!keepdims)
{
my_max = array_ops.reshape(my_max, array_ops.shape(result));
}
result = gen_math_ops.add(result, my_max);
return _may_reduce_to_scalar(keepdims, axis, result);
});
return null;
}

public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
{
return _may_reduce_to_scalar(keepdims, axis, gen_math_ops._max(input_tensor, (int[])_ReductionDims(input_tensor, axis), keepdims, name));
}

/// <summary>
/// Casts a tensor to type `int32`.
/// </summary>
Expand Down
29 changes: 24 additions & 5 deletions test/TensorFlowNET.Examples/NaiveBayesClassifier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace TensorFlowNET.Examples
/// </summary>
public class NaiveBayesClassifier : Python, IExample
{
public Normal dist { get; set; }
public void Run()
{
np.array<float>(1.0f, 1.0f);
Expand Down Expand Up @@ -72,16 +73,34 @@ public void fit(NDArray X, NDArray y)
// Create a 3x2 univariate normal distribution with the
// Known mean and variance
var dist = tf.distributions.Normal(mean, tf.sqrt(variance));

this.dist = dist;
}

public void predict (NDArray X)
public Tensor predict (NDArray X)
{
// assert self.dist is not None
// nb_classes, nb_features = map(int, self.dist.scale.shape)
if (dist == null)
{
throw new ArgumentNullException("cant not find the model (normal distribution)!");
}
int nb_classes = (int) dist.scale().shape[0];
int nb_features = (int)dist.scale().shape[1];

// Conditional probabilities log P(x|c) with shape
// (nb_samples, nb_classes)
Tensor tile = tf.tile(new Tensor(X), new Tensor(new int[] { -1, nb_classes, nb_features }));
Tensor r = tf.reshape(tile, new Tensor(new int[] { -1, nb_classes, nb_features }));
var cond_probs = tf.reduce_sum(dist.log_prob(r));
// uniform priors
var priors = np.log(np.array<double>((1.0 / nb_classes) * nb_classes));

// posterior log probability, log P(c) + log P(x|c)
var joint_likelihood = tf.add(new Tensor(priors), cond_probs);
// normalize to get (log)-probabilities

throw new NotFiniteNumberException();
var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, true);
var log_prob = joint_likelihood - norm_factor;
// exp to get the actual probabilities
return tf.exp(log_prob);
}
}
}