From 31f23565ee923aa60501cbb6b4bb821ebd20d082 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 20 Jan 2020 16:28:51 -0600 Subject: [PATCH 1/3] mask setter --- src/NumSharp.Core/Selection/NDArray.Indexing.Masking.cs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/NumSharp.Core/Selection/NDArray.Indexing.Masking.cs b/src/NumSharp.Core/Selection/NDArray.Indexing.Masking.cs index d0b4f9ae..3b7f3eff 100644 --- a/src/NumSharp.Core/Selection/NDArray.Indexing.Masking.cs +++ b/src/NumSharp.Core/Selection/NDArray.Indexing.Masking.cs @@ -23,7 +23,11 @@ public NDArray this[NDArray mask] get => FetchIndices(this, np.nonzero(mask), null, true); set { - throw new NotImplementedException("Setter is not implemented yet"); + for (int i = 0; i < mask.size; i++) + { + if (mask.GetBoolean(i)) + this[i] = value; + } } } } From e1b9e27301677558bdbe0508ac8890ed0c82f015 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 20 Jan 2020 17:00:28 -0600 Subject: [PATCH 2/3] add MaskSetter2D test --- .../NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs b/test/NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs index 0bf41579..344684a8 100644 --- a/test/NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs +++ b/test/NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs @@ -74,6 +74,15 @@ public void MaskSetter() nd.Should().BeOfValues(-2, 2, -2, 4, -2, 6); } + [TestMethod] + public void MaskSetter2D() + { + var nd = np.arange(15).reshape(5, 3); + var mask = new NDArray(new bool[] { true, false, true, false, true }).MakeGeneric(); + nd[mask] = 99; + nd.Should().BeOfValues(99, 99, 99, 3, 4, 5, 99, 99, 99, 9, 10, 11, 99, 99, 99); + } + [TestMethod] public void Compare() { From 04964ef284e2238b1a02d8ffa2eebb55fab311b2 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 20 Jan 2020 17:09:32 -0600 Subject: [PATCH 3/3] only work for 1d mask. --- .../Selection/NDArray.Indexing.Masking.cs | 13 ++++++++++--- .../Selection/NDArray.Indexing.Test.cs | 9 +++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/NumSharp.Core/Selection/NDArray.Indexing.Masking.cs b/src/NumSharp.Core/Selection/NDArray.Indexing.Masking.cs index 3b7f3eff..c62eddfc 100644 --- a/src/NumSharp.Core/Selection/NDArray.Indexing.Masking.cs +++ b/src/NumSharp.Core/Selection/NDArray.Indexing.Masking.cs @@ -23,10 +23,17 @@ public NDArray this[NDArray mask] get => FetchIndices(this, np.nonzero(mask), null, true); set { - for (int i = 0; i < mask.size; i++) + if(mask.ndim == 1) { - if (mask.GetBoolean(i)) - this[i] = value; + for (int i = 0; i < mask.size; i++) + { + if (mask.GetBoolean(i)) + this[i] = value; + } + } + else + { + throw new NotImplementedException("Setter is not implemented yet"); } } } diff --git a/test/NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs b/test/NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs index 344684a8..de6a6615 100644 --- a/test/NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs +++ b/test/NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs @@ -83,6 +83,15 @@ public void MaskSetter2D() nd.Should().BeOfValues(99, 99, 99, 3, 4, 5, 99, 99, 99, 9, 10, 11, 99, 99, 99); } + [Ignore("to do fix")] + [TestMethod] + public void MaskSetter3D() + { + var nd = np.arange(30).reshape(2, 3, 5); + var mask = new NDArray(new bool[] { true, true, false, false, true, true }).reshape(2, 3).MakeGeneric(); + nd[mask] = 99; + } + [TestMethod] public void Compare() {