Skip to content

Commit

Permalink
Support SIMD on all target frameworks (#329)
Browse files Browse the repository at this point in the history
  • Loading branch information
aalmada committed Feb 19, 2021
1 parent d764e81 commit 9bada1c
Show file tree
Hide file tree
Showing 26 changed files with 146 additions and 198 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ public void Range_With_ToList_Must_Succeed(int start, int count)
.BeEqualTo(expected);
}

#if NET5_0

[Theory]
[MemberData(nameof(TestData.Range), MemberType = typeof(TestData))]
public void Range_SelectVector_Must_Succeed(int start, int count)
Expand All @@ -152,7 +150,23 @@ public void Range_SelectVector_Must_Succeed(int start, int count)
.BeEqualTo(expected);
}



[Theory]
[MemberData(nameof(TestData.Range), MemberType = typeof(TestData))]
public void Range_SelectVector_Sum_Must_Succeed(int start, int count)
{
// Arrange
var expected = Enumerable.Sum(Enumerable.Select(Enumerable.Range(start, count), item => item * 2));

// Act
var result = ValueEnumerable.Range(start, count).SelectVector(item => item * 2, item => item * 2).Sum();

// Assert
_ = result.Must()
.BeEqualTo(expected);
}


[Theory]
[MemberData(nameof(TestData.Range), MemberType = typeof(TestData))]
public void Range_SelectVector_ToArray_Must_Succeed(int start, int count)
Expand Down Expand Up @@ -185,7 +199,5 @@ public void Range_SelectVector_ToList_Must_Succeed(int start, int count)
.BeEnumerableOf<int>()
.BeEqualTo(expected);
}

#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ namespace NetFabric.Hyperlinq
{
public static partial class TestData
{
#if NET5_0

public static TheoryData<int[], Func<Vector<int>, Vector<int>>, Func<int, int>> SelectVector =>
new TheoryData<int[], Func<Vector<int>, Vector<int>>, Func<int, int>>
Expand All @@ -26,6 +25,5 @@ public static partial class TestData
{ new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, item => item * 2, item => item * 2 },
};

#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ namespace NetFabric.Hyperlinq.UnitTests.Projection.SelectVector
{
public class ArraySegmentTests
{
#if NET5_0

[Theory]
[MemberData(nameof(TestData.SelectVector), MemberType = typeof(TestData))]
public void SelectVector_With_ValidData_Must_Succeed(int[] source, Func<Vector<int>, Vector<int>> vectorSelector, Func<int, int> selector)
Expand Down Expand Up @@ -48,7 +46,5 @@ public void SelectVector_ToArray_With_ValidData_Must_Succeed(int[] source, Func<
.BeArrayOf<int>()
.BeEqualTo(expected);
}

#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ namespace NetFabric.Hyperlinq.UnitTests.Projection.SelectVector
{
public class ReadOnlyMemoryTests
{
#if NET5_0

[Theory]
[MemberData(nameof(TestData.SelectVector), MemberType = typeof(TestData))]
public void SelectVector_With_ValidData_Must_Succeed(int[] source, Func<Vector<int>, Vector<int>> vectorSelector, Func<int, int> selector)
Expand Down Expand Up @@ -48,7 +46,5 @@ public void SelectVector_ToArray_With_ValidData_Must_Succeed(int[] source, Func<
.BeArrayOf<int>()
.BeEqualTo(expected);
}

#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ namespace NetFabric.Hyperlinq.UnitTests.Projection.SelectVector
public class ReadOnlySpanTests
{

#if NET5_0

[Theory]
[MemberData(nameof(TestData.SelectVector), MemberType = typeof(TestData))]
public void SelectVector_With_ValidData_Must_Succeed(int[] source, Func<Vector<int>, Vector<int>> vectorSelector, Func<int, int> selector)
Expand Down Expand Up @@ -47,6 +45,5 @@ public void SelectVector_ToArray_With_ValidData_Must_Succeed(int[] source, Func<
.BeEqualTo(expected);
}

#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ namespace NetFabric.Hyperlinq.UnitTests.Quantifier.ContainsVector
{
public class ReadOnlySpanTests
{
#if NET5_0

[Theory]
[MemberData(nameof(TestData.Empty), MemberType = typeof(TestData))]
Expand Down Expand Up @@ -43,6 +42,5 @@ public void ContainsVector_With_Contains_Must_ReturnTrue(int[] source)
.BeTrue();
}

#endif
}
}
17 changes: 7 additions & 10 deletions NetFabric.Hyperlinq/Aggregation/Sum/Sum.Range.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,33 @@ public static partial class ValueEnumerable
public static int SumRange(int start, int count)
=> count * (start + start + count) / 2;

#if NET5_0
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static TResult SumRange<TResult, TVectorSelector, TSelector>(int start, int count, TVectorSelector vectorSelector, TSelector selector)
public static unsafe TResult SumRange<TResult, TVectorSelector, TSelector>(int start, int count, TVectorSelector vectorSelector, TSelector selector)
where TVectorSelector : struct, IFunction<Vector<int>, Vector<TResult>>
where TSelector : struct, IFunction<int, TResult>
where TResult : struct
{
var sum = default(TResult);

var index = 0;

if (Vector.IsHardwareAccelerated && count >= Vector<TResult>.Count) // use SIMD
if (Vector.IsHardwareAccelerated && count > Vector<TResult>.Count * 4) // use SIMD
{
Span<int> seed = stackalloc int[Vector<TResult>.Count];
var seed = stackalloc int[Vector<TResult>.Count];
if (start is 0)
{
for (; index < seed.Length; index++)
for (index = 0; index < Vector<TResult>.Count; index++)
seed[index] = index;
}
else
{
for (; index < seed.Length; index++)
for (index = 0; index < Vector<TResult>.Count; index++)
seed[index] = index + start;
}

var vector = new Vector<int>(seed);
var vector = Unsafe.AsRef<Vector<int>>(seed);
var vectorIncrement = new Vector<int>(Vector<TResult>.Count);
var vectorSum = Vector<TResult>.Zero;
for (; index <= count - Vector<TResult>.Count; index += Vector<TResult>.Count)
for (index = 0; index <= count - Vector<TResult>.Count; index += Vector<TResult>.Count)
{
vectorSum += vectorSelector.Invoke(vector);
vector += vectorIncrement;
Expand All @@ -61,7 +59,6 @@ public static int SumRange(int start, int count)

return sum;
}
#endif
}
}

55 changes: 24 additions & 31 deletions NetFabric.Hyperlinq/Aggregation/Sum/Sum.ReadOnlySpan.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace NetFabric.Hyperlinq
{
Expand Down Expand Up @@ -53,38 +54,32 @@ static TSource Sum<TSource>(this ReadOnlySpan<TSource> source)
{
var sum = default(TSource);

#if NET5_0

if (Vector.IsHardwareAccelerated && source.Length >= Vector<TSource>.Count) // use SIMD
if (Vector.IsHardwareAccelerated && source.Length > Vector<TSource>.Count * 2) // use SIMD
{
var vector = Vector<TSource>.Zero;
var vectors = MemoryMarshal.Cast<TSource, Vector<TSource>>(source);
var vectorSum = Vector<TSource>.Zero;

foreach (var vector in vectors)
vectorSum += vector;

var index = 0;
for (; index <= source.Length - Vector<TSource>.Count; index += Vector<TSource>.Count)
vector += new Vector<TSource>(source.Slice(index, Vector<TSource>.Count));
for (var index = 0; index < Vector<TSource>.Count; index++)
sum = GenericsOperator.Add(vectorSum[index], sum);

for (; index < source.Length; index++)
for (var index = source.Length - (source.Length % Vector<TSource>.Count); index < source.Length; index++)
{
var item = source[index];
sum = GenericsOperator.Add(item, sum);
}

for (index = 0; index < Vector<TSource>.Count; index++)
sum = GenericsOperator.Add(vector[index], sum);

return sum;
}

#endif

foreach (var item in source)
sum = GenericsOperator.Add(item, sum);
else
{
foreach (var item in source)
sum = GenericsOperator.Add(item, sum);
}

return sum;
}

#if NET5_0

static TResult Sum<TSource, TResult, TVectorSelector, TSelector>(this ReadOnlySpan<TSource> source, TVectorSelector vectorSelector, TSelector selector)
where TVectorSelector : struct, IFunction<Vector<TSource>, Vector<TResult>>
where TSelector : struct, IFunction<TSource, TResult>
Expand All @@ -93,22 +88,22 @@ static TSource Sum<TSource>(this ReadOnlySpan<TSource> source)
{
var sum = default(TResult);

if (Vector.IsHardwareAccelerated && source.Length >= Vector<TResult>.Count) // use SIMD
if (Vector.IsHardwareAccelerated && source.Length > Vector<TResult>.Count * 2) // use SIMD
{
var vector = Vector<TResult>.Zero;
var vectors = MemoryMarshal.Cast<TSource, Vector<TSource>>(source);
var vectorSum = Vector<TResult>.Zero;

var index = 0;
for (; index <= source.Length - Vector<TResult>.Count; index += Vector<TResult>.Count)
vector += vectorSelector.Invoke(new Vector<TSource>(source.Slice(index, Vector<TResult>.Count)));
foreach (var vector in vectors)
vectorSum += vectorSelector.Invoke(vector);

for (; index < source.Length; index++)
for (var index = 0; index < Vector<TResult>.Count; index++)
sum = GenericsOperator.Add(vectorSum[index], sum);

for (var index = source.Length - (source.Length % Vector<TSource>.Count); index < source.Length; index++)
{
var item = source[index];
sum = GenericsOperator.Add(selector.Invoke(item), sum);
}

for (index = 0; index < Vector<TSource>.Count; index++)
sum = GenericsOperator.Add(vector[index], sum);
}
else
{
Expand All @@ -118,8 +113,6 @@ static TSource Sum<TSource>(this ReadOnlySpan<TSource> source)
return sum;
}

#endif

static TSum Sum<TSource, TSum>(this ReadOnlySpan<TSource> source)
where TSum : struct
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,6 @@ public static decimal Sum(this ValueEnumerable<decimal> source)
public static decimal Sum(this ValueEnumerable<decimal?> source)
=> source.source.AsSpan().Sum();

#if NET5_0

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool ContainsVector<TSource>(this ValueEnumerable<TSource> source, TSource value)
where TSource : struct
Expand All @@ -417,7 +415,5 @@ public static bool ContainsVector<TSource>(this ValueEnumerable<TSource> source,
where TSource : struct
where TResult : struct
=> new ArraySegment<TSource>(source.source).SelectVector<TSource, TResult, TVectorSelector, TSelector>(vectorSelector, selector);

#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,6 @@ public static decimal Sum(this ValueEnumerable<decimal> source)
public static decimal Sum(this ValueEnumerable<decimal?> source)
=> source.source.Span.Sum();

#if NET5_0

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool ContainsVector<TSource>(this ValueEnumerable<TSource> source, TSource value)
where TSource : struct
Expand All @@ -429,7 +427,5 @@ public static bool ContainsVector<TSource>(this ValueEnumerable<TSource> source,
where TSource : struct
where TResult : struct
=> source.source.SelectVector<TSource, TResult, TVectorSelector, TSelector>(vectorSelector, selector);

#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,6 @@ public static decimal Sum(this ValueEnumerable<decimal> source)
public static decimal Sum(this ValueEnumerable<decimal?> source)
=> source.source.Span.Sum();

#if NET5_0

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool ContainsVector<TSource>(this ValueEnumerable<TSource> source, TSource value)
where TSource : struct
Expand All @@ -429,7 +427,5 @@ public static bool ContainsVector<TSource>(this ValueEnumerable<TSource> source,
where TSource : struct
where TResult : struct
=> source.source.SelectVector<TSource, TResult, TVectorSelector, TSelector>(vectorSelector, selector);

#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,6 @@ public static decimal Sum(this ListValueEnumerable<decimal> source)
public static decimal Sum(this ListValueEnumerable<decimal?> source)
=> source.source.AsSpan().Sum();

#if NET5_0

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool ContainsVector<TSource>(this ListValueEnumerable<TSource> source, TSource value)
where TSource : struct
Expand All @@ -353,7 +351,5 @@ public static bool ContainsVector<TSource>(this ListValueEnumerable<TSource> sou
where TSource : struct
where TResult : struct
=> source.source.AsArraySegment().SelectVector<TSource, TResult, TVectorSelector, TSelector>(vectorSelector, selector);

#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,19 @@ public static IMemoryOwner<TSource> ToArray<TSource>(this ReadOnlySpan<TSource>
return result;
}

#if NET5_0
[MethodImpl(MethodImplOptions.AggressiveInlining)]
static TResult[] ToArrayVector<TSource, TResult, TVectorSelector, TSelector>(this ReadOnlySpan<TSource> source, TVectorSelector vectorSelector, TSelector selector)
where TVectorSelector : struct, IFunction<Vector<TSource>, Vector<TResult>>
where TSelector : struct, IFunction<TSource, TResult>
where TSource : struct
where TResult : struct
{
#if NET5_0
var result = GC.AllocateUninitializedArray<TResult>(source.Length);
#else
// ReSharper disable once HeapView.ObjectAllocation.Evident
var result = new TResult[source.Length];
#endif
CopyVector<TSource, TResult, TVectorSelector, TSelector>(source, result.AsSpan(), vectorSelector, selector);
return result;
}
Expand All @@ -127,7 +131,6 @@ public static IMemoryOwner<TSource> ToArray<TSource>(this ReadOnlySpan<TSource>
CopyVector<TSource, TResult, TVectorSelector, TSelector>(source, result.Memory.Span, vectorSelector, selector);
return result;
}
#endif

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static TResult[] ToArrayRef<TSource, TResult, TSelector>(this ReadOnlySpan<TSource> source, TSelector selector)
Expand Down
2 changes: 0 additions & 2 deletions NetFabric.Hyperlinq/Conversion/ToList/ToList.ReadOnlySpan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public static List<TSource> ToList<TSource>(this ReadOnlySpan<TSource> source)
_ => source.ToArray<TSource, TResult, TSelector>(selector).AsList()
};

#if NET5_0
[MethodImpl(MethodImplOptions.AggressiveInlining)]
static List<TResult> ToListVector<TSource, TResult, TVectorSelector, TSelector>(this ReadOnlySpan<TSource> source, TVectorSelector vectorSelector, TSelector selector)
where TVectorSelector : struct, IFunction<Vector<TSource>, Vector<TResult>>
Expand All @@ -75,7 +74,6 @@ public static List<TSource> ToList<TSource>(this ReadOnlySpan<TSource> source)
0 => new List<TResult>(),
_ => source.ToArrayVector<TSource, TResult, TVectorSelector, TSelector>(vectorSelector, selector).AsList()
};
#endif

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static List<TResult> ToListRef<TSource, TResult, TSelector>(this ReadOnlySpan<TSource> source, TSelector selector)
Expand Down
Loading

0 comments on commit 9bada1c

Please sign in to comment.