Skip to content

Commit

Permalink
Merge pull request #1011 from BalashovK/master
Browse files Browse the repository at this point in the history
Added: complex, real, imag, angle
  • Loading branch information
Oceania2018 committed Apr 1, 2023
2 parents ccc556d + febba7b commit 79a9363
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 43 deletions.
18 changes: 15 additions & 3 deletions src/TensorFlowNET.Core/APIs/tf.math.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Copyright 2023 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -57,7 +57,7 @@ public Tensor softplus(Tensor features, string name = null)

public Tensor tanh(Tensor x, string name = null)
=> math_ops.tanh(x, name: name);

/// <summary>
/// Finds values and indices of the `k` largest entries for the last dimension.
/// </summary>
Expand Down Expand Up @@ -93,6 +93,16 @@ public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name =
bool binary_output = false)
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength,
dtype: dtype, name: name, axis: axis, binary_output: binary_output);

public Tensor real(Tensor x, string name = null)
=> gen_ops.real(x, x.dtype.real_dtype(), name);
public Tensor imag(Tensor x, string name = null)
=> gen_ops.imag(x, x.dtype.real_dtype(), name);

public Tensor conj(Tensor x, string name = null)
=> gen_ops.conj(x, name);
public Tensor angle(Tensor x, string name = null)
=> gen_ops.angle(x, x.dtype.real_dtype(), name);
}

public Tensor abs(Tensor x, string name = null)
Expand Down Expand Up @@ -537,7 +547,7 @@ public Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims
public Tensor reduce_sum(Tensor input, Axis? axis = null, Axis? reduction_indices = null,
bool keepdims = false, string name = null)
{
if(keepdims)
if (keepdims)
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices), keepdims: keepdims, name: name);
else
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices));
Expand Down Expand Up @@ -585,5 +595,7 @@ public Tensor square(Tensor x, string name = null)
=> gen_math_ops.square(x, name: name);
public Tensor squared_difference(Tensor x, Tensor y, string name = null)
=> gen_math_ops.squared_difference(x: x, y: y, name: name);
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null,
string name = null) => gen_ops.complex(real, imag, dtype, name);
}
}
58 changes: 20 additions & 38 deletions src/TensorFlowNET.Core/Operations/gen_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -730,12 +730,7 @@ public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sam
/// </remarks>
public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("Angle", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout }));
}

/// <summary>
Expand Down Expand Up @@ -4976,15 +4971,14 @@ public static Tensor compare_and_bitpack(Tensor input, Tensor threshold, string
/// tf.complex(real, imag) ==&amp;gt; [[2.25 + 4.75j], [3.25 + 5.75j]]
/// </code>
/// </remarks>
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex")
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex")
{
var dict = new Dictionary<string, object>();
dict["real"] = real;
dict["imag"] = imag;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("Complex", name: name, keywords: dict);
return op.output;
TF_DataType Tin = real.GetDataType();
if (a_Tout is null)
{
a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64;
}
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout }));
}

/// <summary>
Expand All @@ -5008,12 +5002,7 @@ public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null,
/// </remarks>
public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs")
{
var dict = new Dictionary<string, object>();
dict["x"] = x;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("ComplexAbs", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout }));
}

/// <summary>
Expand Down Expand Up @@ -5313,10 +5302,7 @@ public static Tensor configure_distributed_t_p_u(string embedding_config = null,
/// </remarks>
public static Tensor conj(Tensor input, string name = "Conj")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
var op = tf.OpDefLib._apply_op_helper("Conj", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("Conj", name, new ExecuteOpArgs(new object[] { input }));
}

/// <summary>
Expand Down Expand Up @@ -13325,14 +13311,12 @@ public static Tensor igammac(Tensor a, Tensor x, string name = "Igammac")
/// tf.imag(input) ==&amp;gt; [4.75, 5.75]
/// </code>
/// </remarks>
public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag")
public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("Imag", name: name, keywords: dict);
return op.output;
TF_DataType Tin = input.GetDataType();
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));

// return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input }));
}

/// <summary>
Expand Down Expand Up @@ -23863,14 +23847,12 @@ public static Tensor reader_serialize_state_v2(Tensor reader_handle, string name
/// tf.real(input) ==&amp;gt; [-2.25, 3.25]
/// </code>
/// </remarks>
public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real")
public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("Real", name: name, keywords: dict);
return op.output;
TF_DataType Tin = input.GetDataType();
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));

// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input}));
}

/// <summary>
Expand Down
6 changes: 4 additions & 2 deletions src/TensorFlowNET.Core/Operations/math_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding;
using Tensorflow.Operations;

namespace Tensorflow
{
Expand All @@ -35,8 +36,9 @@ public static Tensor abs(Tensor x, string name = null)
name = scope;
x = ops.convert_to_tensor(x, name: "x");
if (x.dtype.is_complex())
throw new NotImplementedException("math_ops.abs for dtype.is_complex");
//return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name);
{
return gen_ops.complex_abs(x, Tout: x.dtype.real_dtype(), name: name);
}
return gen_math_ops._abs(x, name: name);
});
}
Expand Down
202 changes: 202 additions & 0 deletions test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using Buffer = Tensorflow.Buffer;
using TensorFlowNET.Keras.UnitTest;

namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class ComplexTest : EagerModeTestBase
{
// Tests for Complex128

[TestMethod]
public void complex128_basic()
{
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 };
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 };

Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_real_result = tf.math.real(t_complex);
Tensor t_imag_result = tf.math.imag(t_complex);

NDArray n_real_result = t_real_result.numpy();
NDArray n_imag_result = t_imag_result.numpy();

double[] d_real_result =n_real_result.ToArray<double>();
double[] d_imag_result = n_imag_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_real_result, d_real));
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
}
[TestMethod]
public void complex128_abs()
{
tf.enable_eager_execution();

double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };

double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_abs_result = tf.abs(t_complex);

double[] d_abs_result = t_abs_result.numpy().ToArray<double>();
Assert.IsTrue(base.Equal(d_abs_result, d_abs));
}
[TestMethod]
public void complex128_conj()
{
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };

double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 };
double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);

Tensor t_result = tf.math.conj(t_complex);

NDArray n_real_result = tf.math.real(t_result).numpy();
NDArray n_imag_result = tf.math.imag(t_result).numpy();

double[] d_real_result = n_real_result.ToArray<double>();
double[] d_imag_result = n_imag_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));
}
[TestMethod]
public void complex128_angle()
{
double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 };
double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 };

double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);

Tensor t_result = tf.math.angle(t_complex);

NDArray n_result = t_result.numpy();

double[] d_result = n_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_result, d_expected));
}

// Tests for Complex64
[TestMethod]
public void complex64_basic()
{
tf.init_scope();
float[] d_real = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
float[] d_imag = new float[] { -1.0f, -3.0f, 5.0f, 7.0f };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_real_result = tf.math.real(t_complex);
Tensor t_imag_result = tf.math.imag(t_complex);

// Convert the EagerTensors to NumPy arrays directly
float[] d_real_result = t_real_result.numpy().ToArray<float>();
float[] d_imag_result = t_imag_result.numpy().ToArray<float>();

Assert.IsTrue(base.Equal(d_real_result, d_real));
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
}
[TestMethod]
public void complex64_abs()
{
tf.enable_eager_execution();

float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };

float[] d_abs = new float[] { 5.0f, 13.0f, 17.0f, 25.0f };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);

Tensor t_abs_result = tf.abs(t_complex);

NDArray n_abs_result = t_abs_result.numpy();

float[] d_abs_result = n_abs_result.ToArray<float>();
Assert.IsTrue(base.Equal(d_abs_result, d_abs));

}
[TestMethod]
public void complex64_conj()
{
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };

float[] d_real_expected = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
float[] d_imag_expected = new float[] { 4.0f, -12.0f, 15.0f, -24.0f };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);

Tensor t_result = tf.math.conj(t_complex);

NDArray n_real_result = tf.math.real(t_result).numpy();
NDArray n_imag_result = tf.math.imag(t_result).numpy();

float[] d_real_result = n_real_result.ToArray<float>();
float[] d_imag_result = n_imag_result.ToArray<float>();

Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));

}
[TestMethod]
public void complex64_angle()
{
float[] d_real = new float[] { 0.0f, 1.0f, -1.0f, 0.0f };
float[] d_imag = new float[] { 1.0f, 0.0f, -2.0f, -3.0f };

float[] d_expected = new float[] { 1.5707964f, 0f, -2.0344439f, -1.5707964f };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);

Tensor t_result = tf.math.angle(t_complex);

NDArray n_result = t_result.numpy();

float[] d_result = n_result.ToArray<float>();

Assert.IsTrue(base.Equal(d_result, d_expected));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

<ItemGroup>
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
<ProjectReference Include="..\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj" />
</ItemGroup>

</Project>

0 comments on commit 79a9363

Please sign in to comment.