diff --git a/src/System.Linq/src/System.Linq.csproj b/src/System.Linq/src/System.Linq.csproj index a455e79a128a..a8414436c43a 100644 --- a/src/System.Linq/src/System.Linq.csproj +++ b/src/System.Linq/src/System.Linq.csproj @@ -1,4 +1,4 @@ - + @@ -34,8 +34,49 @@ System\Diagnostics\CodeAnalysis\ExcludeFromCodeCoverageAttribute.cs + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/System.Linq/src/System/Linq/Aggregate.cs b/src/System.Linq/src/System/Linq/Aggregate.cs new file mode 100644 index 000000000000..924a49c035cf --- /dev/null +++ b/src/System.Linq/src/System/Linq/Aggregate.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static TSource Aggregate(this IEnumerable source, Func func) + { + if (source == null) throw Error.ArgumentNull("source"); + if (func == null) throw Error.ArgumentNull("func"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + TSource result = e.Current; + while (e.MoveNext()) result = func(result, e.Current); + return result; + } + } + + public static TAccumulate Aggregate(this IEnumerable source, TAccumulate seed, Func func) + { + if (source == null) throw Error.ArgumentNull("source"); + if (func == null) throw Error.ArgumentNull("func"); + TAccumulate result = seed; + foreach (TSource element in source) result = func(result, element); + return result; + } + + public static TResult Aggregate(this IEnumerable source, TAccumulate seed, Func func, Func resultSelector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (func == null) throw Error.ArgumentNull("func"); + if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); + TAccumulate result = seed; + foreach (TSource element in source) result = func(result, element); + return resultSelector(result); + } + } +} diff --git a/src/System.Linq/src/System/Linq/AnyAll.cs b/src/System.Linq/src/System/Linq/AnyAll.cs new file mode 100644 index 000000000000..376d6f396b17 --- /dev/null +++ b/src/System.Linq/src/System/Linq/AnyAll.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static bool Any(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + return e.MoveNext(); + } + } + + public static bool Any(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + foreach (TSource element in source) + { + if (predicate(element)) return true; + } + return false; + } + + public static bool All(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + foreach (TSource element in source) + { + if (!predicate(element)) return false; + } + return true; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Average.cs b/src/System.Linq/src/System/Linq/Average.cs new file mode 100644 index 000000000000..eab1a8e4265a --- /dev/null +++ b/src/System.Linq/src/System/Linq/Average.cs @@ -0,0 +1,514 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static double Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + long sum = e.Current; + long count = 1; + checked + { + while (e.MoveNext()) + { + sum += e.Current; + ++count; + } + } + return (double)sum / count; + } + } + + public static double? Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + int? v = e.Current; + if (v.HasValue) + { + long sum = v.GetValueOrDefault(); + long count = 1; + checked + { + while (e.MoveNext()) + { + v = e.Current; + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + } + return (double)sum / count; + } + } + } + return null; + } + + public static double Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + long sum = e.Current; + long count = 1; + checked + { + while (e.MoveNext()) + { + sum += e.Current; + ++count; + } + } + return (double)sum / count; + } + } + + public static double? Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + long? v = e.Current; + if (v.HasValue) + { + long sum = v.GetValueOrDefault(); + long count = 1; + checked + { + while (e.MoveNext()) + { + v = e.Current; + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + } + return (double)sum / count; + } + } + } + return null; + } + + public static float Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + double sum = e.Current; + long count = 1; + while (e.MoveNext()) + { + sum += e.Current; + ++count; + } + return (float)(sum / count); + } + } + + public static float? Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + float? v = e.Current; + if (v.HasValue) + { + double sum = v.GetValueOrDefault(); + long count = 1; + checked + { + while (e.MoveNext()) + { + v = e.Current; + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + } + return (float)(sum / count); + } + } + } + return null; + } + + public static double Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + double sum = e.Current; + long count = 1; + while (e.MoveNext()) + { + // There is an opportunity to short-circuit here, in that if e.Current is + // ever NaN then the result will always be NaN. Assuming that this case is + // rare enough that not checking is the better approach generally. + sum += e.Current; + ++count; + } + return sum / count; + } + } + + public static double? Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + double? v = e.Current; + if (v.HasValue) + { + double sum = v.GetValueOrDefault(); + long count = 1; + checked + { + while (e.MoveNext()) + { + v = e.Current; + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + } + return sum / count; + } + } + } + return null; + } + + public static decimal Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + decimal sum = e.Current; + long count = 1; + while (e.MoveNext()) + { + sum += e.Current; + ++count; + } + return sum / count; + } + } + + public static decimal? Average(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + decimal? v = e.Current; + if (v.HasValue) + { + decimal sum = v.GetValueOrDefault(); + long count = 1; + while (e.MoveNext()) + { + v = e.Current; + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + return sum / count; + } + } + } + return null; + } + + public static double Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + long sum = selector(e.Current); + long count = 1; + checked + { + while (e.MoveNext()) + { + sum += selector(e.Current); + ++count; + } + } + return (double)sum / count; + } + } + + public static double? Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + int? v = selector(e.Current); + if (v.HasValue) + { + long sum = v.GetValueOrDefault(); + long count = 1; + checked + { + while (e.MoveNext()) + { + v = selector(e.Current); + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + } + return (double)sum / count; + } + } + } + return null; + } + + public static double Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + long sum = selector(e.Current); + long count = 1; + checked + { + while (e.MoveNext()) + { + sum += selector(e.Current); + ++count; + } + } + return (double)sum / count; + } + } + + public static double? Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + long? v = selector(e.Current); + if (v.HasValue) + { + long sum = v.GetValueOrDefault(); + long count = 1; + checked + { + while (e.MoveNext()) + { + v = selector(e.Current); + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + } + return (double)sum / count; + } + } + } + return null; + } + + public static float Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + double sum = selector(e.Current); + long count = 1; + while (e.MoveNext()) + { + sum += selector(e.Current); + ++count; + } + return (float)(sum / count); + } + } + + public static float? Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + float? v = selector(e.Current); + if (v.HasValue) + { + double sum = v.GetValueOrDefault(); + long count = 1; + checked + { + while (e.MoveNext()) + { + v = selector(e.Current); + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + } + return (float)(sum / count); + } + } + } + return null; + } + + public static double Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + double sum = selector(e.Current); + long count = 1; + while (e.MoveNext()) + { + // There is an opportunity to short-circuit here, in that if e.Current is + // ever NaN then the result will always be NaN. Assuming that this case is + // rare enough that not checking is the better approach generally. + sum += selector(e.Current); + ++count; + } + return sum / count; + } + } + + public static double? Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + double? v = selector(e.Current); + if (v.HasValue) + { + double sum = v.GetValueOrDefault(); + long count = 1; + checked + { + while (e.MoveNext()) + { + v = selector(e.Current); + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + } + return sum / count; + } + } + } + return null; + } + + public static decimal Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + decimal sum = selector(e.Current); + long count = 1; + while (e.MoveNext()) + { + sum += selector(e.Current); + ++count; + } + return sum / count; + } + } + + public static decimal? Average(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + decimal? v = selector(e.Current); + if (v.HasValue) + { + decimal sum = v.GetValueOrDefault(); + long count = 1; + while (e.MoveNext()) + { + v = selector(e.Current); + if (v.HasValue) + { + sum += v.GetValueOrDefault(); + ++count; + } + } + return sum / count; + } + } + } + return null; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Buffer.cs b/src/System.Linq/src/System/Linq/Buffer.cs new file mode 100644 index 000000000000..738ef2bc2fae --- /dev/null +++ b/src/System.Linq/src/System/Linq/Buffer.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + internal struct Buffer + { + internal TElement[] items; + internal int count; + + internal Buffer(IEnumerable source) + { + IArrayProvider iterator = source as IArrayProvider; + if (iterator != null) + { + TElement[] array = iterator.ToArray(); + items = array; + count = array.Length; + } + else + { + items = EnumerableHelpers.ToArray(source, out count); + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Cast.cs b/src/System.Linq/src/System/Linq/Cast.cs new file mode 100644 index 000000000000..daa61ddf0c27 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Cast.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable OfType(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + return OfTypeIterator(source); + } + + private static IEnumerable OfTypeIterator(IEnumerable source) + { + foreach (object obj in source) + { + if (obj is TResult) yield return (TResult)obj; + } + } + + public static IEnumerable Cast(this IEnumerable source) + { + IEnumerable typedSource = source as IEnumerable; + if (typedSource != null) return typedSource; + if (source == null) throw Error.ArgumentNull("source"); + return CastIterator(source); + } + + private static IEnumerable CastIterator(IEnumerable source) + { + foreach (object obj in source) yield return (TResult)obj; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Concatenate.cs b/src/System.Linq/src/System/Linq/Concatenate.cs new file mode 100644 index 000000000000..5de90b43d174 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Concatenate.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Concat(this IEnumerable first, IEnumerable second) + { + if (first == null) throw Error.ArgumentNull("first"); + if (second == null) throw Error.ArgumentNull("second"); + return ConcatIterator(first, second); + } + + private static IEnumerable ConcatIterator(IEnumerable first, IEnumerable second) + { + foreach (TSource element in first) yield return element; + foreach (TSource element in second) yield return element; + } + + public static IEnumerable Append(this IEnumerable source, TSource element) + { + if (source == null) throw Error.ArgumentNull("source"); + return AppendIterator(source, element); + } + + private static IEnumerable AppendIterator(IEnumerable source, TSource element) + { + foreach (TSource e1 in source) yield return e1; + yield return element; + } + + public static IEnumerable Prepend(this IEnumerable source, TSource element) + { + if (source == null) throw Error.ArgumentNull("source"); + return PrependIterator(source, element); + } + + private static IEnumerable PrependIterator(IEnumerable source, TSource element) + { + yield return element; + foreach (TSource e1 in source) yield return e1; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Contains.cs b/src/System.Linq/src/System/Linq/Contains.cs new file mode 100644 index 000000000000..7dbff25b2160 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Contains.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static bool Contains(this IEnumerable source, TSource value) + { + ICollection collection = source as ICollection; + if (collection != null) return collection.Contains(value); + return Contains(source, value, null); + } + + public static bool Contains(this IEnumerable source, TSource value, IEqualityComparer comparer) + { + if (comparer == null) comparer = EqualityComparer.Default; + if (source == null) throw Error.ArgumentNull("source"); + foreach (TSource element in source) + if (comparer.Equals(element, value)) return true; + return false; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Count.cs b/src/System.Linq/src/System/Linq/Count.cs new file mode 100644 index 000000000000..d97dfa947546 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Count.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static int Count(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + ICollection collectionoft = source as ICollection; + if (collectionoft != null) return collectionoft.Count; + ICollection collection = source as ICollection; + if (collection != null) return collection.Count; + int count = 0; + using (IEnumerator e = source.GetEnumerator()) + { + checked + { + while (e.MoveNext()) count++; + } + } + return count; + } + + public static int Count(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + int count = 0; + foreach (TSource element in source) + { + checked + { + if (predicate(element)) count++; + } + } + return count; + } + + public static long LongCount(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + long count = 0; + using (IEnumerator e = source.GetEnumerator()) + { + checked + { + while (e.MoveNext()) count++; + } + } + return count; + } + + public static long LongCount(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + long count = 0; + foreach (TSource element in source) + { + checked + { + if (predicate(element)) count++; + } + } + return count; + } + } +} diff --git a/src/System.Linq/src/System/Linq/DebugView.cs b/src/System.Linq/src/System/Linq/DebugView.cs new file mode 100644 index 000000000000..af128bc4454e --- /dev/null +++ b/src/System.Linq/src/System/Linq/DebugView.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Linq +{ + // NOTE: DO NOT DELETE THE FOLLOWING DEBUG VIEW TYPES. + // Although it might be tempting due to them not be referenced anywhere in this library, + // Visual Studio currently depends on their existence to enable the "Results" view in + // watch windows. + + /// + /// This class provides the items view for the Enumerable + /// + /// + internal sealed class SystemCore_EnumerableDebugView + { + public SystemCore_EnumerableDebugView(IEnumerable enumerable) + { + if (enumerable == null) + { + throw new ArgumentNullException("enumerable"); + } + + _enumerable = enumerable; + } + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public T[] Items + { + get + { + T[] array = _enumerable.ToArray(); + if (array.Length == 0) + { + throw new SystemCore_EnumerableDebugViewEmptyException(); + } + return array; + } + } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private IEnumerable _enumerable; + } + + internal sealed class SystemCore_EnumerableDebugViewEmptyException : Exception + { + public string Empty + { + get + { + return SR.EmptyEnumerable; + } + } + } + + internal sealed class SystemCore_EnumerableDebugView + { + public SystemCore_EnumerableDebugView(IEnumerable enumerable) + { + if (enumerable == null) + { + throw new ArgumentNullException("enumerable"); + } + + _enumerable = enumerable; + } + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public object[] Items + { + get + { + List tempList = new List(); + foreach (object item in _enumerable) + tempList.Add(item); + + if (tempList.Count == 0) + { + throw new SystemCore_EnumerableDebugViewEmptyException(); + } + return tempList.ToArray(); + } + } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private IEnumerable _enumerable; + } +} diff --git a/src/System.Linq/src/System/Linq/DefaultIfEmpty.cs b/src/System.Linq/src/System/Linq/DefaultIfEmpty.cs new file mode 100644 index 000000000000..abc644928415 --- /dev/null +++ b/src/System.Linq/src/System/Linq/DefaultIfEmpty.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable DefaultIfEmpty(this IEnumerable source) + { + return DefaultIfEmpty(source, default(TSource)); + } + + public static IEnumerable DefaultIfEmpty(this IEnumerable source, TSource defaultValue) + { + if (source == null) throw Error.ArgumentNull("source"); + return DefaultIfEmptyIterator(source, defaultValue); + } + + private static IEnumerable DefaultIfEmptyIterator(IEnumerable source, TSource defaultValue) + { + using (IEnumerator e = source.GetEnumerator()) + { + if (e.MoveNext()) + { + do + { + yield return e.Current; + } while (e.MoveNext()); + } + else + { + yield return defaultValue; + } + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Distinct.cs b/src/System.Linq/src/System/Linq/Distinct.cs new file mode 100644 index 000000000000..7ef74b9f4680 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Distinct.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Distinct(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + return DistinctIterator(source, null); + } + + public static IEnumerable Distinct(this IEnumerable source, IEqualityComparer comparer) + { + if (source == null) throw Error.ArgumentNull("source"); + return DistinctIterator(source, comparer); + } + + private static IEnumerable DistinctIterator(IEnumerable source, IEqualityComparer comparer) + { + Set set = new Set(comparer); + foreach (TSource element in source) + if (set.Add(element)) yield return element; + } + } +} diff --git a/src/System.Linq/src/System/Linq/ElementAt.cs b/src/System.Linq/src/System/Linq/ElementAt.cs new file mode 100644 index 000000000000..5e63180bd0b5 --- /dev/null +++ b/src/System.Linq/src/System/Linq/ElementAt.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static TSource ElementAt(this IEnumerable source, int index) + { + if (source == null) throw Error.ArgumentNull("source"); + IPartition partition = source as IPartition; + if (partition != null) return partition.ElementAt(index); + IList list = source as IList; + if (list != null) return list[index]; + if (index >= 0) + { + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + if (index == 0) return e.Current; + index--; + } + } + } + throw Error.ArgumentOutOfRange("index"); + } + + public static TSource ElementAtOrDefault(this IEnumerable source, int index) + { + if (source == null) throw Error.ArgumentNull("source"); + IPartition partition = source as IPartition; + if (partition != null) return partition.ElementAtOrDefault(index); + if (index >= 0) + { + IList list = source as IList; + if (list != null) + { + if (index < list.Count) return list[index]; + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + if (index == 0) return e.Current; + index--; + } + } + } + } + return default(TSource); + } + } +} diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index c9fdc2d2a1a5..519143d7da0b 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -3,5597 +3,20 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections; using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Threading; namespace System.Linq { - public static class Enumerable + public static partial class Enumerable { - public static IEnumerable Where(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - Iterator iterator = source as Iterator; - if (iterator != null) return iterator.Where(predicate); - TSource[] array = source as TSource[]; - if (array != null) return new WhereArrayIterator(array, predicate); - List list = source as List; - if (list != null) return new WhereListIterator(list, predicate); - return new WhereEnumerableIterator(source, predicate); - } - - public static IEnumerable Where(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - return WhereIterator(source, predicate); - } - - private static IEnumerable WhereIterator(IEnumerable source, Func predicate) - { - int index = -1; - foreach (TSource element in source) - { - checked { index++; } - if (predicate(element, index)) yield return element; - } - } - - public static IEnumerable Select(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - Iterator iterator = source as Iterator; - if (iterator != null) return iterator.Select(selector); - IList ilist = source as IList; - if (ilist != null) - { - TSource[] array = source as TSource[]; - if (array != null) return new SelectArrayIterator(array, selector); - List list = source as List; - if (list != null) return new SelectListIterator(list, selector); - return new SelectIListIterator(ilist, selector); - } - return new SelectEnumerableIterator(source, selector); - } - - public static IEnumerable Select(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - return SelectIterator(source, selector); - } - - private static IEnumerable SelectIterator(IEnumerable source, Func selector) - { - int index = -1; - foreach (TSource element in source) - { - checked { index++; } - yield return selector(element, index); - } - } - - private static Func CombinePredicates(Func predicate1, Func predicate2) - { - return x => predicate1(x) && predicate2(x); - } - - private static Func CombineSelectors(Func selector1, Func selector2) - { - return x => selector2(selector1(x)); - } - - internal abstract class Iterator : IEnumerable, IEnumerator - { - private int _threadId; - internal int state; - internal TSource current; - - public Iterator() - { - _threadId = Environment.CurrentManagedThreadId; - } - - public TSource Current - { - get { return current; } - } - - public abstract Iterator Clone(); - - public virtual void Dispose() - { - current = default(TSource); - state = -1; - } - - public IEnumerator GetEnumerator() - { - Iterator enumerator = state == 0 && _threadId == Environment.CurrentManagedThreadId ? this : Clone(); - enumerator.state = 1; - return enumerator; - } - - public abstract bool MoveNext(); - - public virtual IEnumerable Select(Func selector) - { - return new SelectEnumerableIterator(this, selector); - } - - public virtual IEnumerable Where(Func predicate) - { - return new WhereEnumerableIterator(this, predicate); - } - - object IEnumerator.Current - { - get { return Current; } - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - - void IEnumerator.Reset() - { - throw Error.NotSupported(); - } - } - - internal class WhereEnumerableIterator : Iterator - { - private readonly IEnumerable _source; - private readonly Func _predicate; - private IEnumerator _enumerator; - - public WhereEnumerableIterator(IEnumerable source, Func predicate) - { - Debug.Assert(source != null); - Debug.Assert(predicate != null); - _source = source; - _predicate = predicate; - } - - public override Iterator Clone() - { - return new WhereEnumerableIterator(_source, _predicate); - } - - public override void Dispose() - { - if (_enumerator != null) - { - _enumerator.Dispose(); - _enumerator = null; - } - base.Dispose(); - } - - public override bool MoveNext() - { - switch (state) - { - case 1: - _enumerator = _source.GetEnumerator(); - state = 2; - goto case 2; - case 2: - while (_enumerator.MoveNext()) - { - TSource item = _enumerator.Current; - if (_predicate(item)) - { - current = item; - return true; - } - } - Dispose(); - break; - } - return false; - } - - public override IEnumerable Select(Func selector) - { - return new WhereSelectEnumerableIterator(_source, _predicate, selector); - } - - public override IEnumerable Where(Func predicate) - { - return new WhereEnumerableIterator(_source, CombinePredicates(_predicate, predicate)); - } - } - - internal class WhereArrayIterator : Iterator - { - private readonly TSource[] _source; - private readonly Func _predicate; - private int _index; - - public WhereArrayIterator(TSource[] source, Func predicate) - { - Debug.Assert(source != null); - Debug.Assert(predicate != null); - _source = source; - _predicate = predicate; - } - - public override Iterator Clone() - { - return new WhereArrayIterator(_source, _predicate); - } - - public override bool MoveNext() - { - if (state == 1) - { - while (_index < _source.Length) - { - TSource item = _source[_index]; - _index++; - if (_predicate(item)) - { - current = item; - return true; - } - } - Dispose(); - } - return false; - } - - public override IEnumerable Select(Func selector) - { - return new WhereSelectArrayIterator(_source, _predicate, selector); - } - - public override IEnumerable Where(Func predicate) - { - return new WhereArrayIterator(_source, CombinePredicates(_predicate, predicate)); - } - } - - internal class WhereListIterator : Iterator - { - private readonly List _source; - private readonly Func _predicate; - private List.Enumerator _enumerator; - - public WhereListIterator(List source, Func predicate) - { - Debug.Assert(source != null); - Debug.Assert(predicate != null); - _source = source; - _predicate = predicate; - } - - public override Iterator Clone() - { - return new WhereListIterator(_source, _predicate); - } - - public override bool MoveNext() - { - switch (state) - { - case 1: - _enumerator = _source.GetEnumerator(); - state = 2; - goto case 2; - case 2: - while (_enumerator.MoveNext()) - { - TSource item = _enumerator.Current; - if (_predicate(item)) - { - current = item; - return true; - } - } - Dispose(); - break; - } - return false; - } - - public override IEnumerable Select(Func selector) - { - return new WhereSelectListIterator(_source, _predicate, selector); - } - - public override IEnumerable Where(Func predicate) - { - return new WhereListIterator(_source, CombinePredicates(_predicate, predicate)); - } - } - - internal class WhereSelectEnumerableIterator : Iterator - { - private readonly IEnumerable _source; - private readonly Func _predicate; - private readonly Func _selector; - private IEnumerator _enumerator; - - public WhereSelectEnumerableIterator(IEnumerable source, Func predicate, Func selector) - { - Debug.Assert(source != null); - Debug.Assert(predicate != null); - Debug.Assert(selector != null); - _source = source; - _predicate = predicate; - _selector = selector; - } - - public override Iterator Clone() - { - return new WhereSelectEnumerableIterator(_source, _predicate, _selector); - } - - public override void Dispose() - { - if (_enumerator != null) - { - _enumerator.Dispose(); - _enumerator = null; - } - base.Dispose(); - } - - public override bool MoveNext() - { - switch (state) - { - case 1: - _enumerator = _source.GetEnumerator(); - state = 2; - goto case 2; - case 2: - while (_enumerator.MoveNext()) - { - TSource item = _enumerator.Current; - if (_predicate(item)) - { - current = _selector(item); - return true; - } - } - Dispose(); - break; - } - return false; - } - - public override IEnumerable Select(Func selector) - { - return new WhereSelectEnumerableIterator(_source, _predicate, CombineSelectors(_selector, selector)); - } - } - - internal class WhereSelectArrayIterator : Iterator - { - private readonly TSource[] _source; - private readonly Func _predicate; - private readonly Func _selector; - private int _index; - - public WhereSelectArrayIterator(TSource[] source, Func predicate, Func selector) - { - Debug.Assert(source != null); - Debug.Assert(predicate != null); - Debug.Assert(selector != null); - _source = source; - _predicate = predicate; - _selector = selector; - } - - public override Iterator Clone() - { - return new WhereSelectArrayIterator(_source, _predicate, _selector); - } - - public override bool MoveNext() - { - if (state == 1) - { - while (_index < _source.Length) - { - TSource item = _source[_index]; - _index++; - if (_predicate(item)) - { - current = _selector(item); - return true; - } - } - Dispose(); - } - return false; - } - - public override IEnumerable Select(Func selector) - { - return new WhereSelectArrayIterator(_source, _predicate, CombineSelectors(_selector, selector)); - } - } - - internal class WhereSelectListIterator : Iterator - { - private readonly List _source; - private readonly Func _predicate; - private readonly Func _selector; - private List.Enumerator _enumerator; - - public WhereSelectListIterator(List source, Func predicate, Func selector) - { - Debug.Assert(source != null); - Debug.Assert(predicate != null); - Debug.Assert(selector != null); - _source = source; - _predicate = predicate; - _selector = selector; - } - - public override Iterator Clone() - { - return new WhereSelectListIterator(_source, _predicate, _selector); - } - - public override bool MoveNext() - { - switch (state) - { - case 1: - _enumerator = _source.GetEnumerator(); - state = 2; - goto case 2; - case 2: - while (_enumerator.MoveNext()) - { - TSource item = _enumerator.Current; - if (_predicate(item)) - { - current = _selector(item); - return true; - } - } - Dispose(); - break; - } - return false; - } - - public override IEnumerable Select(Func selector) - { - return new WhereSelectListIterator(_source, _predicate, CombineSelectors(_selector, selector)); - } - } - - internal sealed class SelectEnumerableIterator : Iterator - { - private readonly IEnumerable _source; - private readonly Func _selector; - private IEnumerator _enumerator; - - public SelectEnumerableIterator(IEnumerable source, Func selector) - { - Debug.Assert(source != null); - Debug.Assert(selector != null); - _source = source; - _selector = selector; - } - - public override Iterator Clone() - { - return new SelectEnumerableIterator(_source, _selector); - } - - public override void Dispose() - { - if (_enumerator != null) - { - _enumerator.Dispose(); - _enumerator = null; - } - base.Dispose(); - } - - public override bool MoveNext() - { - switch (state) - { - case 1: - _enumerator = _source.GetEnumerator(); - state = 2; - goto case 2; - case 2: - if (_enumerator.MoveNext()) - { - current = _selector(_enumerator.Current); - return true; - } - Dispose(); - break; - } - return false; - } - - public override IEnumerable Select(Func selector) - { - return new SelectEnumerableIterator(_source, CombineSelectors(_selector, selector)); - } - } - - - internal sealed class SelectArrayIterator : Iterator, IArrayProvider, IListProvider - { - private readonly TSource[] _source; - private readonly Func _selector; - private int _index; - - public SelectArrayIterator(TSource[] source, Func selector) - { - Debug.Assert(source != null); - Debug.Assert(selector != null); - _source = source; - _selector = selector; - } - - public override Iterator Clone() - { - return new SelectArrayIterator(_source, _selector); - } - - public override bool MoveNext() - { - if (state == 1 && _index < _source.Length) - { - current = _selector(_source[_index++]); - return true; - } - Dispose(); - return false; - } - - public override IEnumerable Select(Func selector) - { - return new SelectArrayIterator(_source, CombineSelectors(_selector, selector)); - } - - public TResult[] ToArray() - { - if (_source.Length == 0) - { - return Array.Empty(); - } - - var results = new TResult[_source.Length]; - for (int i = 0; i < results.Length; i++) - { - results[i] = _selector(_source[i]); - } - return results; - } - - public List ToList() - { - TSource[] source = _source; - var results = new List(source.Length); - for (int i = 0; i < source.Length; i++) - { - results.Add(_selector(source[i])); - } - return results; - } - } - - internal sealed class SelectListIterator : Iterator, IArrayProvider, IListProvider - { - private readonly List _source; - private readonly Func _selector; - private List.Enumerator _enumerator; - - public SelectListIterator(List source, Func selector) - { - Debug.Assert(source != null); - Debug.Assert(selector != null); - _source = source; - _selector = selector; - } - - public override Iterator Clone() - { - return new SelectListIterator(_source, _selector); - } - - public override bool MoveNext() - { - switch (state) - { - case 1: - _enumerator = _source.GetEnumerator(); - state = 2; - goto case 2; - case 2: - if (_enumerator.MoveNext()) - { - current = _selector(_enumerator.Current); - return true; - } - Dispose(); - break; - } - return false; - } - - public override IEnumerable Select(Func selector) - { - return new SelectListIterator(_source, CombineSelectors(_selector, selector)); - } - - public TResult[] ToArray() - { - int count = _source.Count; - if (count == 0) - { - return Array.Empty(); - } - - var results = new TResult[count]; - for (int i = 0; i < results.Length; i++) - { - results[i] = _selector(_source[i]); - } - return results; - } - - public List ToList() - { - int count = _source.Count; - var results = new List(count); - for (int i = 0; i < count; i++) - { - results.Add(_selector(_source[i])); - } - return results; - } - } - - internal sealed class SelectIListIterator : Iterator, IArrayProvider, IListProvider - { - private readonly IList _source; - private readonly Func _selector; - private IEnumerator _enumerator; - - public SelectIListIterator(IList source, Func selector) - { - Debug.Assert(source != null); - Debug.Assert(selector != null); - _source = source; - _selector = selector; - } - - public override Iterator Clone() - { - return new SelectIListIterator(_source, _selector); - } - - public override bool MoveNext() - { - switch (state) - { - case 1: - _enumerator = _source.GetEnumerator(); - state = 2; - goto case 2; - case 2: - if (_enumerator.MoveNext()) - { - current = _selector(_enumerator.Current); - return true; - } - Dispose(); - break; - } - return false; - } - - public override void Dispose() - { - if (_enumerator != null) - { - _enumerator.Dispose(); - _enumerator = null; - } - base.Dispose(); - } - - public override IEnumerable Select(Func selector) - { - return new SelectIListIterator(_source, CombineSelectors(_selector, selector)); - } - - public TResult[] ToArray() - { - int count = _source.Count; - if (count == 0) - { - return Array.Empty(); - } - - var results = new TResult[count]; - for (int i = 0; i < results.Length; i++) - { - results[i] = _selector(_source[i]); - } - return results; - } - - public List ToList() - { - int count = _source.Count; - var results = new List(count); - for (int i = 0; i < count; i++) - { - results.Add(_selector(_source[i])); - } - return results; - } - } - - //public static IEnumerable Where(this IEnumerable source, Func predicate) { - // if (source == null) throw Error.ArgumentNull("source"); - // if (predicate == null) throw Error.ArgumentNull("predicate"); - // return WhereIterator(source, predicate); - //} - - //static IEnumerable WhereIterator(IEnumerable source, Func predicate) { - // foreach (TSource element in source) { - // if (predicate(element)) yield return element; - // } - //} - - //public static IEnumerable Select(this IEnumerable source, Func selector) { - // if (source == null) throw Error.ArgumentNull("source"); - // if (selector == null) throw Error.ArgumentNull("selector"); - // return SelectIterator(source, selector); - //} - - //static IEnumerable SelectIterator(IEnumerable source, Func selector) { - // foreach (TSource element in source) { - // yield return selector(element); - // } - //} - - public static IEnumerable SelectMany(this IEnumerable source, Func> selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - return SelectManyIterator(source, selector); - } - - private static IEnumerable SelectManyIterator(IEnumerable source, Func> selector) - { - foreach (TSource element in source) - { - foreach (TResult subElement in selector(element)) - { - yield return subElement; - } - } - } - - public static IEnumerable SelectMany(this IEnumerable source, Func> selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - return SelectManyIterator(source, selector); - } - - private static IEnumerable SelectManyIterator(IEnumerable source, Func> selector) - { - int index = -1; - foreach (TSource element in source) - { - checked { index++; } - foreach (TResult subElement in selector(element, index)) - { - yield return subElement; - } - } - } - public static IEnumerable SelectMany(this IEnumerable source, Func> collectionSelector, Func resultSelector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (collectionSelector == null) throw Error.ArgumentNull("collectionSelector"); - if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); - return SelectManyIterator(source, collectionSelector, resultSelector); - } - - private static IEnumerable SelectManyIterator(IEnumerable source, Func> collectionSelector, Func resultSelector) - { - int index = -1; - foreach (TSource element in source) - { - checked { index++; } - foreach (TCollection subElement in collectionSelector(element, index)) - { - yield return resultSelector(element, subElement); - } - } - } - - public static IEnumerable SelectMany(this IEnumerable source, Func> collectionSelector, Func resultSelector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (collectionSelector == null) throw Error.ArgumentNull("collectionSelector"); - if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); - return SelectManyIterator(source, collectionSelector, resultSelector); - } - - private static IEnumerable SelectManyIterator(IEnumerable source, Func> collectionSelector, Func resultSelector) - { - foreach (TSource element in source) - { - foreach (TCollection subElement in collectionSelector(element)) - { - yield return resultSelector(element, subElement); - } - } - } - - public static IEnumerable Take(this IEnumerable source, int count) - { - if (source == null) throw Error.ArgumentNull("source"); - if (count <= 0) return new EmptyPartition(); - IPartition partition = source as IPartition; - if (partition != null) return partition.Take(count); - return TakeIterator(source, count); - } - - private static IEnumerable TakeIterator(IEnumerable source, int count) - { - foreach (TSource element in source) - { - yield return element; - if (--count == 0) break; - } - } - - public static IEnumerable TakeWhile(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - return TakeWhileIterator(source, predicate); - } - - private static IEnumerable TakeWhileIterator(IEnumerable source, Func predicate) - { - foreach (TSource element in source) - { - if (!predicate(element)) break; - yield return element; - } - } - - public static IEnumerable TakeWhile(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - return TakeWhileIterator(source, predicate); - } - - private static IEnumerable TakeWhileIterator(IEnumerable source, Func predicate) - { - int index = -1; - foreach (TSource element in source) - { - checked { index++; } - if (!predicate(element, index)) break; - yield return element; - } - } - - public static IEnumerable Skip(this IEnumerable source, int count) - { - if (source == null) throw Error.ArgumentNull("source"); - if (count < 0) count = 0; - IPartition partition = source as IPartition; - if (partition != null) return partition.Skip(count); - IList sourceList = source as IList; - return sourceList != null ? SkipList(sourceList, count) : SkipIterator(source, count); - } - - private static IEnumerable SkipList(IList source, int count) - { - while (count < source.Count) - { - yield return source[count++]; - } - } - - private static IEnumerable SkipIterator(IEnumerable source, int count) - { - using (IEnumerator e = source.GetEnumerator()) - { - while (count > 0 && e.MoveNext()) count--; - if (count <= 0) - { - while (e.MoveNext()) yield return e.Current; - } - } - } - - public static IEnumerable SkipWhile(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - return SkipWhileIterator(source, predicate); - } - - private static IEnumerable SkipWhileIterator(IEnumerable source, Func predicate) - { - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - TSource element = e.Current; - if (!predicate(element)) - { - yield return element; - while (e.MoveNext()) - yield return e.Current; - yield break; - } - } - } - } - - public static IEnumerable SkipWhile(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - return SkipWhileIterator(source, predicate); - } - - private static IEnumerable SkipWhileIterator(IEnumerable source, Func predicate) - { - using (IEnumerator e = source.GetEnumerator()) - { - int index = -1; - while (e.MoveNext()) - { - checked { index++; } - TSource element = e.Current; - if (!predicate(element, index)) - { - yield return element; - while (e.MoveNext()) - yield return e.Current; - yield break; - } - } - } - } - - public static IEnumerable Join(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) - { - if (outer == null) throw Error.ArgumentNull("outer"); - if (inner == null) throw Error.ArgumentNull("inner"); - if (outerKeySelector == null) throw Error.ArgumentNull("outerKeySelector"); - if (innerKeySelector == null) throw Error.ArgumentNull("innerKeySelector"); - if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); - return JoinIterator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, null); - } - - public static IEnumerable Join(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) - { - if (outer == null) throw Error.ArgumentNull("outer"); - if (inner == null) throw Error.ArgumentNull("inner"); - if (outerKeySelector == null) throw Error.ArgumentNull("outerKeySelector"); - if (innerKeySelector == null) throw Error.ArgumentNull("innerKeySelector"); - if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); - return JoinIterator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); - } - - private static IEnumerable JoinIterator(IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) - { - Lookup lookup = Lookup.CreateForJoin(inner, innerKeySelector, comparer); - foreach (TOuter item in outer) - { - Grouping g = lookup.GetGrouping(outerKeySelector(item), false); - if (g != null) - { - for (int i = 0; i < g.count; i++) - { - yield return resultSelector(item, g.elements[i]); - } - } - } - } - - public static IEnumerable GroupJoin(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector) - { - if (outer == null) throw Error.ArgumentNull("outer"); - if (inner == null) throw Error.ArgumentNull("inner"); - if (outerKeySelector == null) throw Error.ArgumentNull("outerKeySelector"); - if (innerKeySelector == null) throw Error.ArgumentNull("innerKeySelector"); - if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); - return GroupJoinIterator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, null); - } - - public static IEnumerable GroupJoin(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer comparer) - { - if (outer == null) throw Error.ArgumentNull("outer"); - if (inner == null) throw Error.ArgumentNull("inner"); - if (outerKeySelector == null) throw Error.ArgumentNull("outerKeySelector"); - if (innerKeySelector == null) throw Error.ArgumentNull("innerKeySelector"); - if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); - return GroupJoinIterator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); - } - - private static IEnumerable GroupJoinIterator(IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer comparer) - { - using (IEnumerator e = outer.GetEnumerator()) - { - if (e.MoveNext()) - { - Lookup lookup = Lookup.CreateForJoin(inner, innerKeySelector, comparer); - do - { - TOuter item = e.Current; - yield return resultSelector(item, lookup[outerKeySelector(item)]); - } - while (e.MoveNext()); - } - } - } - - public static IOrderedEnumerable OrderBy(this IEnumerable source, Func keySelector) - { - return new OrderedEnumerable(source, keySelector, null, false); - } - - public static IOrderedEnumerable OrderBy(this IEnumerable source, Func keySelector, IComparer comparer) - { - return new OrderedEnumerable(source, keySelector, comparer, false); - } - - public static IOrderedEnumerable OrderByDescending(this IEnumerable source, Func keySelector) - { - return new OrderedEnumerable(source, keySelector, null, true); - } - - public static IOrderedEnumerable OrderByDescending(this IEnumerable source, Func keySelector, IComparer comparer) - { - return new OrderedEnumerable(source, keySelector, comparer, true); - } - - public static IOrderedEnumerable ThenBy(this IOrderedEnumerable source, Func keySelector) - { - if (source == null) throw Error.ArgumentNull("source"); - return source.CreateOrderedEnumerable(keySelector, null, false); - } - - public static IOrderedEnumerable ThenBy(this IOrderedEnumerable source, Func keySelector, IComparer comparer) - { - if (source == null) throw Error.ArgumentNull("source"); - return source.CreateOrderedEnumerable(keySelector, comparer, false); - } - - public static IOrderedEnumerable ThenByDescending(this IOrderedEnumerable source, Func keySelector) - { - if (source == null) throw Error.ArgumentNull("source"); - return source.CreateOrderedEnumerable(keySelector, null, true); - } - - public static IOrderedEnumerable ThenByDescending(this IOrderedEnumerable source, Func keySelector, IComparer comparer) - { - if (source == null) throw Error.ArgumentNull("source"); - return source.CreateOrderedEnumerable(keySelector, comparer, true); - } - - public static IEnumerable> GroupBy(this IEnumerable source, Func keySelector) - { - return new GroupedEnumerable(source, keySelector, IdentityFunction.Instance, null); - } - - public static IEnumerable> GroupBy(this IEnumerable source, Func keySelector, IEqualityComparer comparer) - { - return new GroupedEnumerable(source, keySelector, IdentityFunction.Instance, comparer); - } - - public static IEnumerable> GroupBy(this IEnumerable source, Func keySelector, Func elementSelector) - { - return new GroupedEnumerable(source, keySelector, elementSelector, null); - } - - public static IEnumerable> GroupBy(this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) - { - return new GroupedEnumerable(source, keySelector, elementSelector, comparer); - } - - public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func, TResult> resultSelector) - { - return new GroupedEnumerable(source, keySelector, IdentityFunction.Instance, resultSelector, null); - } - - public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector) - { - return new GroupedEnumerable(source, keySelector, elementSelector, resultSelector, null); - } - - public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer comparer) - { - return new GroupedEnumerable(source, keySelector, IdentityFunction.Instance, resultSelector, comparer); - } - - public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) - { - return new GroupedEnumerable(source, keySelector, elementSelector, resultSelector, comparer); - } - - public static IEnumerable Concat(this IEnumerable first, IEnumerable second) - { - if (first == null) throw Error.ArgumentNull("first"); - if (second == null) throw Error.ArgumentNull("second"); - return ConcatIterator(first, second); - } - - private static IEnumerable ConcatIterator(IEnumerable first, IEnumerable second) - { - foreach (TSource element in first) yield return element; - foreach (TSource element in second) yield return element; - } - - public static IEnumerable Append(this IEnumerable source, TSource element) - { - if (source == null) throw Error.ArgumentNull("source"); - return AppendIterator(source, element); - } - - private static IEnumerable AppendIterator(IEnumerable source, TSource element) - { - foreach (TSource e1 in source) yield return e1; - yield return element; - } - - public static IEnumerable Prepend(this IEnumerable source, TSource element) - { - if (source == null) throw Error.ArgumentNull("source"); - return PrependIterator(source, element); - } - - private static IEnumerable PrependIterator(IEnumerable source, TSource element) - { - yield return element; - foreach (TSource e1 in source) yield return e1; - } - - public static IEnumerable Zip(this IEnumerable first, IEnumerable second, Func resultSelector) - { - if (first == null) throw Error.ArgumentNull("first"); - if (second == null) throw Error.ArgumentNull("second"); - if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); - return ZipIterator(first, second, resultSelector); - } - - private static IEnumerable ZipIterator(IEnumerable first, IEnumerable second, Func resultSelector) - { - using (IEnumerator e1 = first.GetEnumerator()) - using (IEnumerator e2 = second.GetEnumerator()) - while (e1.MoveNext() && e2.MoveNext()) - yield return resultSelector(e1.Current, e2.Current); - } - - - public static IEnumerable Distinct(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - return DistinctIterator(source, null); - } - - public static IEnumerable Distinct(this IEnumerable source, IEqualityComparer comparer) - { - if (source == null) throw Error.ArgumentNull("source"); - return DistinctIterator(source, comparer); - } - - private static IEnumerable DistinctIterator(IEnumerable source, IEqualityComparer comparer) - { - Set set = new Set(comparer); - foreach (TSource element in source) - if (set.Add(element)) yield return element; - } - - public static IEnumerable Union(this IEnumerable first, IEnumerable second) - { - if (first == null) throw Error.ArgumentNull("first"); - if (second == null) throw Error.ArgumentNull("second"); - return UnionIterator(first, second, null); - } - - public static IEnumerable Union(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) - { - if (first == null) throw Error.ArgumentNull("first"); - if (second == null) throw Error.ArgumentNull("second"); - return UnionIterator(first, second, comparer); - } - - private static IEnumerable UnionIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) - { - Set set = new Set(comparer); - foreach (TSource element in first) - if (set.Add(element)) yield return element; - foreach (TSource element in second) - if (set.Add(element)) yield return element; - } - - public static IEnumerable Intersect(this IEnumerable first, IEnumerable second) - { - if (first == null) throw Error.ArgumentNull("first"); - if (second == null) throw Error.ArgumentNull("second"); - return IntersectIterator(first, second, null); - } - - public static IEnumerable Intersect(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) - { - if (first == null) throw Error.ArgumentNull("first"); - if (second == null) throw Error.ArgumentNull("second"); - return IntersectIterator(first, second, comparer); - } - - private static IEnumerable IntersectIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) - { - Set set = new Set(comparer); - foreach (TSource element in second) set.Add(element); - foreach (TSource element in first) - if (set.Remove(element)) yield return element; - } - - public static IEnumerable Except(this IEnumerable first, IEnumerable second) - { - if (first == null) throw Error.ArgumentNull("first"); - if (second == null) throw Error.ArgumentNull("second"); - return ExceptIterator(first, second, null); - } - - public static IEnumerable Except(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) - { - if (first == null) throw Error.ArgumentNull("first"); - if (second == null) throw Error.ArgumentNull("second"); - return ExceptIterator(first, second, comparer); - } - - private static IEnumerable ExceptIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) - { - Set set = new Set(comparer); - foreach (TSource element in second) set.Add(element); - foreach (TSource element in first) - if (set.Add(element)) yield return element; - } - - public static IEnumerable Reverse(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - return ReverseIterator(source); - } - - private static IEnumerable ReverseIterator(IEnumerable source) - { - Buffer buffer = new Buffer(source); - for (int i = buffer.count - 1; i >= 0; i--) yield return buffer.items[i]; - } - - public static bool SequenceEqual(this IEnumerable first, IEnumerable second) - { - return SequenceEqual(first, second, null); - } - - public static bool SequenceEqual(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) - { - if (comparer == null) comparer = EqualityComparer.Default; - if (first == null) throw Error.ArgumentNull("first"); - if (second == null) throw Error.ArgumentNull("second"); - - ICollection firstCol = first as ICollection; - if (firstCol != null) - { - ICollection secondCol = second as ICollection; - if (secondCol != null && firstCol.Count != secondCol.Count) return false; - } - - using (IEnumerator e1 = first.GetEnumerator()) - using (IEnumerator e2 = second.GetEnumerator()) - { - while (e1.MoveNext()) - { - if (!(e2.MoveNext() && comparer.Equals(e1.Current, e2.Current))) return false; - } - return !e2.MoveNext(); - } - } - - public static IEnumerable AsEnumerable(this IEnumerable source) - { - return source; - } - - public static TSource[] ToArray(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - IArrayProvider arrayProvider = source as IArrayProvider; - return arrayProvider != null ? arrayProvider.ToArray() : EnumerableHelpers.ToArray(source); - } - - public static List ToList(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - IListProvider listProvider = source as IListProvider; - return listProvider != null ? listProvider.ToList() : new List(source); - } - - public static Dictionary ToDictionary(this IEnumerable source, Func keySelector) - { - return ToDictionary(source, keySelector, null); - } - - public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, IEqualityComparer comparer) - { - if (source == null) throw Error.ArgumentNull("source"); - if (keySelector == null) throw Error.ArgumentNull("keySelector"); - - int capacity = 0; - ICollection collection = source as ICollection; - if (collection != null) - { - capacity = collection.Count; - if (capacity == 0) - return new Dictionary(comparer); - - TSource[] array = collection as TSource[]; - if (array != null) - return ToDictionary(array, keySelector, comparer); - List list = collection as List; - if (list != null) - return ToDictionary(list, keySelector, comparer); - } - - Dictionary d = new Dictionary(capacity, comparer); - foreach (TSource element in source) d.Add(keySelector(element), element); - return d; - } - - private static Dictionary ToDictionary(TSource[] source, Func keySelector, IEqualityComparer comparer) - { - Dictionary d = new Dictionary(source.Length, comparer); - for (int i = 0; i < source.Length; i++) d.Add(keySelector(source[i]), source[i]); - return d; - } - private static Dictionary ToDictionary(List source, Func keySelector, IEqualityComparer comparer) - { - Dictionary d = new Dictionary(source.Count, comparer); - foreach (TSource element in source) d.Add(keySelector(element), element); - return d; - } - - - public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, Func elementSelector) - { - return ToDictionary(source, keySelector, elementSelector, null); - } - - public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) - { - if (source == null) throw Error.ArgumentNull("source"); - if (keySelector == null) throw Error.ArgumentNull("keySelector"); - if (elementSelector == null) throw Error.ArgumentNull("elementSelector"); - - int capacity = 0; - ICollection collection = source as ICollection; - if (collection != null) - { - capacity = collection.Count; - if (capacity == 0) - return new Dictionary(comparer); - - TSource[] array = collection as TSource[]; - if (array != null) - return ToDictionary(array, keySelector, elementSelector, comparer); - List list = collection as List; - if (list != null) - return ToDictionary(list, keySelector, elementSelector, comparer); - } - - Dictionary d = new Dictionary(capacity, comparer); - foreach (TSource element in source) d.Add(keySelector(element), elementSelector(element)); - return d; - } - - private static Dictionary ToDictionary(TSource[] source, Func keySelector, Func elementSelector, IEqualityComparer comparer) - { - Dictionary d = new Dictionary(source.Length, comparer); - for (int i = 0; i < source.Length; i++) d.Add(keySelector(source[i]), elementSelector(source[i])); - return d; - } - private static Dictionary ToDictionary(List source, Func keySelector, Func elementSelector, IEqualityComparer comparer) - { - Dictionary d = new Dictionary(source.Count, comparer); - foreach (TSource element in source) d.Add(keySelector(element), elementSelector(element)); - return d; - } - - - public static ILookup ToLookup(this IEnumerable source, Func keySelector) - { - return Lookup.Create(source, keySelector, IdentityFunction.Instance, null); - } - - public static ILookup ToLookup(this IEnumerable source, Func keySelector, IEqualityComparer comparer) - { - return Lookup.Create(source, keySelector, IdentityFunction.Instance, comparer); - } - - public static ILookup ToLookup(this IEnumerable source, Func keySelector, Func elementSelector) - { - return Lookup.Create(source, keySelector, elementSelector, null); - } - - public static ILookup ToLookup(this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) - { - return Lookup.Create(source, keySelector, elementSelector, comparer); - } - - public static IEnumerable DefaultIfEmpty(this IEnumerable source) - { - return DefaultIfEmpty(source, default(TSource)); - } - - public static IEnumerable DefaultIfEmpty(this IEnumerable source, TSource defaultValue) - { - if (source == null) throw Error.ArgumentNull("source"); - return DefaultIfEmptyIterator(source, defaultValue); - } - - private static IEnumerable DefaultIfEmptyIterator(IEnumerable source, TSource defaultValue) - { - using (IEnumerator e = source.GetEnumerator()) - { - if (e.MoveNext()) - { - do - { - yield return e.Current; - } while (e.MoveNext()); - } - else - { - yield return defaultValue; - } - } - } - - public static IEnumerable OfType(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - return OfTypeIterator(source); - } - - private static IEnumerable OfTypeIterator(IEnumerable source) - { - foreach (object obj in source) - { - if (obj is TResult) yield return (TResult)obj; - } - } - - public static IEnumerable Cast(this IEnumerable source) - { - IEnumerable typedSource = source as IEnumerable; - if (typedSource != null) return typedSource; - if (source == null) throw Error.ArgumentNull("source"); - return CastIterator(source); - } - - private static IEnumerable CastIterator(IEnumerable source) - { - foreach (object obj in source) yield return (TResult)obj; - } - - public static TSource First(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - IPartition partition = source as IPartition; - if (partition != null) return partition.First(); - IList list = source as IList; - if (list != null) - { - if (list.Count > 0) return list[0]; - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (e.MoveNext()) return e.Current; - } - } - throw Error.NoElements(); - } - - public static TSource First(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - OrderedEnumerable ordered = source as OrderedEnumerable; - if (ordered != null) return ordered.First(predicate); - foreach (TSource element in source) - { - if (predicate(element)) return element; - } - throw Error.NoMatch(); - } - - public static TSource FirstOrDefault(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - IPartition partition = source as IPartition; - if (partition != null) return partition.FirstOrDefault(); - IList list = source as IList; - if (list != null) - { - if (list.Count > 0) return list[0]; - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (e.MoveNext()) return e.Current; - } - } - return default(TSource); - } - - public static TSource FirstOrDefault(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - OrderedEnumerable ordered = source as OrderedEnumerable; - if (ordered != null) return ordered.FirstOrDefault(predicate); - foreach (TSource element in source) - { - if (predicate(element)) return element; - } - return default(TSource); - } - - public static TSource Last(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - IPartition partition = source as IPartition; - if (partition != null) return partition.Last(); - IList list = source as IList; - if (list != null) - { - int count = list.Count; - if (count > 0) return list[count - 1]; - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (e.MoveNext()) - { - TSource result; - do - { - result = e.Current; - } while (e.MoveNext()); - return result; - } - } - } - throw Error.NoElements(); - } - - public static TSource Last(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - OrderedEnumerable ordered = source as OrderedEnumerable; - if (ordered != null) return ordered.Last(predicate); - IList list = source as IList; - if (list != null) - { - for (int i = list.Count - 1; i >= 0; --i) - { - TSource result = list[i]; - if (predicate(result)) return result; - } - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - TSource result = e.Current; - if (predicate(result)) - { - while (e.MoveNext()) - { - TSource element = e.Current; - if (predicate(element)) result = element; - } - return result; - } - } - } - } - throw Error.NoMatch(); - } - - public static TSource LastOrDefault(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - IPartition partition = source as IPartition; - if (partition != null) return partition.LastOrDefault(); - IList list = source as IList; - if (list != null) - { - int count = list.Count; - if (count > 0) return list[count - 1]; - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (e.MoveNext()) - { - TSource result; - do - { - result = e.Current; - } while (e.MoveNext()); - return result; - } - } - } - return default(TSource); - } - - public static TSource LastOrDefault(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - OrderedEnumerable ordered = source as OrderedEnumerable; - if (ordered != null) return ordered.LastOrDefault(predicate); - IList list = source as IList; - if (list != null) - { - for (int i = list.Count - 1; i >= 0; --i) - { - TSource element = list[i]; - if (predicate(element)) return element; - } - return default(TSource); - } - else - { - TSource result = default(TSource); - foreach (TSource element in source) - { - if (predicate(element)) - { - result = element; - } - } - return result; - } - } - - public static TSource Single(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - IList list = source as IList; - if (list != null) - { - switch (list.Count) - { - case 0: throw Error.NoElements(); - case 1: return list[0]; - } - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - TSource result = e.Current; - if (!e.MoveNext()) return result; - } - } - throw Error.MoreThanOneElement(); - } - - public static TSource Single(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - TSource result = e.Current; - if (predicate(result)) - { - while (e.MoveNext()) - { - if (predicate(e.Current)) throw Error.MoreThanOneMatch(); - } - return result; - } - } - } - throw Error.NoMatch(); - } - - public static TSource SingleOrDefault(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - IList list = source as IList; - if (list != null) - { - switch (list.Count) - { - case 0: return default(TSource); - case 1: return list[0]; - } - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) return default(TSource); - TSource result = e.Current; - if (!e.MoveNext()) return result; - } - } - throw Error.MoreThanOneElement(); - } - - public static TSource SingleOrDefault(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - TSource result = e.Current; - if (predicate(result)) - { - while (e.MoveNext()) - { - if (predicate(e.Current)) throw Error.MoreThanOneMatch(); - } - return result; - } - } - } - return default(TSource); - } - - public static TSource ElementAt(this IEnumerable source, int index) - { - if (source == null) throw Error.ArgumentNull("source"); - IPartition partition = source as IPartition; - if (partition != null) return partition.ElementAt(index); - IList list = source as IList; - if (list != null) return list[index]; - if (index >= 0) - { - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - if (index == 0) return e.Current; - index--; - } - } - } - throw Error.ArgumentOutOfRange("index"); - } - - public static TSource ElementAtOrDefault(this IEnumerable source, int index) - { - if (source == null) throw Error.ArgumentNull("source"); - IPartition partition = source as IPartition; - if (partition != null) return partition.ElementAtOrDefault(index); - if (index >= 0) - { - IList list = source as IList; - if (list != null) - { - if (index < list.Count) return list[index]; - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - if (index == 0) return e.Current; - index--; - } - } - } - } - return default(TSource); - } - - public static IEnumerable Range(int start, int count) - { - long max = ((long)start) + count - 1; - if (count < 0 || max > Int32.MaxValue) throw Error.ArgumentOutOfRange("count"); - if (count == 0) return new EmptyPartition(); - return new RangeIterator(start, count); - } - - private sealed class RangeIterator : Iterator, IArrayProvider, IListProvider, IPartition - { - private readonly int _start; - private readonly int _end; - - public RangeIterator(int start, int count) - { - Debug.Assert(count > 0); - _start = start; - _end = start + count; - } - - public override Iterator Clone() - { - return new RangeIterator(_start, _end - _start); - } - - public override bool MoveNext() - { - switch (state) - { - case 1: - Debug.Assert(_start != _end); - current = _start; - state = 2; - return true; - case 2: - if (++current == _end) break; - return true; - } - state = -1; - return false; - } - - public override void Dispose() - { - state = -1; // Don't reset current - } - - public int[] ToArray() - { - int[] array = new int[_end - _start]; - int cur = _start; - for (int i = 0; i != array.Length; ++i) - { - array[i] = cur; - ++cur; - } - - return array; - } - - public List ToList() - { - List list = new List(_end - _start); - for (int cur = _start; cur != _end; cur++) - { - list.Add(cur); - } - - return list; - } - - public IPartition Skip(int count) - { - if (count >= _end - _start) return new EmptyPartition(); - return new RangeIterator(_start + count, _end - _start - count); - } - - public IPartition Take(int count) - { - int curCount = _end - _start; - if (count > curCount) count = curCount; - return new RangeIterator(_start, count); - } - - public int ElementAt(int index) - { - if ((uint)index >= (uint)(_end - _start)) throw Error.ArgumentOutOfRange("index"); - return _start + index; - } - - public int ElementAtOrDefault(int index) - { - return (uint)index >= (uint)(_end - _start) ? 0 : _start + index; - } - - public int First() - { - return _start; - } - - public int FirstOrDefault() - { - return _start; - } - - public int Last() - { - return _end - 1; - } - - public int LastOrDefault() - { - return _end - 1; - } - } - - public static IEnumerable Repeat(TResult element, int count) - { - if (count < 0) throw Error.ArgumentOutOfRange("count"); - if (count == 0) return new EmptyPartition(); - return new RepeatIterator(element, count); - } - - private sealed class RepeatIterator : Iterator, IArrayProvider, IListProvider, IPartition - { - private readonly int _count; - private int _sent; - - public RepeatIterator(TResult element, int count) - { - Debug.Assert(count > 0); - current = element; - _count = count; - } - - public override Iterator Clone() - { - return new RepeatIterator(current, _count); - } - - public override void Dispose() - { - // Don't let base Dispose wipe current. - state = -1; - } - - public override bool MoveNext() - { - if (state == 1 & _sent != _count) - { - ++_sent; - return true; - } - state = -1; - return false; - } - - public TResult[] ToArray() - { - TResult[] array = new TResult[_count]; - if (current != null) - { - for (int i = 0; i != array.Length; ++i) array[i] = current; - } - - return array; - } - - public List ToList() - { - List list = new List(_count); - for (int i = 0; i != _count; ++i) list.Add(current); - - return list; - } - - public IPartition Skip(int count) - { - if (count >= _count) return new EmptyPartition(); - return new RepeatIterator(current, _count - count); - } - - public IPartition Take(int count) - { - if (count > _count) count = _count; - return new RepeatIterator(current, count); - } - - public TResult ElementAt(int index) - { - if ((uint)index >= (uint)_count) throw Error.ArgumentOutOfRange("index"); - return current; - } - - public TResult ElementAtOrDefault(int index) - { - return (uint)index >= (uint)_count ? default(TResult) : current; - } - - public TResult First() - { - return current; - } - - public TResult FirstOrDefault() - { - return current; - } - - public TResult Last() - { - return current; - } - - public TResult LastOrDefault() - { - return current; - } - } - - public static IEnumerable Empty() - { - return Array.Empty(); - } - - public static bool Any(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - return e.MoveNext(); - } - } - - public static bool Any(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - foreach (TSource element in source) - { - if (predicate(element)) return true; - } - return false; - } - - public static bool All(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - foreach (TSource element in source) - { - if (!predicate(element)) return false; - } - return true; - } - - public static int Count(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - ICollection collectionoft = source as ICollection; - if (collectionoft != null) return collectionoft.Count; - ICollection collection = source as ICollection; - if (collection != null) return collection.Count; - int count = 0; - using (IEnumerator e = source.GetEnumerator()) - { - checked - { - while (e.MoveNext()) count++; - } - } - return count; - } - - public static int Count(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - int count = 0; - foreach (TSource element in source) - { - checked - { - if (predicate(element)) count++; - } - } - return count; - } - - public static long LongCount(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - long count = 0; - using (IEnumerator e = source.GetEnumerator()) - { - checked - { - while (e.MoveNext()) count++; - } - } - return count; - } - - public static long LongCount(this IEnumerable source, Func predicate) - { - if (source == null) throw Error.ArgumentNull("source"); - if (predicate == null) throw Error.ArgumentNull("predicate"); - long count = 0; - foreach (TSource element in source) - { - checked - { - if (predicate(element)) count++; - } - } - return count; - } - - public static bool Contains(this IEnumerable source, TSource value) - { - ICollection collection = source as ICollection; - if (collection != null) return collection.Contains(value); - return Contains(source, value, null); - } - - public static bool Contains(this IEnumerable source, TSource value, IEqualityComparer comparer) - { - if (comparer == null) comparer = EqualityComparer.Default; - if (source == null) throw Error.ArgumentNull("source"); - foreach (TSource element in source) - if (comparer.Equals(element, value)) return true; - return false; - } - - public static TSource Aggregate(this IEnumerable source, Func func) - { - if (source == null) throw Error.ArgumentNull("source"); - if (func == null) throw Error.ArgumentNull("func"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - TSource result = e.Current; - while (e.MoveNext()) result = func(result, e.Current); - return result; - } - } - - public static TAccumulate Aggregate(this IEnumerable source, TAccumulate seed, Func func) - { - if (source == null) throw Error.ArgumentNull("source"); - if (func == null) throw Error.ArgumentNull("func"); - TAccumulate result = seed; - foreach (TSource element in source) result = func(result, element); - return result; - } - - public static TResult Aggregate(this IEnumerable source, TAccumulate seed, Func func, Func resultSelector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (func == null) throw Error.ArgumentNull("func"); - if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); - TAccumulate result = seed; - foreach (TSource element in source) result = func(result, element); - return resultSelector(result); - } - - public static int Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - int sum = 0; - checked - { - foreach (int v in source) sum += v; - } - return sum; - } - - public static int? Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - int sum = 0; - checked - { - foreach (int? v in source) - { - if (v != null) sum += v.GetValueOrDefault(); - } - } - return sum; - } - - public static long Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - long sum = 0; - checked - { - foreach (long v in source) sum += v; - } - return sum; - } - - public static long? Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - long sum = 0; - checked - { - foreach (long? v in source) - { - if (v != null) sum += v.GetValueOrDefault(); - } - } - return sum; - } - - public static float Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - double sum = 0; - foreach (float v in source) sum += v; - return (float)sum; - } - - public static float? Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - double sum = 0; - foreach (float? v in source) - { - if (v != null) sum += v.GetValueOrDefault(); - } - return (float)sum; - } - - public static double Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - double sum = 0; - foreach (double v in source) sum += v; - return sum; - } - - public static double? Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - double sum = 0; - foreach (double? v in source) - { - if (v != null) sum += v.GetValueOrDefault(); - } - return sum; - } - - public static decimal Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - decimal sum = 0; - foreach (decimal v in source) sum += v; - return sum; - } - - public static decimal? Sum(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - decimal sum = 0; - foreach (decimal? v in source) - { - if (v != null) sum += v.GetValueOrDefault(); - } - return sum; - } - - public static int Sum(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - int sum = 0; - checked - { - foreach (TSource item in source) sum += selector(item); - } - return sum; - } - - public static int? Sum(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - int sum = 0; - checked - { - foreach (TSource item in source) - { - int? v = selector(item); - if (v != null) sum += v.GetValueOrDefault(); - } - } - return sum; - } - - public static long Sum(this IEnumerable source, Func selector) - { - if (selector == null) throw Error.ArgumentNull("selector"); - if (source == null) throw Error.ArgumentNull("source"); - long sum = 0; - checked - { - foreach (TSource item in source) sum += selector(item); - } - return sum; - } - - public static long? Sum(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - long sum = 0; - checked - { - foreach (TSource item in source) - { - long? v = selector(item); - if (v != null) sum += v.GetValueOrDefault(); - } - } - return sum; - } - - public static float Sum(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - double sum = 0; - foreach (TSource item in source) sum += selector(item); - return (float)sum; - } - - public static float? Sum(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - double sum = 0; - foreach (TSource item in source) - { - float? v = selector(item); - if (v != null) sum += v.GetValueOrDefault(); - } - return (float)sum; - } - - public static double Sum(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - double sum = 0; - foreach (TSource item in source) sum += selector(item); - return sum; - } - - public static double? Sum(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - double sum = 0; - foreach (TSource item in source) - { - double? v = selector(item); - if (v != null) sum += v.GetValueOrDefault(); - } - return sum; - } - - public static decimal Sum(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - decimal sum = 0; - foreach (TSource item in source) sum += selector(item); - return sum; - } - - public static decimal? Sum(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - decimal sum = 0; - foreach (TSource item in source) - { - decimal? v = selector(item); - if (v != null) sum += v.GetValueOrDefault(); - } - return sum; - } - - public static int Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - int value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - int x = e.Current; - if (x < value) value = x; - } - } - return value; - } - - public static int? Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - int? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - // Start off knowing that we've a non-null value (or exit here, knowing we don't) - // so we don't have to keep testing for nullity. - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - // Keep hold of the wrapped value, and do comparisons on that, rather than - // using the lifted operation each time. - int valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - int? cur = e.Current; - int x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x < valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static long Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - long value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - long x = e.Current; - if (x < value) value = x; - } - } - return value; - } - - public static long? Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - long? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - long valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - long? cur = e.Current; - long x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x < valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static float Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - float value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - float x = e.Current; - if (x < value) value = x; - // Normally NaN < anything is false, as is anything < NaN - // However, this leads to some irksome outcomes in Min and Max. - // If we use those semantics then Min(NaN, 5.0) is NaN, but - // Min(5.0, NaN) is 5.0! To fix this, we impose a total - // ordering where NaN is smaller than every value, including - // negative infinity. - // Not testing for NaN therefore isn't an option, but since we - // can't find a smaller value, we can short-circuit. - else if (float.IsNaN(x)) return x; - } - } - return value; - } - - public static float? Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - float? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - float valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - float? cur = e.Current; - if (cur.HasValue) - { - float x = cur.GetValueOrDefault(); - if (x < valueVal) - { - valueVal = x; - value = cur; - } - else if (float.IsNaN(x)) return cur; - } - } - } - return value; - } - - public static double Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - double value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - double x = e.Current; - if (x < value) value = x; - else if (double.IsNaN(x)) return x; - } - } - return value; - } - - public static double? Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - double? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - double valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - double? cur = e.Current; - if (cur.HasValue) - { - double x = cur.GetValueOrDefault(); - if (x < valueVal) - { - valueVal = x; - value = cur; - } - else if (double.IsNaN(x)) return cur; - } - } - } - return value; - } - - public static decimal Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - decimal value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - decimal x = e.Current; - if (x < value) value = x; - } - } - return value; - } - - public static decimal? Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - decimal? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - decimal valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - decimal? cur = e.Current; - decimal x = cur.GetValueOrDefault(); - if (cur.HasValue && x < valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static TSource Min(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - Comparer comparer = Comparer.Default; - TSource value = default(TSource); - if (value == null) - { - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (value == null); - while (e.MoveNext()) - { - TSource x = e.Current; - if (x != null && comparer.Compare(x, value) < 0) value = x; - } - } - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - TSource x = e.Current; - if (comparer.Compare(x, value) < 0) value = x; - } - } - } - return value; - } - - public static int Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - int value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - int x = selector(e.Current); - if (x < value) value = x; - } - } - return value; - } - - public static int? Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - int? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - // Start off knowing that we've a non-null value (or exit here, knowing we don't) - // so we don't have to keep testing for nullity. - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - // Keep hold of the wrapped value, and do comparisons on that, rather than - // using the lifted operation each time. - int valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - int? cur = selector(e.Current); - int x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x < valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static long Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - long value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - long x = selector(e.Current); - if (x < value) value = x; - } - } - return value; - } - - public static long? Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - long? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - long valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - long? cur = selector(e.Current); - long x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x < valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static float Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - float value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - float x = selector(e.Current); - if (x < value) value = x; - // Normally NaN < anything is false, as is anything < NaN - // However, this leads to some irksome outcomes in Min and Max. - // If we use those semantics then Min(NaN, 5.0) is NaN, but - // Min(5.0, NaN) is 5.0! To fix this, we impose a total - // ordering where NaN is smaller than every value, including - // negative infinity. - // Not testing for NaN therefore isn't an option, but since we - // can't find a smaller value, we can short-circuit. - else if (float.IsNaN(x)) return x; - } - } - return value; - } - - public static float? Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - float? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - float valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - float? cur = selector(e.Current); - if (cur.HasValue) - { - float x = cur.GetValueOrDefault(); - if (x < valueVal) - { - valueVal = x; - value = cur; - } - else if (float.IsNaN(x)) return cur; - } - } - } - return value; - } - - public static double Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - double value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - double x = selector(e.Current); - if (x < value) value = x; - else if (double.IsNaN(x)) return x; - } - } - return value; - } - - public static double? Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - double? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - double valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - double? cur = selector(e.Current); - if (cur.HasValue) - { - double x = cur.GetValueOrDefault(); - if (x < valueVal) - { - valueVal = x; - value = cur; - } - else if (double.IsNaN(x)) return cur; - } - } - } - return value; - } - - public static decimal Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - decimal value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - decimal x = selector(e.Current); - if (x < value) value = x; - } - } - return value; - } - - public static decimal? Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - decimal? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - decimal valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - decimal? cur = selector(e.Current); - decimal x = cur.GetValueOrDefault(); - if (cur.HasValue && x < valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static TResult Min(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - Comparer comparer = Comparer.Default; - TResult value = default(TResult); - if (value == null) - { - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (value == null); - while (e.MoveNext()) - { - TResult x = selector(e.Current); - if (x != null && comparer.Compare(x, value) < 0) value = x; - } - } - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - TResult x = selector(e.Current); - if (comparer.Compare(x, value) < 0) value = x; - } - } - } - return value; - } - - public static int Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - int value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - int x = e.Current; - if (x > value) value = x; - } - } - return value; - } - - public static int? Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - int? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - int valueVal = value.GetValueOrDefault(); - if (valueVal >= 0) - { - // We can fast-path this case where we know HasValue will - // never affect the outcome, without constantly checking - // if we're in such a state. Similar fast-paths could - // be done for other cases, but as all-positive - // or mostly-positive integer values are quite common in real-world - // uses, it's only been done in this direction for int? and long?. - while (e.MoveNext()) - { - int? cur = e.Current; - int x = cur.GetValueOrDefault(); - if (x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - else - { - while (e.MoveNext()) - { - int? cur = e.Current; - int x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - } - return value; - } - - public static long Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - long value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - long x = e.Current; - if (x > value) value = x; - } - } - return value; - } - - public static long? Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - long? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - long valueVal = value.GetValueOrDefault(); - if (valueVal >= 0) - { - while (e.MoveNext()) - { - long? cur = e.Current; - long x = cur.GetValueOrDefault(); - if (x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - else - { - while (e.MoveNext()) - { - long? cur = e.Current; - long x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - } - return value; - } - - public static double Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - double value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - // As described in a comment on Min(this IEnumerable) NaN is ordered - // less than all other values. We need to do explicit checks to ensure this, but - // once we've found a value that is not NaN we need no longer worry about it, - // so first loop until such a value is found (or not, as the case may be). - while (double.IsNaN(value)) - { - if (!e.MoveNext()) return value; - value = e.Current; - } - while (e.MoveNext()) - { - double x = e.Current; - if (x > value) value = x; - } - } - return value; - } - - public static double? Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - double? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - double valueVal = value.GetValueOrDefault(); - while (double.IsNaN(valueVal)) - { - if (!e.MoveNext()) return value; - double? cur = e.Current; - if (cur.HasValue) valueVal = (value = cur).GetValueOrDefault(); - } - while (e.MoveNext()) - { - double? cur = e.Current; - double x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static float Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - float value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (float.IsNaN(value)) - { - if (!e.MoveNext()) return value; - value = e.Current; - } - while (e.MoveNext()) - { - float x = e.Current; - if (x > value) value = x; - } - } - return value; - } - - public static float? Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - float? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - float valueVal = value.GetValueOrDefault(); - while (float.IsNaN(valueVal)) - { - if (!e.MoveNext()) return value; - float? cur = e.Current; - if (cur.HasValue) valueVal = (value = cur).GetValueOrDefault(); - } - while (e.MoveNext()) - { - float? cur = e.Current; - float x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static decimal Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - decimal value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - decimal x = e.Current; - if (x > value) value = x; - } - } - return value; - } - - public static decimal? Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - decimal? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (!value.HasValue); - decimal valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - decimal? cur = e.Current; - decimal x = cur.GetValueOrDefault(); - if (cur.HasValue && x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static TSource Max(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - Comparer comparer = Comparer.Default; - TSource value = default(TSource); - if (value == null) - { - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = e.Current; - } while (value == null); - while (e.MoveNext()) - { - TSource x = e.Current; - if (x != null && comparer.Compare(x, value) > 0) value = x; - } - } - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = e.Current; - while (e.MoveNext()) - { - TSource x = e.Current; - if (comparer.Compare(x, value) > 0) value = x; - } - } - } - return value; - } - - public static int Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - int value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - int x = selector(e.Current); - if (x > value) value = x; - } - } - return value; - } - - public static int? Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - int? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - int valueVal = value.GetValueOrDefault(); - if (valueVal >= 0) - { - // We can fast-path this case where we know HasValue will - // never affect the outcome, without constantly checking - // if we're in such a state. Similar fast-paths could - // be done for other cases, but as all-positive - // or mostly-positive integer values are quite common in real-world - // uses, it's only been done in this direction for int? and long?. - while (e.MoveNext()) - { - int? cur = selector(e.Current); - int x = cur.GetValueOrDefault(); - if (x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - else - { - while (e.MoveNext()) - { - int? cur = selector(e.Current); - int x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - } - return value; - } - - public static long Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - long value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - long x = selector(e.Current); - if (x > value) value = x; - } - } - return value; - } - - public static long? Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - long? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - long valueVal = value.GetValueOrDefault(); - if (valueVal >= 0) - { - while (e.MoveNext()) - { - long? cur = selector(e.Current); - long x = cur.GetValueOrDefault(); - if (x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - else - { - while (e.MoveNext()) - { - long? cur = selector(e.Current); - long x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - } - return value; - } - - public static float Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - float value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (float.IsNaN(value)) - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } - while (e.MoveNext()) - { - float x = selector(e.Current); - if (x > value) value = x; - } - } - return value; - } - - public static float? Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - float? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - float valueVal = value.GetValueOrDefault(); - while (float.IsNaN(valueVal)) - { - if (!e.MoveNext()) return value; - float? cur = selector(e.Current); - if (cur.HasValue) valueVal = (value = cur).GetValueOrDefault(); - } - while (e.MoveNext()) - { - float? cur = selector(e.Current); - float x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static double Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - double value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - // As described in a comment on Min(this IEnumerable) NaN is ordered - // less than all other values. We need to do explicit checks to ensure this, but - // once we've found a value that is not NaN we need no longer worry about it, - // so first loop until such a value is found (or not, as the case may be). - while (double.IsNaN(value)) - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } - while (e.MoveNext()) - { - double x = selector(e.Current); - if (x > value) value = x; - } - } - return value; - } - - public static double? Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - double? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - double valueVal = value.GetValueOrDefault(); - while (double.IsNaN(valueVal)) - { - if (!e.MoveNext()) return value; - double? cur = selector(e.Current); - if (cur.HasValue) valueVal = (value = cur).GetValueOrDefault(); - } - while (e.MoveNext()) - { - double? cur = selector(e.Current); - double x = cur.GetValueOrDefault(); - // Do not replace & with &&. The branch prediction cost outweighs the extra operation - // unless nulls either never happen or always happen. - if (cur.HasValue & x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static decimal Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - decimal value; - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - decimal x = selector(e.Current); - if (x > value) value = x; - } - } - return value; - } - - public static decimal? Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - decimal? value = null; - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (!value.HasValue); - decimal valueVal = value.GetValueOrDefault(); - while (e.MoveNext()) - { - decimal? cur = selector(e.Current); - decimal x = cur.GetValueOrDefault(); - if (cur.HasValue && x > valueVal) - { - valueVal = x; - value = cur; - } - } - } - return value; - } - - public static TResult Max(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - Comparer comparer = Comparer.Default; - TResult value = default(TResult); - if (value == null) - { - using (IEnumerator e = source.GetEnumerator()) - { - do - { - if (!e.MoveNext()) return value; - value = selector(e.Current); - } while (value == null); - while (e.MoveNext()) - { - TResult x = selector(e.Current); - if (x != null && comparer.Compare(x, value) > 0) value = x; - } - } - } - else - { - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - value = selector(e.Current); - while (e.MoveNext()) - { - TResult x = selector(e.Current); - if (comparer.Compare(x, value) > 0) value = x; - } - } - } - return value; - } - - public static double Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - long sum = e.Current; - long count = 1; - checked - { - while (e.MoveNext()) - { - sum += e.Current; - ++count; - } - } - return (double)sum / count; - } - } - - public static double? Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - int? v = e.Current; - if (v.HasValue) - { - long sum = v.GetValueOrDefault(); - long count = 1; - checked - { - while (e.MoveNext()) - { - v = e.Current; - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - } - return (double)sum / count; - } - } - } - return null; - } - - public static double Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - long sum = e.Current; - long count = 1; - checked - { - while (e.MoveNext()) - { - sum += e.Current; - ++count; - } - } - return (double)sum / count; - } - } - - public static double? Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - long? v = e.Current; - if (v.HasValue) - { - long sum = v.GetValueOrDefault(); - long count = 1; - checked - { - while (e.MoveNext()) - { - v = e.Current; - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - } - return (double)sum / count; - } - } - } - return null; - } - - public static float Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - double sum = e.Current; - long count = 1; - while (e.MoveNext()) - { - sum += e.Current; - ++count; - } - return (float)(sum / count); - } - } - - public static float? Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - float? v = e.Current; - if (v.HasValue) - { - double sum = v.GetValueOrDefault(); - long count = 1; - checked - { - while (e.MoveNext()) - { - v = e.Current; - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - } - return (float)(sum / count); - } - } - } - return null; - } - - public static double Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - double sum = e.Current; - long count = 1; - while (e.MoveNext()) - { - // There is an opportunity to short-circuit here, in that if e.Current is - // ever NaN then the result will always be NaN. Assuming that this case is - // rare enough that not checking is the better approach generally. - sum += e.Current; - ++count; - } - return sum / count; - } - } - - public static double? Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - double? v = e.Current; - if (v.HasValue) - { - double sum = v.GetValueOrDefault(); - long count = 1; - checked - { - while (e.MoveNext()) - { - v = e.Current; - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - } - return sum / count; - } - } - } - return null; - } - - public static decimal Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - decimal sum = e.Current; - long count = 1; - while (e.MoveNext()) - { - sum += e.Current; - ++count; - } - return sum / count; - } - } - - public static decimal? Average(this IEnumerable source) - { - if (source == null) throw Error.ArgumentNull("source"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - decimal? v = e.Current; - if (v.HasValue) - { - decimal sum = v.GetValueOrDefault(); - long count = 1; - while (e.MoveNext()) - { - v = e.Current; - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - return sum / count; - } - } - } - return null; - } - - public static double Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - long sum = selector(e.Current); - long count = 1; - checked - { - while (e.MoveNext()) - { - sum += selector(e.Current); - ++count; - } - } - return (double)sum / count; - } - } - - public static double? Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - int? v = selector(e.Current); - if (v.HasValue) - { - long sum = v.GetValueOrDefault(); - long count = 1; - checked - { - while (e.MoveNext()) - { - v = selector(e.Current); - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - } - return (double)sum / count; - } - } - } - return null; - } - - public static double Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - long sum = selector(e.Current); - long count = 1; - checked - { - while (e.MoveNext()) - { - sum += selector(e.Current); - ++count; - } - } - return (double)sum / count; - } - } - - public static double? Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - long? v = selector(e.Current); - if (v.HasValue) - { - long sum = v.GetValueOrDefault(); - long count = 1; - checked - { - while (e.MoveNext()) - { - v = selector(e.Current); - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - } - return (double)sum / count; - } - } - } - return null; - } - - public static float Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - double sum = selector(e.Current); - long count = 1; - while (e.MoveNext()) - { - sum += selector(e.Current); - ++count; - } - return (float)(sum / count); - } - } - - public static float? Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - float? v = selector(e.Current); - if (v.HasValue) - { - double sum = v.GetValueOrDefault(); - long count = 1; - checked - { - while (e.MoveNext()) - { - v = selector(e.Current); - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - } - return (float)(sum / count); - } - } - } - return null; - } - - public static double Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - double sum = selector(e.Current); - long count = 1; - while (e.MoveNext()) - { - // There is an opportunity to short-circuit here, in that if e.Current is - // ever NaN then the result will always be NaN. Assuming that this case is - // rare enough that not checking is the better approach generally. - sum += selector(e.Current); - ++count; - } - return sum / count; - } - } - - public static double? Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - double? v = selector(e.Current); - if (v.HasValue) - { - double sum = v.GetValueOrDefault(); - long count = 1; - checked - { - while (e.MoveNext()) - { - v = selector(e.Current); - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - } - return sum / count; - } - } - } - return null; - } - - public static decimal Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - decimal sum = selector(e.Current); - long count = 1; - while (e.MoveNext()) - { - sum += selector(e.Current); - ++count; - } - return sum / count; - } - } - - public static decimal? Average(this IEnumerable source, Func selector) - { - if (source == null) throw Error.ArgumentNull("source"); - if (selector == null) throw Error.ArgumentNull("selector"); - using (IEnumerator e = source.GetEnumerator()) - { - while (e.MoveNext()) - { - decimal? v = selector(e.Current); - if (v.HasValue) - { - decimal sum = v.GetValueOrDefault(); - long count = 1; - while (e.MoveNext()) - { - v = selector(e.Current); - if (v.HasValue) - { - sum += v.GetValueOrDefault(); - ++count; - } - } - return sum / count; - } - } - } - return null; - } - } - - /// - /// An iterator that can produce an array through an optimized path. - /// - internal interface IArrayProvider - { - /// - /// Produce an array of the sequence through an optimized path. - /// - /// The array. - TElement[] ToArray(); - } - - /// - /// An iterator that can produce a through an optimized path. - /// - internal interface IListProvider - { - /// - /// Produce a of the sequence through an optimized path. - /// - /// The . - List ToList(); - } - - internal class IdentityFunction - { - public static Func Instance - { - get { return x => x; } - } - } - - public interface IOrderedEnumerable : IEnumerable - { - IOrderedEnumerable CreateOrderedEnumerable(Func keySelector, IComparer comparer, bool descending); - } - - public interface IGrouping : IEnumerable - { - TKey Key { get; } - } - - public interface ILookup : IEnumerable> - { - int Count { get; } - IEnumerable this[TKey key] { get; } - bool Contains(TKey key); - } - - public class Lookup : IEnumerable>, ILookup, IArrayProvider>, IListProvider> - { - private IEqualityComparer _comparer; - private Grouping[] _groupings; - private Grouping _lastGrouping; - private int _count; - - internal static Lookup Create(IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) - { - if (source == null) throw Error.ArgumentNull("source"); - if (keySelector == null) throw Error.ArgumentNull("keySelector"); - if (elementSelector == null) throw Error.ArgumentNull("elementSelector"); - Lookup lookup = new Lookup(comparer); - foreach (TSource item in source) - { - lookup.GetGrouping(keySelector(item), true).Add(elementSelector(item)); - } - return lookup; - } - - internal static Lookup CreateForJoin(IEnumerable source, Func keySelector, IEqualityComparer comparer) - { - Lookup lookup = new Lookup(comparer); - foreach (TElement item in source) - { - TKey key = keySelector(item); - if (key != null) lookup.GetGrouping(key, true).Add(item); - } - return lookup; - } - - private Lookup(IEqualityComparer comparer) - { - if (comparer == null) comparer = EqualityComparer.Default; - _comparer = comparer; - _groupings = new Grouping[7]; - } - - public int Count - { - get { return _count; } - } - - public IEnumerable this[TKey key] - { - get - { - Grouping grouping = GetGrouping(key, false); - if (grouping != null) return grouping; - return Array.Empty(); - } - } - - public bool Contains(TKey key) - { - return GetGrouping(key, false) != null; - } - - public IEnumerator> GetEnumerator() - { - Grouping g = _lastGrouping; - if (g != null) - { - do - { - g = g.next; - yield return g; - } while (g != _lastGrouping); - } - } - - IGrouping[] IArrayProvider>.ToArray() - { - IGrouping[] array = new IGrouping[_count]; - int index = 0; - Grouping g = _lastGrouping; - if (g != null) - { - do - { - g = g.next; - array[index] = g; - ++index; - } while (g != _lastGrouping); - } - return array; - } - - List> IListProvider>.ToList() - { - List> list = new List>(_count); - Grouping g = _lastGrouping; - if (g != null) - { - do - { - g = g.next; - list.Add(g); - } while (g != _lastGrouping); - } - return list; - } - - public IEnumerable ApplyResultSelector(Func, TResult> resultSelector) - { - Grouping g = _lastGrouping; - if (g != null) - { - do - { - g = g.next; - if (g.count != g.elements.Length) { Array.Resize(ref g.elements, g.count); } - yield return resultSelector(g.key, g.elements); - } while (g != _lastGrouping); - } - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - - internal int InternalGetHashCode(TKey key) - { - // Handle comparer implementations that throw when passed null - return (key == null) ? 0 : _comparer.GetHashCode(key) & 0x7FFFFFFF; - } - - internal Grouping GetGrouping(TKey key, bool create) - { - int hashCode = InternalGetHashCode(key); - for (Grouping g = _groupings[hashCode % _groupings.Length]; g != null; g = g.hashNext) - if (g.hashCode == hashCode && _comparer.Equals(g.key, key)) return g; - if (create) - { - if (_count == _groupings.Length) Resize(); - int index = hashCode % _groupings.Length; - Grouping g = new Grouping(); - g.key = key; - g.hashCode = hashCode; - g.elements = new TElement[1]; - g.hashNext = _groupings[index]; - _groupings[index] = g; - if (_lastGrouping == null) - { - g.next = g; - } - else - { - g.next = _lastGrouping.next; - _lastGrouping.next = g; - } - _lastGrouping = g; - _count++; - return g; - } - return null; - } - - private void Resize() - { - int newSize = checked(_count * 2 + 1); - Grouping[] newGroupings = new Grouping[newSize]; - Grouping g = _lastGrouping; - do - { - g = g.next; - int index = g.hashCode % newSize; - g.hashNext = newGroupings[index]; - newGroupings[index] = g; - } while (g != _lastGrouping); - _groupings = newGroupings; - } - } - - // - // It is (unfortunately) common to databind directly to Grouping.Key. - // Because of this, we have to declare this internal type public so that we - // can mark the Key property for public reflection. - // - // To limit the damage, the toolchain makes this type appear in a hidden assembly. - // (This is also why it is no longer a nested type of Lookup<,>). - // - public class Grouping : IGrouping, IList - { - internal TKey key; - internal int hashCode; - internal TElement[] elements; - internal int count; - internal Grouping hashNext; - internal Grouping next; - - internal Grouping() - { - } - - internal void Add(TElement element) - { - if (elements.Length == count) Array.Resize(ref elements, checked(count * 2)); - elements[count] = element; - count++; - } - - public IEnumerator GetEnumerator() - { - for (int i = 0; i < count; i++) yield return elements[i]; - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - - // DDB195907: implement IGrouping<>.Key implicitly - // so that WPF binding works on this property. - public TKey Key - { - get { return key; } - } - - int ICollection.Count - { - get { return count; } - } - - bool ICollection.IsReadOnly - { - get { return true; } - } - - void ICollection.Add(TElement item) - { - throw Error.NotSupported(); - } - - void ICollection.Clear() - { - throw Error.NotSupported(); - } - - bool ICollection.Contains(TElement item) - { - return Array.IndexOf(elements, item, 0, count) >= 0; - } - - void ICollection.CopyTo(TElement[] array, int arrayIndex) - { - Array.Copy(elements, 0, array, arrayIndex, count); - } - - bool ICollection.Remove(TElement item) - { - throw Error.NotSupported(); - } - - int IList.IndexOf(TElement item) - { - return Array.IndexOf(elements, item, 0, count); - } - - void IList.Insert(int index, TElement item) - { - throw Error.NotSupported(); - } - - void IList.RemoveAt(int index) - { - throw Error.NotSupported(); - } - - TElement IList.this[int index] - { - get - { - if (index < 0 || index >= count) throw Error.ArgumentOutOfRange("index"); - return elements[index]; - } - set - { - throw Error.NotSupported(); - } - } - } - - internal class Set - { - private int[] _buckets; - private Slot[] _slots; - private int _count; - private readonly IEqualityComparer _comparer; -#if DEBUG - private bool _haveRemoved; -#endif - - public Set(IEqualityComparer comparer) - { - if (comparer == null) comparer = EqualityComparer.Default; - _comparer = comparer; - _buckets = new int[7]; - _slots = new Slot[7]; - } - - // If value is not in set, add it and return true; otherwise return false - public bool Add(TElement value) - { -#if DEBUG - Debug.Assert(!_haveRemoved, "This class is optimised for never calling Add after Remove. If your changes need to do so, undo that optimization."); -#endif - int hashCode = InternalGetHashCode(value); - for (int i = _buckets[hashCode % _buckets.Length] - 1; i >= 0; i = _slots[i].next) - { - if (_slots[i].hashCode == hashCode && _comparer.Equals(_slots[i].value, value)) return false; - } - if (_count == _slots.Length) Resize(); - int index = _count; - _count++; - int bucket = hashCode % _buckets.Length; - _slots[index].hashCode = hashCode; - _slots[index].value = value; - _slots[index].next = _buckets[bucket] - 1; - _buckets[bucket] = index + 1; - return true; - } - - // If value is in set, remove it and return true; otherwise return false - public bool Remove(TElement value) - { -#if DEBUG - _haveRemoved = true; -#endif - int hashCode = InternalGetHashCode(value); - int bucket = hashCode % _buckets.Length; - int last = -1; - for (int i = _buckets[bucket] - 1; i >= 0; last = i, i = _slots[i].next) - { - if (_slots[i].hashCode == hashCode && _comparer.Equals(_slots[i].value, value)) - { - if (last < 0) - { - _buckets[bucket] = _slots[i].next + 1; - } - else - { - _slots[last].next = _slots[i].next; - } - _slots[i].hashCode = -1; - _slots[i].value = default(TElement); - _slots[i].next = -1; - return true; - } - } - return false; - } - - private void Resize() - { - int newSize = checked(_count * 2 + 1); - int[] newBuckets = new int[newSize]; - Slot[] newSlots = new Slot[newSize]; - Array.Copy(_slots, 0, newSlots, 0, _count); - for (int i = 0; i < _count; i++) - { - int bucket = newSlots[i].hashCode % newSize; - newSlots[i].next = newBuckets[bucket] - 1; - newBuckets[bucket] = i + 1; - } - _buckets = newBuckets; - _slots = newSlots; - } - - internal int InternalGetHashCode(TElement value) - { - // Handle comparer implementations that throw when passed null - return (value == null) ? 0 : _comparer.GetHashCode(value) & 0x7FFFFFFF; - } - - internal struct Slot - { - internal int hashCode; - internal int next; - internal TElement value; - } - } - - internal class GroupedEnumerable : IEnumerable - { - private IEnumerable _source; - private Func _keySelector; - private Func _elementSelector; - private IEqualityComparer _comparer; - private Func, TResult> _resultSelector; - - public GroupedEnumerable(IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) - { - if (source == null) throw Error.ArgumentNull("source"); - if (keySelector == null) throw Error.ArgumentNull("keySelector"); - if (elementSelector == null) throw Error.ArgumentNull("elementSelector"); - if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); - _source = source; - _keySelector = keySelector; - _elementSelector = elementSelector; - _comparer = comparer; - _resultSelector = resultSelector; - } - - public IEnumerator GetEnumerator() - { - Lookup lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); - return lookup.ApplyResultSelector(_resultSelector).GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - } - - internal class GroupedEnumerable : IEnumerable>, IArrayProvider>, IListProvider> - { - private IEnumerable _source; - private Func _keySelector; - private Func _elementSelector; - private IEqualityComparer _comparer; - - public GroupedEnumerable(IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) - { - if (source == null) throw Error.ArgumentNull("source"); - if (keySelector == null) throw Error.ArgumentNull("keySelector"); - if (elementSelector == null) throw Error.ArgumentNull("elementSelector"); - _source = source; - _keySelector = keySelector; - _elementSelector = elementSelector; - _comparer = comparer; - } - - public IEnumerator> GetEnumerator() - { - return Lookup.Create(_source, _keySelector, _elementSelector, _comparer).GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - - public IGrouping[] ToArray() - { - IArrayProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); - return lookup.ToArray(); - } - - public List> ToList() - { - IListProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); - return lookup.ToList(); - } - } - - internal interface IPartition : IEnumerable, IArrayProvider - { - IPartition Skip(int count); - - IPartition Take(int count); - - TElement ElementAt(int index); - - TElement ElementAtOrDefault(int index); - - TElement First(); - - TElement FirstOrDefault(); - - TElement Last(); - - TElement LastOrDefault(); - } - - internal sealed class EmptyPartition : IPartition, IListProvider, IEnumerator - { - public EmptyPartition() - { - } - - public IEnumerator GetEnumerator() - { - return this; - } - - IEnumerator IEnumerable.GetEnumerator() - { - return this; - } - - public bool MoveNext() - { - return false; - } - - [ExcludeFromCodeCoverage] // Shouldn't be called, and as undefined can return or throw anything anyway. - public TElement Current - { - get { return default(TElement); } - } - - [ExcludeFromCodeCoverage] // Shouldn't be called, and as undefined can return or throw anything anyway. - object IEnumerator.Current - { - get { return default(TElement); } - } - - void IEnumerator.Reset() - { - throw Error.NotSupported(); - } - - void IDisposable.Dispose() - { - // Do nothing. - } - - public IPartition Skip(int count) - { - return new EmptyPartition(); - } - - public IPartition Take(int count) - { - return new EmptyPartition(); - } - - public TElement ElementAt(int index) - { - throw Error.ArgumentOutOfRange("index"); - } - - public TElement ElementAtOrDefault(int index) - { - return default(TElement); - } - - public TElement First() - { - throw Error.NoElements(); - } - - public TElement FirstOrDefault() - { - return default(TElement); - } - - public TElement Last() - { - throw Error.NoElements(); - } - - public TElement LastOrDefault() - { - return default(TElement); - } - - public TElement[] ToArray() - { - return Array.Empty(); - } - - public List ToList() - { - return new List(); - } - } - - internal sealed class OrderedPartition : IPartition - { - private readonly OrderedEnumerable _source; - private readonly int _minIndex; - private readonly int _maxIndex; - - public OrderedPartition(OrderedEnumerable source, int minIdx, int maxIdx) - { - _source = source; - _minIndex = minIdx; - _maxIndex = maxIdx; - } - - public IEnumerator GetEnumerator() - { - return _source.GetEnumerator(_minIndex, _maxIndex); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - - public IPartition Skip(int count) - { - int minIndex = _minIndex + count; - return minIndex >= _maxIndex - ? (IPartition)new EmptyPartition() - : new OrderedPartition(_source, minIndex, _maxIndex); - } - - public IPartition Take(int count) - { - int maxIndex = _minIndex + count - 1; - if (maxIndex >= _maxIndex) maxIndex = _maxIndex; - return new OrderedPartition(_source, _minIndex, maxIndex); - } - - public TElement ElementAt(int index) - { - if ((uint)index > (uint)_maxIndex - _minIndex) throw Error.ArgumentOutOfRange("index"); - return _source.ElementAt(index + _minIndex); - } - - public TElement ElementAtOrDefault(int index) - { - return (uint)index <= (uint)_maxIndex - _minIndex ? _source.ElementAtOrDefault(index + _minIndex) : default(TElement); - } - - public TElement First() - { - TElement result; - if (!_source.TryGetElementAt(_minIndex, out result)) throw Error.NoElements(); - return result; - } - - public TElement FirstOrDefault() - { - return _source.ElementAtOrDefault(_minIndex); - } - - public TElement Last() - { - return _source.Last(_minIndex, _maxIndex); - } - - public TElement LastOrDefault() - { - return _source.LastOrDefault(_minIndex, _maxIndex); - } - - public TElement[] ToArray() - { - return _source.ToArray(_minIndex, _maxIndex); - } - } - - internal abstract class OrderedEnumerable : IOrderedEnumerable, IArrayProvider, IListProvider, IPartition - { - internal IEnumerable source; - - private int[] SortedMap(Buffer buffer) - { - return GetEnumerableSorter().Sort(buffer.items, buffer.count); - } - - private int[] SortedMap(Buffer buffer, int minIdx, int maxIdx) - { - return GetEnumerableSorter().Sort(buffer.items, buffer.count, minIdx, maxIdx); - } - - public IEnumerator GetEnumerator() - { - Buffer buffer = new Buffer(source); - if (buffer.count > 0) - { - int[] map = SortedMap(buffer); - for (int i = 0; i < buffer.count; i++) yield return buffer.items[map[i]]; - } - } - - public TElement[] ToArray() - { - Buffer buffer = new Buffer(source); - - int count = buffer.count; - if (count == 0) - { - return buffer.items; - } - - TElement[] array = new TElement[count]; - int[] map = SortedMap(buffer); - for (int i = 0; i != array.Length; i++) array[i] = buffer.items[map[i]]; - - return array; - } - - public List ToList() - { - Buffer buffer = new Buffer(source); - int count = buffer.count; - List list = new List(count); - if (count > 0) - { - int[] map = SortedMap(buffer); - for (int i = 0; i != count; i++) list.Add(buffer.items[map[i]]); - } - - return list; - } - - internal IEnumerator GetEnumerator(int minIdx, int maxIdx) - { - Buffer buffer = new Buffer(source); - int count = buffer.count; - if (count > minIdx) - { - if (count <= maxIdx) maxIdx = count - 1; - if (minIdx == maxIdx) yield return GetEnumerableSorter().ElementAt(buffer.items, count, minIdx); - else - { - int[] map = SortedMap(buffer, minIdx, maxIdx); - while (minIdx <= maxIdx) - { - yield return buffer.items[map[minIdx]]; - ++minIdx; - } - } - } - } - - internal TElement[] ToArray(int minIdx, int maxIdx) - { - Buffer buffer = new Buffer(source); - int count = buffer.count; - if (count <= minIdx) return Array.Empty(); - if (count <= maxIdx) maxIdx = count - 1; - if (minIdx == maxIdx) return new TElement[] { GetEnumerableSorter().ElementAt(buffer.items, count, minIdx) }; - int[] map = SortedMap(buffer, minIdx, maxIdx); - TElement[] array = new TElement[maxIdx - minIdx + 1]; - int idx = 0; - while (minIdx <= maxIdx) - { - array[idx] = buffer.items[map[minIdx]]; - ++idx; - ++minIdx; - } - return array; - } - - private EnumerableSorter GetEnumerableSorter() - { - return GetEnumerableSorter(null); - } - - internal abstract EnumerableSorter GetEnumerableSorter(EnumerableSorter next); - - internal CachingComparer GetComparer() - { - return GetComparer(null); - } - - internal abstract CachingComparer GetComparer(CachingComparer childComparer); - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - - IOrderedEnumerable IOrderedEnumerable.CreateOrderedEnumerable(Func keySelector, IComparer comparer, bool descending) - { - OrderedEnumerable result = new OrderedEnumerable(source, keySelector, comparer, descending); - result.parent = this; - return result; - } - - public IPartition Skip(int count) - { - return new OrderedPartition(this, count, int.MaxValue); - } - - public IPartition Take(int count) - { - return new OrderedPartition(this, 0, count - 1); - } - - public bool TryGetElementAt(int index, out TElement result) - { - if (index == 0) return TryGetFirst(out result); - if (index > 0) - { - Buffer buffer = new Buffer(source); - int count = buffer.count; - if (index < count) - { - result = GetEnumerableSorter().ElementAt(buffer.items, count, index); - return true; - } - } - result = default(TElement); - return false; - } - - public TElement ElementAt(int index) - { - TElement result; - if (!TryGetElementAt(index, out result)) throw Error.ArgumentOutOfRange("index"); - return result; - } - - public TElement ElementAtOrDefault(int index) - { - TElement result; - TryGetElementAt(index, out result); - return result; - } - - private bool TryGetFirst(out TElement result) - { - CachingComparer comparer = GetComparer(); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) - { - result = default(TElement); - return false; - } - TElement value = e.Current; - comparer.SetElement(value); - while (e.MoveNext()) - { - TElement x = e.Current; - if (comparer.Compare(x, true) < 0) value = x; - } - result = value; - return true; - } - } - - public TElement FirstOrDefault() - { - TElement result; - TryGetFirst(out result); - return result; - } - - public TElement First() - { - TElement result; - if (!TryGetFirst(out result)) throw Error.NoElements(); - return result; - } - - public TElement First(Func predicate) - { - CachingComparer comparer = GetComparer(); - using (IEnumerator e = source.GetEnumerator()) - { - TElement value; - do - { - if (!e.MoveNext()) throw Error.NoMatch(); - value = e.Current; - } while (!predicate(value)); - comparer.SetElement(value); - while (e.MoveNext()) - { - TElement x = e.Current; - if (predicate(x) && comparer.Compare(x, true) < 0) value = x; - } - return value; - } - } - - public TElement FirstOrDefault(Func predicate) - { - CachingComparer comparer = GetComparer(); - using (IEnumerator e = source.GetEnumerator()) - { - TElement value; - do - { - if (!e.MoveNext()) return default(TElement); - value = e.Current; - } while (!predicate(value)); - comparer.SetElement(value); - while (e.MoveNext()) - { - TElement x = e.Current; - if (predicate(x) && comparer.Compare(x, true) < 0) value = x; - } - return value; - } - } - - public TElement Last() - { - CachingComparer comparer = GetComparer(); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) throw Error.NoElements(); - TElement value = e.Current; - comparer.SetElement(value); - while (e.MoveNext()) - { - TElement x = e.Current; - if (comparer.Compare(x, false) >= 0) value = x; - } - return value; - } - } - - public TElement LastOrDefault() - { - CachingComparer comparer = GetComparer(); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) return default(TElement); - TElement value = e.Current; - comparer.SetElement(value); - while (e.MoveNext()) - { - TElement x = e.Current; - if (comparer.Compare(x, false) >= 0) value = x; - } - return value; - } - } - - public TElement Last(int minIdx, int maxIdx) - { - Buffer buffer = new Buffer(source); - int count = buffer.count; - if (minIdx >= count) throw Error.NoElements(); - if (maxIdx < count - 1) return GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); - // If we're here, we want the same results we would have got from - // Last(), but we've already buffered our source. - return Last(buffer); - } - - public TElement LastOrDefault(int minIdx, int maxIdx) - { - Buffer buffer = new Buffer(source); - int count = buffer.count; - if (minIdx >= count) return default(TElement); - if (maxIdx < count - 1) return GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); - return Last(buffer); - } - - private TElement Last(Buffer buffer) - { - CachingComparer comparer = GetComparer(); - TElement[] items = buffer.items; - int count = buffer.count; - TElement value = items[0]; - comparer.SetElement(value); - for (int i = 1; i != count; ++i) - { - TElement x = items[i]; - if (comparer.Compare(x, false) >= 0) value = x; - } - return value; - } - - public TElement Last(Func predicate) - { - CachingComparer comparer = GetComparer(); - using (IEnumerator e = source.GetEnumerator()) - { - TElement value; - do - { - if (!e.MoveNext()) throw Error.NoMatch(); - value = e.Current; - } while (!predicate(value)); - comparer.SetElement(value); - while (e.MoveNext()) - { - TElement x = e.Current; - if (predicate(x) && comparer.Compare(x, false) >= 0) value = x; - } - return value; - } - } - - public TElement LastOrDefault(Func predicate) - { - CachingComparer comparer = GetComparer(); - using (IEnumerator e = source.GetEnumerator()) - { - TElement value; - do - { - if (!e.MoveNext()) return default(TElement); - value = e.Current; - } while (!predicate(value)); - comparer.SetElement(value); - while (e.MoveNext()) - { - TElement x = e.Current; - if (predicate(x) && comparer.Compare(x, false) > 0) value = x; - } - return value; - } - } - } - - internal sealed class OrderedEnumerable : OrderedEnumerable - { - internal OrderedEnumerable parent; - internal Func keySelector; - internal IComparer comparer; - internal bool descending; - - internal OrderedEnumerable(IEnumerable source, Func keySelector, IComparer comparer, bool descending) - { - if (source == null) throw Error.ArgumentNull("source"); - if (keySelector == null) throw Error.ArgumentNull("keySelector"); - this.source = source; - this.parent = null; - this.keySelector = keySelector; - this.comparer = comparer != null ? comparer : Comparer.Default; - this.descending = descending; - } - - internal override EnumerableSorter GetEnumerableSorter(EnumerableSorter next) - { - EnumerableSorter sorter = new EnumerableSorter(keySelector, comparer, descending, next); - if (parent != null) sorter = parent.GetEnumerableSorter(sorter); - return sorter; - } - - internal override CachingComparer GetComparer(CachingComparer childComparer) - { - CachingComparer cmp = childComparer == null - ? new CachingComparer(keySelector, comparer, descending) - : new CachingComparerWithChild(keySelector, comparer, descending, childComparer); - return parent != null ? parent.GetComparer(cmp) : cmp; - } - } - - // A comparer that chains comparisons, and pushes through the last element found to be - // lower or higher (depending on use), so as to represent the sort of comparisons - // done by OrderBy().ThenBy() combinations. - internal abstract class CachingComparer - { - internal abstract int Compare(TElement element, bool cacheLower); - internal abstract void SetElement(TElement element); - } - - internal class CachingComparer : CachingComparer - { - protected readonly Func KeySelector; - protected readonly IComparer Comparer; - protected readonly bool Descending; - protected TKey LastKey; - public CachingComparer(Func keySelector, IComparer comparer, bool descending) - { - KeySelector = keySelector; - Comparer = comparer; - Descending = descending; - } - internal override int Compare(TElement element, bool cacheLower) - { - TKey newKey = KeySelector(element); - int cmp = Descending ? Comparer.Compare(LastKey, newKey) : Comparer.Compare(newKey, LastKey); - if (cacheLower == cmp < 0) LastKey = newKey; - return cmp; - } - internal override void SetElement(TElement element) - { - LastKey = KeySelector(element); - } - } - - internal sealed class CachingComparerWithChild : CachingComparer - { - private readonly CachingComparer _child; - public CachingComparerWithChild(Func keySelector, IComparer comparer, bool descending, CachingComparer child) - : base(keySelector, comparer, descending) - { - _child = child; - } - internal override int Compare(TElement element, bool cacheLower) - { - TKey newKey = KeySelector(element); - int cmp = Descending ? Comparer.Compare(LastKey, newKey) : Comparer.Compare(newKey, LastKey); - if (cmp == 0) return _child.Compare(element, cacheLower); - if (cacheLower == cmp < 0) - { - LastKey = newKey; - _child.SetElement(element); - } - return cmp; - } - internal override void SetElement(TElement element) - { - base.SetElement(element); - _child.SetElement(element); - } - } - - internal abstract class EnumerableSorter - { - internal abstract void ComputeKeys(TElement[] elements, int count); - - internal abstract int CompareAnyKeys(int index1, int index2); - - private int[] ComputeMap(TElement[] elements, int count) - { - ComputeKeys(elements, count); - int[] map = new int[count]; - for (int i = 0; i < count; i++) map[i] = i; - return map; - } - - internal int[] Sort(TElement[] elements, int count) - { - int[] map = ComputeMap(elements, count); - QuickSort(map, 0, count - 1); - return map; - } - - internal int[] Sort(TElement[] elements, int count, int minIdx, int maxIdx) - { - int[] map = ComputeMap(elements, count); - PartialQuickSort(map, 0, count - 1, minIdx, maxIdx); - return map; - } - - internal TElement ElementAt(TElement[] elements, int count, int idx) - { - return elements[QuickSelect(ComputeMap(elements, count), count - 1, idx)]; - } - - private int CompareKeys(int index1, int index2) - { - return index1 == index2 ? 0 : CompareAnyKeys(index1, index2); - } - - private void QuickSort(int[] map, int left, int right) - { - do - { - int i = left; - int j = right; - int x = map[i + ((j - i) >> 1)]; - do - { - while (i < map.Length && CompareKeys(x, map[i]) > 0) i++; - while (j >= 0 && CompareKeys(x, map[j]) < 0) j--; - if (i > j) break; - if (i < j) - { - int temp = map[i]; - map[i] = map[j]; - map[j] = temp; - } - i++; - j--; - } while (i <= j); - if (j - left <= right - i) - { - if (left < j) QuickSort(map, left, j); - left = i; - } - else - { - if (i < right) QuickSort(map, i, right); - right = j; - } - } while (left < right); - } - - // Sorts the k elements between minIdx and maxIdx without sorting all elements - // Time complexity: O(n + k log k) best and average case. O(n^2) worse case. - private void PartialQuickSort(int[] map, int left, int right, int minIdx, int maxIdx) - { - do - { - int i = left; - int j = right; - int x = map[i + ((j - i) >> 1)]; - do - { - while (i < map.Length && CompareKeys(x, map[i]) > 0) i++; - while (j >= 0 && CompareKeys(x, map[j]) < 0) j--; - if (i > j) break; - if (i < j) - { - int temp = map[i]; - map[i] = map[j]; - map[j] = temp; - } - i++; - j--; - } while (i <= j); - if (minIdx >= i) left = i + 1; - else if (maxIdx <= j) right = j - 1; - if (j - left <= right - i) - { - if (left < j) PartialQuickSort(map, left, j, minIdx, maxIdx); - left = i; - } - else - { - if (i < right) PartialQuickSort(map, i, right, minIdx, maxIdx); - right = j; - } - } while (left < right); - } - - // Finds the element that would be at idx if the collection was sorted. - // Time complexity: O(n) best and average case. O(n^2) worse case. - private int QuickSelect(int[] map, int right, int idx) - { - int left = 0; - do - { - int i = left; - int j = right; - int x = map[i + ((j - i) >> 1)]; - do - { - while (i < map.Length && CompareKeys(x, map[i]) > 0) i++; - while (j >= 0 && CompareKeys(x, map[j]) < 0) j--; - if (i > j) break; - if (i < j) - { - int temp = map[i]; - map[i] = map[j]; - map[j] = temp; - } - i++; - j--; - } while (i <= j); - if (i <= idx) left = i + 1; - else right = j - 1; - if (j - left <= right - i) - { - if (left < j) right = j; - left = i; - } - else - { - if (i < right) left = i; - right = j; - } - } while (left < right); - return map[idx]; - } - } - - internal sealed class EnumerableSorter : EnumerableSorter - { - internal Func keySelector; - internal IComparer comparer; - internal bool descending; - internal EnumerableSorter next; - internal TKey[] keys; - - internal EnumerableSorter(Func keySelector, IComparer comparer, bool descending, EnumerableSorter next) - { - this.keySelector = keySelector; - this.comparer = comparer; - this.descending = descending; - this.next = next; - } - - internal override void ComputeKeys(TElement[] elements, int count) - { - keys = new TKey[count]; - for (int i = 0; i < count; i++) keys[i] = keySelector(elements[i]); - if (next != null) next.ComputeKeys(elements, count); - } - - internal override int CompareAnyKeys(int index1, int index2) - { - int c = comparer.Compare(keys[index1], keys[index2]); - if (c == 0) - { - if (next == null) return index1 - index2; - return next.CompareAnyKeys(index1, index2); - } - // -c will result in a negative value for int.MinValue (-int.MinValue == int.MinValue). - // Flipping keys earlier is more likely to trigger something strange in a comparer, - // particularly as it comes to the sort being stable. - return (descending != (c > 0)) ? 1 : -1; - } - } - - internal struct Buffer - { - internal TElement[] items; - internal int count; - - internal Buffer(IEnumerable source) - { - IArrayProvider iterator = source as IArrayProvider; - if (iterator != null) - { - TElement[] array = iterator.ToArray(); - items = array; - count = array.Length; - } - else - { - items = EnumerableHelpers.ToArray(source, out count); - } - } - } - - // NOTE: DO NOT DELETE THE FOLLOWING DEBUG VIEW TYPES. - // Although it might be tempting due to them not be referenced anywhere in this library, - // Visual Studio currently depends on their existence to enable the "Results" view in - // watch windows. - - /// - /// This class provides the items view for the Enumerable - /// - /// - internal sealed class SystemCore_EnumerableDebugView - { - public SystemCore_EnumerableDebugView(IEnumerable enumerable) - { - if (enumerable == null) - { - throw new ArgumentNullException("enumerable"); - } - - _enumerable = enumerable; - } - - [System.Diagnostics.DebuggerBrowsable(System.Diagnostics.DebuggerBrowsableState.RootHidden)] - public T[] Items - { - get - { - T[] array = _enumerable.ToArray(); - if (array.Length == 0) - { - throw new SystemCore_EnumerableDebugViewEmptyException(); - } - return array; - } - } - - [System.Diagnostics.DebuggerBrowsable(System.Diagnostics.DebuggerBrowsableState.Never)] - private IEnumerable _enumerable; - } - - internal sealed class SystemCore_EnumerableDebugViewEmptyException : Exception - { - public string Empty - { - get - { - return SR.EmptyEnumerable; - } - } - } - - internal sealed class SystemCore_EnumerableDebugView - { - public SystemCore_EnumerableDebugView(IEnumerable enumerable) + public static IEnumerable AsEnumerable(this IEnumerable source) { - if (enumerable == null) - { - throw new ArgumentNullException("enumerable"); - } - - _enumerable = enumerable; + return source; } - [System.Diagnostics.DebuggerBrowsable(System.Diagnostics.DebuggerBrowsableState.RootHidden)] - public object[] Items + public static IEnumerable Empty() { - get - { - List tempList = new List(); - foreach (object item in _enumerable) - tempList.Add(item); - - if (tempList.Count == 0) - { - throw new SystemCore_EnumerableDebugViewEmptyException(); - } - return tempList.ToArray(); - } + return Array.Empty(); } - - [System.Diagnostics.DebuggerBrowsable(System.Diagnostics.DebuggerBrowsableState.Never)] - private IEnumerable _enumerable; } } diff --git a/src/System.Linq/src/System/Linq/Except.cs b/src/System.Linq/src/System/Linq/Except.cs new file mode 100644 index 000000000000..7f5841a304fe --- /dev/null +++ b/src/System.Linq/src/System/Linq/Except.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Except(this IEnumerable first, IEnumerable second) + { + if (first == null) throw Error.ArgumentNull("first"); + if (second == null) throw Error.ArgumentNull("second"); + return ExceptIterator(first, second, null); + } + + public static IEnumerable Except(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + if (first == null) throw Error.ArgumentNull("first"); + if (second == null) throw Error.ArgumentNull("second"); + return ExceptIterator(first, second, comparer); + } + + private static IEnumerable ExceptIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + Set set = new Set(comparer); + foreach (TSource element in second) set.Add(element); + foreach (TSource element in first) + if (set.Add(element)) yield return element; + } + } +} diff --git a/src/System.Linq/src/System/Linq/First.cs b/src/System.Linq/src/System/Linq/First.cs new file mode 100644 index 000000000000..359cc494df96 --- /dev/null +++ b/src/System.Linq/src/System/Linq/First.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static TSource First(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + IPartition partition = source as IPartition; + if (partition != null) return partition.First(); + IList list = source as IList; + if (list != null) + { + if (list.Count > 0) return list[0]; + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (e.MoveNext()) return e.Current; + } + } + throw Error.NoElements(); + } + + public static TSource First(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + OrderedEnumerable ordered = source as OrderedEnumerable; + if (ordered != null) return ordered.First(predicate); + foreach (TSource element in source) + { + if (predicate(element)) return element; + } + throw Error.NoMatch(); + } + + public static TSource FirstOrDefault(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + IPartition partition = source as IPartition; + if (partition != null) return partition.FirstOrDefault(); + IList list = source as IList; + if (list != null) + { + if (list.Count > 0) return list[0]; + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (e.MoveNext()) return e.Current; + } + } + return default(TSource); + } + + public static TSource FirstOrDefault(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + OrderedEnumerable ordered = source as OrderedEnumerable; + if (ordered != null) return ordered.FirstOrDefault(predicate); + foreach (TSource element in source) + { + if (predicate(element)) return element; + } + return default(TSource); + } + } +} diff --git a/src/System.Linq/src/System/Linq/GroupJoin.cs b/src/System.Linq/src/System/Linq/GroupJoin.cs new file mode 100644 index 000000000000..2f25541f012e --- /dev/null +++ b/src/System.Linq/src/System/Linq/GroupJoin.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable GroupJoin(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector) + { + if (outer == null) throw Error.ArgumentNull("outer"); + if (inner == null) throw Error.ArgumentNull("inner"); + if (outerKeySelector == null) throw Error.ArgumentNull("outerKeySelector"); + if (innerKeySelector == null) throw Error.ArgumentNull("innerKeySelector"); + if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); + return GroupJoinIterator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, null); + } + + public static IEnumerable GroupJoin(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer comparer) + { + if (outer == null) throw Error.ArgumentNull("outer"); + if (inner == null) throw Error.ArgumentNull("inner"); + if (outerKeySelector == null) throw Error.ArgumentNull("outerKeySelector"); + if (innerKeySelector == null) throw Error.ArgumentNull("innerKeySelector"); + if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); + return GroupJoinIterator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + private static IEnumerable GroupJoinIterator(IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer comparer) + { + using (IEnumerator e = outer.GetEnumerator()) + { + if (e.MoveNext()) + { + Lookup lookup = Lookup.CreateForJoin(inner, innerKeySelector, comparer); + do + { + TOuter item = e.Current; + yield return resultSelector(item, lookup[outerKeySelector(item)]); + } + while (e.MoveNext()); + } + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Grouping.cs b/src/System.Linq/src/System/Linq/Grouping.cs new file mode 100644 index 000000000000..d1d8e7705278 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Grouping.cs @@ -0,0 +1,248 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable> GroupBy(this IEnumerable source, Func keySelector) + { + return new GroupedEnumerable(source, keySelector, IdentityFunction.Instance, null); + } + + public static IEnumerable> GroupBy(this IEnumerable source, Func keySelector, IEqualityComparer comparer) + { + return new GroupedEnumerable(source, keySelector, IdentityFunction.Instance, comparer); + } + + public static IEnumerable> GroupBy(this IEnumerable source, Func keySelector, Func elementSelector) + { + return new GroupedEnumerable(source, keySelector, elementSelector, null); + } + + public static IEnumerable> GroupBy(this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + { + return new GroupedEnumerable(source, keySelector, elementSelector, comparer); + } + + public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func, TResult> resultSelector) + { + return new GroupedEnumerable(source, keySelector, IdentityFunction.Instance, resultSelector, null); + } + + public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector) + { + return new GroupedEnumerable(source, keySelector, elementSelector, resultSelector, null); + } + + public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer comparer) + { + return new GroupedEnumerable(source, keySelector, IdentityFunction.Instance, resultSelector, comparer); + } + + public static IEnumerable GroupBy(this IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) + { + return new GroupedEnumerable(source, keySelector, elementSelector, resultSelector, comparer); + } + } + + internal class IdentityFunction + { + public static Func Instance + { + get { return x => x; } + } + } + + public interface IGrouping : IEnumerable + { + TKey Key { get; } + } + + // It is (unfortunately) common to databind directly to Grouping.Key. + // Because of this, we have to declare this internal type public so that we + // can mark the Key property for public reflection. + // + // To limit the damage, the toolchain makes this type appear in a hidden assembly. + // (This is also why it is no longer a nested type of Lookup<,>). + // + public class Grouping : IGrouping, IList + { + internal TKey key; + internal int hashCode; + internal TElement[] elements; + internal int count; + internal Grouping hashNext; + internal Grouping next; + + internal Grouping() + { + } + + internal void Add(TElement element) + { + if (elements.Length == count) Array.Resize(ref elements, checked(count * 2)); + elements[count] = element; + count++; + } + + public IEnumerator GetEnumerator() + { + for (int i = 0; i < count; i++) yield return elements[i]; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + // DDB195907: implement IGrouping<>.Key implicitly + // so that WPF binding works on this property. + public TKey Key + { + get { return key; } + } + + int ICollection.Count + { + get { return count; } + } + + bool ICollection.IsReadOnly + { + get { return true; } + } + + void ICollection.Add(TElement item) + { + throw Error.NotSupported(); + } + + void ICollection.Clear() + { + throw Error.NotSupported(); + } + + bool ICollection.Contains(TElement item) + { + return Array.IndexOf(elements, item, 0, count) >= 0; + } + + void ICollection.CopyTo(TElement[] array, int arrayIndex) + { + Array.Copy(elements, 0, array, arrayIndex, count); + } + + bool ICollection.Remove(TElement item) + { + throw Error.NotSupported(); + } + + int IList.IndexOf(TElement item) + { + return Array.IndexOf(elements, item, 0, count); + } + + void IList.Insert(int index, TElement item) + { + throw Error.NotSupported(); + } + + void IList.RemoveAt(int index) + { + throw Error.NotSupported(); + } + + TElement IList.this[int index] + { + get + { + if (index < 0 || index >= count) throw Error.ArgumentOutOfRange("index"); + return elements[index]; + } + set + { + throw Error.NotSupported(); + } + } + } + + internal class GroupedEnumerable : IEnumerable + { + private IEnumerable _source; + private Func _keySelector; + private Func _elementSelector; + private IEqualityComparer _comparer; + private Func, TResult> _resultSelector; + + public GroupedEnumerable(IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) + { + if (source == null) throw Error.ArgumentNull("source"); + if (keySelector == null) throw Error.ArgumentNull("keySelector"); + if (elementSelector == null) throw Error.ArgumentNull("elementSelector"); + if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); + _source = source; + _keySelector = keySelector; + _elementSelector = elementSelector; + _comparer = comparer; + _resultSelector = resultSelector; + } + + public IEnumerator GetEnumerator() + { + Lookup lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); + return lookup.ApplyResultSelector(_resultSelector).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } + + internal class GroupedEnumerable : IEnumerable>, IArrayProvider>, IListProvider> + { + private IEnumerable _source; + private Func _keySelector; + private Func _elementSelector; + private IEqualityComparer _comparer; + + public GroupedEnumerable(IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + { + if (source == null) throw Error.ArgumentNull("source"); + if (keySelector == null) throw Error.ArgumentNull("keySelector"); + if (elementSelector == null) throw Error.ArgumentNull("elementSelector"); + _source = source; + _keySelector = keySelector; + _elementSelector = elementSelector; + _comparer = comparer; + } + + public IEnumerator> GetEnumerator() + { + return Lookup.Create(_source, _keySelector, _elementSelector, _comparer).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public IGrouping[] ToArray() + { + IArrayProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); + return lookup.ToArray(); + } + + public List> ToList() + { + IListProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); + return lookup.ToList(); + } + } +} diff --git a/src/System.Linq/src/System/Linq/Intersect.cs b/src/System.Linq/src/System/Linq/Intersect.cs new file mode 100644 index 000000000000..a42090a076e6 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Intersect.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Intersect(this IEnumerable first, IEnumerable second) + { + if (first == null) throw Error.ArgumentNull("first"); + if (second == null) throw Error.ArgumentNull("second"); + return IntersectIterator(first, second, null); + } + + public static IEnumerable Intersect(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + if (first == null) throw Error.ArgumentNull("first"); + if (second == null) throw Error.ArgumentNull("second"); + return IntersectIterator(first, second, comparer); + } + + private static IEnumerable IntersectIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + Set set = new Set(comparer); + foreach (TSource element in second) set.Add(element); + foreach (TSource element in first) + if (set.Remove(element)) yield return element; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Iterator.cs b/src/System.Linq/src/System/Linq/Iterator.cs new file mode 100644 index 000000000000..f07affd5c792 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Iterator.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + internal abstract class Iterator : IEnumerable, IEnumerator + { + private int _threadId; + internal int state; + internal TSource current; + + public Iterator() + { + _threadId = Environment.CurrentManagedThreadId; + } + + public TSource Current + { + get { return current; } + } + + public abstract Iterator Clone(); + + public virtual void Dispose() + { + current = default(TSource); + state = -1; + } + + public IEnumerator GetEnumerator() + { + Iterator enumerator = state == 0 && _threadId == Environment.CurrentManagedThreadId ? this : Clone(); + enumerator.state = 1; + return enumerator; + } + + public abstract bool MoveNext(); + + public virtual IEnumerable Select(Func selector) + { + return new SelectEnumerableIterator(this, selector); + } + + public virtual IEnumerable Where(Func predicate) + { + return new WhereEnumerableIterator(this, predicate); + } + + object IEnumerator.Current + { + get { return Current; } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + void IEnumerator.Reset() + { + throw Error.NotSupported(); + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Join.cs b/src/System.Linq/src/System/Linq/Join.cs new file mode 100644 index 000000000000..7bc4d26523a2 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Join.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Join(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) + { + if (outer == null) throw Error.ArgumentNull("outer"); + if (inner == null) throw Error.ArgumentNull("inner"); + if (outerKeySelector == null) throw Error.ArgumentNull("outerKeySelector"); + if (innerKeySelector == null) throw Error.ArgumentNull("innerKeySelector"); + if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); + return JoinIterator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, null); + } + + public static IEnumerable Join(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) + { + if (outer == null) throw Error.ArgumentNull("outer"); + if (inner == null) throw Error.ArgumentNull("inner"); + if (outerKeySelector == null) throw Error.ArgumentNull("outerKeySelector"); + if (innerKeySelector == null) throw Error.ArgumentNull("innerKeySelector"); + if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); + return JoinIterator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + private static IEnumerable JoinIterator(IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) + { + Lookup lookup = Lookup.CreateForJoin(inner, innerKeySelector, comparer); + foreach (TOuter item in outer) + { + Grouping g = lookup.GetGrouping(outerKeySelector(item), false); + if (g != null) + { + for (int i = 0; i < g.count; i++) + { + yield return resultSelector(item, g.elements[i]); + } + } + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Last.cs b/src/System.Linq/src/System/Linq/Last.cs new file mode 100644 index 000000000000..9af5556da3ba --- /dev/null +++ b/src/System.Linq/src/System/Linq/Last.cs @@ -0,0 +1,137 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static TSource Last(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + IPartition partition = source as IPartition; + if (partition != null) return partition.Last(); + IList list = source as IList; + if (list != null) + { + int count = list.Count; + if (count > 0) return list[count - 1]; + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (e.MoveNext()) + { + TSource result; + do + { + result = e.Current; + } while (e.MoveNext()); + return result; + } + } + } + throw Error.NoElements(); + } + + public static TSource Last(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + OrderedEnumerable ordered = source as OrderedEnumerable; + if (ordered != null) return ordered.Last(predicate); + IList list = source as IList; + if (list != null) + { + for (int i = list.Count - 1; i >= 0; --i) + { + TSource result = list[i]; + if (predicate(result)) return result; + } + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + TSource result = e.Current; + if (predicate(result)) + { + while (e.MoveNext()) + { + TSource element = e.Current; + if (predicate(element)) result = element; + } + return result; + } + } + } + } + throw Error.NoMatch(); + } + + public static TSource LastOrDefault(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + IPartition partition = source as IPartition; + if (partition != null) return partition.LastOrDefault(); + IList list = source as IList; + if (list != null) + { + int count = list.Count; + if (count > 0) return list[count - 1]; + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (e.MoveNext()) + { + TSource result; + do + { + result = e.Current; + } while (e.MoveNext()); + return result; + } + } + } + return default(TSource); + } + + public static TSource LastOrDefault(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + OrderedEnumerable ordered = source as OrderedEnumerable; + if (ordered != null) return ordered.LastOrDefault(predicate); + IList list = source as IList; + if (list != null) + { + for (int i = list.Count - 1; i >= 0; --i) + { + TSource element = list[i]; + if (predicate(element)) return element; + } + return default(TSource); + } + else + { + TSource result = default(TSource); + foreach (TSource element in source) + { + if (predicate(element)) + { + result = element; + } + } + return result; + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Lookup.cs b/src/System.Linq/src/System/Linq/Lookup.cs new file mode 100644 index 000000000000..973bc3bb821f --- /dev/null +++ b/src/System.Linq/src/System/Linq/Lookup.cs @@ -0,0 +1,215 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static ILookup ToLookup(this IEnumerable source, Func keySelector) + { + return Lookup.Create(source, keySelector, IdentityFunction.Instance, null); + } + + public static ILookup ToLookup(this IEnumerable source, Func keySelector, IEqualityComparer comparer) + { + return Lookup.Create(source, keySelector, IdentityFunction.Instance, comparer); + } + + public static ILookup ToLookup(this IEnumerable source, Func keySelector, Func elementSelector) + { + return Lookup.Create(source, keySelector, elementSelector, null); + } + + public static ILookup ToLookup(this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + { + return Lookup.Create(source, keySelector, elementSelector, comparer); + } + } + + public interface ILookup : IEnumerable> + { + int Count { get; } + IEnumerable this[TKey key] { get; } + bool Contains(TKey key); + } + + public class Lookup : ILookup, IArrayProvider>, IListProvider> + { + private IEqualityComparer _comparer; + private Grouping[] _groupings; + private Grouping _lastGrouping; + private int _count; + + internal static Lookup Create(IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + { + if (source == null) throw Error.ArgumentNull("source"); + if (keySelector == null) throw Error.ArgumentNull("keySelector"); + if (elementSelector == null) throw Error.ArgumentNull("elementSelector"); + Lookup lookup = new Lookup(comparer); + foreach (TSource item in source) + { + lookup.GetGrouping(keySelector(item), true).Add(elementSelector(item)); + } + return lookup; + } + + internal static Lookup CreateForJoin(IEnumerable source, Func keySelector, IEqualityComparer comparer) + { + Lookup lookup = new Lookup(comparer); + foreach (TElement item in source) + { + TKey key = keySelector(item); + if (key != null) lookup.GetGrouping(key, true).Add(item); + } + return lookup; + } + + private Lookup(IEqualityComparer comparer) + { + if (comparer == null) comparer = EqualityComparer.Default; + _comparer = comparer; + _groupings = new Grouping[7]; + } + + public int Count + { + get { return _count; } + } + + public IEnumerable this[TKey key] + { + get + { + Grouping grouping = GetGrouping(key, false); + if (grouping != null) return grouping; + return Array.Empty(); + } + } + + public bool Contains(TKey key) + { + return GetGrouping(key, false) != null; + } + + public IEnumerator> GetEnumerator() + { + Grouping g = _lastGrouping; + if (g != null) + { + do + { + g = g.next; + yield return g; + } while (g != _lastGrouping); + } + } + + IGrouping[] IArrayProvider>.ToArray() + { + IGrouping[] array = new IGrouping[_count]; + int index = 0; + Grouping g = _lastGrouping; + if (g != null) + { + do + { + g = g.next; + array[index] = g; + ++index; + } while (g != _lastGrouping); + } + return array; + } + + List> IListProvider>.ToList() + { + List> list = new List>(_count); + Grouping g = _lastGrouping; + if (g != null) + { + do + { + g = g.next; + list.Add(g); + } while (g != _lastGrouping); + } + return list; + } + + public IEnumerable ApplyResultSelector(Func, TResult> resultSelector) + { + Grouping g = _lastGrouping; + if (g != null) + { + do + { + g = g.next; + if (g.count != g.elements.Length) { Array.Resize(ref g.elements, g.count); } + yield return resultSelector(g.key, g.elements); + } while (g != _lastGrouping); + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + internal int InternalGetHashCode(TKey key) + { + // Handle comparer implementations that throw when passed null + return (key == null) ? 0 : _comparer.GetHashCode(key) & 0x7FFFFFFF; + } + + internal Grouping GetGrouping(TKey key, bool create) + { + int hashCode = InternalGetHashCode(key); + for (Grouping g = _groupings[hashCode % _groupings.Length]; g != null; g = g.hashNext) + if (g.hashCode == hashCode && _comparer.Equals(g.key, key)) return g; + if (create) + { + if (_count == _groupings.Length) Resize(); + int index = hashCode % _groupings.Length; + Grouping g = new Grouping(); + g.key = key; + g.hashCode = hashCode; + g.elements = new TElement[1]; + g.hashNext = _groupings[index]; + _groupings[index] = g; + if (_lastGrouping == null) + { + g.next = g; + } + else + { + g.next = _lastGrouping.next; + _lastGrouping.next = g; + } + _lastGrouping = g; + _count++; + return g; + } + return null; + } + + private void Resize() + { + int newSize = checked(_count * 2 + 1); + Grouping[] newGroupings = new Grouping[newSize]; + Grouping g = _lastGrouping; + do + { + g = g.next; + int index = g.hashCode % newSize; + g.hashNext = newGroupings[index]; + newGroupings[index] = g; + } while (g != _lastGrouping); + _groupings = newGroupings; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Max.cs b/src/System.Linq/src/System/Linq/Max.cs new file mode 100644 index 000000000000..7f4bf7d08eb4 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Max.cs @@ -0,0 +1,671 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static int Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + int value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + int x = e.Current; + if (x > value) value = x; + } + } + return value; + } + + public static int? Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + int? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + int valueVal = value.GetValueOrDefault(); + if (valueVal >= 0) + { + // We can fast-path this case where we know HasValue will + // never affect the outcome, without constantly checking + // if we're in such a state. Similar fast-paths could + // be done for other cases, but as all-positive + // or mostly-positive integer values are quite common in real-world + // uses, it's only been done in this direction for int? and long?. + while (e.MoveNext()) + { + int? cur = e.Current; + int x = cur.GetValueOrDefault(); + if (x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + else + { + while (e.MoveNext()) + { + int? cur = e.Current; + int x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + } + return value; + } + + public static long Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + long value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + long x = e.Current; + if (x > value) value = x; + } + } + return value; + } + + public static long? Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + long? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + long valueVal = value.GetValueOrDefault(); + if (valueVal >= 0) + { + while (e.MoveNext()) + { + long? cur = e.Current; + long x = cur.GetValueOrDefault(); + if (x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + else + { + while (e.MoveNext()) + { + long? cur = e.Current; + long x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + } + return value; + } + + public static double Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + double value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + // As described in a comment on Min(this IEnumerable) NaN is ordered + // less than all other values. We need to do explicit checks to ensure this, but + // once we've found a value that is not NaN we need no longer worry about it, + // so first loop until such a value is found (or not, as the case may be). + while (double.IsNaN(value)) + { + if (!e.MoveNext()) return value; + value = e.Current; + } + while (e.MoveNext()) + { + double x = e.Current; + if (x > value) value = x; + } + } + return value; + } + + public static double? Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + double? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + double valueVal = value.GetValueOrDefault(); + while (double.IsNaN(valueVal)) + { + if (!e.MoveNext()) return value; + double? cur = e.Current; + if (cur.HasValue) valueVal = (value = cur).GetValueOrDefault(); + } + while (e.MoveNext()) + { + double? cur = e.Current; + double x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static float Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + float value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (float.IsNaN(value)) + { + if (!e.MoveNext()) return value; + value = e.Current; + } + while (e.MoveNext()) + { + float x = e.Current; + if (x > value) value = x; + } + } + return value; + } + + public static float? Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + float? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + float valueVal = value.GetValueOrDefault(); + while (float.IsNaN(valueVal)) + { + if (!e.MoveNext()) return value; + float? cur = e.Current; + if (cur.HasValue) valueVal = (value = cur).GetValueOrDefault(); + } + while (e.MoveNext()) + { + float? cur = e.Current; + float x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static decimal Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + decimal value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + decimal x = e.Current; + if (x > value) value = x; + } + } + return value; + } + + public static decimal? Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + decimal? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + decimal valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + decimal? cur = e.Current; + decimal x = cur.GetValueOrDefault(); + if (cur.HasValue && x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static TSource Max(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + Comparer comparer = Comparer.Default; + TSource value = default(TSource); + if (value == null) + { + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (value == null); + while (e.MoveNext()) + { + TSource x = e.Current; + if (x != null && comparer.Compare(x, value) > 0) value = x; + } + } + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + TSource x = e.Current; + if (comparer.Compare(x, value) > 0) value = x; + } + } + } + return value; + } + + public static int Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + int value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + int x = selector(e.Current); + if (x > value) value = x; + } + } + return value; + } + + public static int? Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + int? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + int valueVal = value.GetValueOrDefault(); + if (valueVal >= 0) + { + // We can fast-path this case where we know HasValue will + // never affect the outcome, without constantly checking + // if we're in such a state. Similar fast-paths could + // be done for other cases, but as all-positive + // or mostly-positive integer values are quite common in real-world + // uses, it's only been done in this direction for int? and long?. + while (e.MoveNext()) + { + int? cur = selector(e.Current); + int x = cur.GetValueOrDefault(); + if (x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + else + { + while (e.MoveNext()) + { + int? cur = selector(e.Current); + int x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + } + return value; + } + + public static long Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + long value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + long x = selector(e.Current); + if (x > value) value = x; + } + } + return value; + } + + public static long? Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + long? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + long valueVal = value.GetValueOrDefault(); + if (valueVal >= 0) + { + while (e.MoveNext()) + { + long? cur = selector(e.Current); + long x = cur.GetValueOrDefault(); + if (x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + else + { + while (e.MoveNext()) + { + long? cur = selector(e.Current); + long x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + } + return value; + } + + public static float Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + float value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (float.IsNaN(value)) + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } + while (e.MoveNext()) + { + float x = selector(e.Current); + if (x > value) value = x; + } + } + return value; + } + + public static float? Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + float? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + float valueVal = value.GetValueOrDefault(); + while (float.IsNaN(valueVal)) + { + if (!e.MoveNext()) return value; + float? cur = selector(e.Current); + if (cur.HasValue) valueVal = (value = cur).GetValueOrDefault(); + } + while (e.MoveNext()) + { + float? cur = selector(e.Current); + float x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static double Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + double value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + // As described in a comment on Min(this IEnumerable) NaN is ordered + // less than all other values. We need to do explicit checks to ensure this, but + // once we've found a value that is not NaN we need no longer worry about it, + // so first loop until such a value is found (or not, as the case may be). + while (double.IsNaN(value)) + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } + while (e.MoveNext()) + { + double x = selector(e.Current); + if (x > value) value = x; + } + } + return value; + } + + public static double? Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + double? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + double valueVal = value.GetValueOrDefault(); + while (double.IsNaN(valueVal)) + { + if (!e.MoveNext()) return value; + double? cur = selector(e.Current); + if (cur.HasValue) valueVal = (value = cur).GetValueOrDefault(); + } + while (e.MoveNext()) + { + double? cur = selector(e.Current); + double x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static decimal Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + decimal value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + decimal x = selector(e.Current); + if (x > value) value = x; + } + } + return value; + } + + public static decimal? Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + decimal? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + decimal valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + decimal? cur = selector(e.Current); + decimal x = cur.GetValueOrDefault(); + if (cur.HasValue && x > valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static TResult Max(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + Comparer comparer = Comparer.Default; + TResult value = default(TResult); + if (value == null) + { + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (value == null); + while (e.MoveNext()) + { + TResult x = selector(e.Current); + if (x != null && comparer.Compare(x, value) > 0) value = x; + } + } + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + TResult x = selector(e.Current); + if (comparer.Compare(x, value) > 0) value = x; + } + } + } + return value; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Min.cs b/src/System.Linq/src/System/Linq/Min.cs new file mode 100644 index 000000000000..2961171ee8b8 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Min.cs @@ -0,0 +1,579 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static int Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + int value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + int x = e.Current; + if (x < value) value = x; + } + } + return value; + } + + public static int? Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + int? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + // Start off knowing that we've a non-null value (or exit here, knowing we don't) + // so we don't have to keep testing for nullity. + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + // Keep hold of the wrapped value, and do comparisons on that, rather than + // using the lifted operation each time. + int valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + int? cur = e.Current; + int x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x < valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static long Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + long value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + long x = e.Current; + if (x < value) value = x; + } + } + return value; + } + + public static long? Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + long? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + long valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + long? cur = e.Current; + long x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x < valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static float Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + float value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + float x = e.Current; + if (x < value) value = x; + // Normally NaN < anything is false, as is anything < NaN + // However, this leads to some irksome outcomes in Min and Max. + // If we use those semantics then Min(NaN, 5.0) is NaN, but + // Min(5.0, NaN) is 5.0! To fix this, we impose a total + // ordering where NaN is smaller than every value, including + // negative infinity. + // Not testing for NaN therefore isn't an option, but since we + // can't find a smaller value, we can short-circuit. + else if (float.IsNaN(x)) return x; + } + } + return value; + } + + public static float? Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + float? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + float valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + float? cur = e.Current; + if (cur.HasValue) + { + float x = cur.GetValueOrDefault(); + if (x < valueVal) + { + valueVal = x; + value = cur; + } + else if (float.IsNaN(x)) return cur; + } + } + } + return value; + } + + public static double Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + double value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + double x = e.Current; + if (x < value) value = x; + else if (double.IsNaN(x)) return x; + } + } + return value; + } + + public static double? Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + double? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + double valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + double? cur = e.Current; + if (cur.HasValue) + { + double x = cur.GetValueOrDefault(); + if (x < valueVal) + { + valueVal = x; + value = cur; + } + else if (double.IsNaN(x)) return cur; + } + } + } + return value; + } + + public static decimal Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + decimal value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + decimal x = e.Current; + if (x < value) value = x; + } + } + return value; + } + + public static decimal? Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + decimal? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (!value.HasValue); + decimal valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + decimal? cur = e.Current; + decimal x = cur.GetValueOrDefault(); + if (cur.HasValue && x < valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static TSource Min(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + Comparer comparer = Comparer.Default; + TSource value = default(TSource); + if (value == null) + { + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = e.Current; + } while (value == null); + while (e.MoveNext()) + { + TSource x = e.Current; + if (x != null && comparer.Compare(x, value) < 0) value = x; + } + } + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = e.Current; + while (e.MoveNext()) + { + TSource x = e.Current; + if (comparer.Compare(x, value) < 0) value = x; + } + } + } + return value; + } + + public static int Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + int value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + int x = selector(e.Current); + if (x < value) value = x; + } + } + return value; + } + + public static int? Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + int? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + // Start off knowing that we've a non-null value (or exit here, knowing we don't) + // so we don't have to keep testing for nullity. + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + // Keep hold of the wrapped value, and do comparisons on that, rather than + // using the lifted operation each time. + int valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + int? cur = selector(e.Current); + int x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x < valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static long Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + long value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + long x = selector(e.Current); + if (x < value) value = x; + } + } + return value; + } + + public static long? Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + long? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + long valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + long? cur = selector(e.Current); + long x = cur.GetValueOrDefault(); + // Do not replace & with &&. The branch prediction cost outweighs the extra operation + // unless nulls either never happen or always happen. + if (cur.HasValue & x < valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static float Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + float value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + float x = selector(e.Current); + if (x < value) value = x; + // Normally NaN < anything is false, as is anything < NaN + // However, this leads to some irksome outcomes in Min and Max. + // If we use those semantics then Min(NaN, 5.0) is NaN, but + // Min(5.0, NaN) is 5.0! To fix this, we impose a total + // ordering where NaN is smaller than every value, including + // negative infinity. + // Not testing for NaN therefore isn't an option, but since we + // can't find a smaller value, we can short-circuit. + else if (float.IsNaN(x)) return x; + } + } + return value; + } + + public static float? Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + float? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + float valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + float? cur = selector(e.Current); + if (cur.HasValue) + { + float x = cur.GetValueOrDefault(); + if (x < valueVal) + { + valueVal = x; + value = cur; + } + else if (float.IsNaN(x)) return cur; + } + } + } + return value; + } + + public static double Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + double value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + double x = selector(e.Current); + if (x < value) value = x; + else if (double.IsNaN(x)) return x; + } + } + return value; + } + + public static double? Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + double? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + double valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + double? cur = selector(e.Current); + if (cur.HasValue) + { + double x = cur.GetValueOrDefault(); + if (x < valueVal) + { + valueVal = x; + value = cur; + } + else if (double.IsNaN(x)) return cur; + } + } + } + return value; + } + + public static decimal Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + decimal value; + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + decimal x = selector(e.Current); + if (x < value) value = x; + } + } + return value; + } + + public static decimal? Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + decimal? value = null; + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (!value.HasValue); + decimal valueVal = value.GetValueOrDefault(); + while (e.MoveNext()) + { + decimal? cur = selector(e.Current); + decimal x = cur.GetValueOrDefault(); + if (cur.HasValue && x < valueVal) + { + valueVal = x; + value = cur; + } + } + } + return value; + } + + public static TResult Min(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + Comparer comparer = Comparer.Default; + TResult value = default(TResult); + if (value == null) + { + using (IEnumerator e = source.GetEnumerator()) + { + do + { + if (!e.MoveNext()) return value; + value = selector(e.Current); + } while (value == null); + while (e.MoveNext()) + { + TResult x = selector(e.Current); + if (x != null && comparer.Compare(x, value) < 0) value = x; + } + } + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + value = selector(e.Current); + while (e.MoveNext()) + { + TResult x = selector(e.Current); + if (comparer.Compare(x, value) < 0) value = x; + } + } + } + return value; + } + } +} diff --git a/src/System.Linq/src/System/Linq/OrderBy.cs b/src/System.Linq/src/System/Linq/OrderBy.cs new file mode 100644 index 000000000000..ed82dbc2fedd --- /dev/null +++ b/src/System.Linq/src/System/Linq/OrderBy.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IOrderedEnumerable OrderBy(this IEnumerable source, Func keySelector) + { + return new OrderedEnumerable(source, keySelector, null, false); + } + + public static IOrderedEnumerable OrderBy(this IEnumerable source, Func keySelector, IComparer comparer) + { + return new OrderedEnumerable(source, keySelector, comparer, false); + } + + public static IOrderedEnumerable OrderByDescending(this IEnumerable source, Func keySelector) + { + return new OrderedEnumerable(source, keySelector, null, true); + } + + public static IOrderedEnumerable OrderByDescending(this IEnumerable source, Func keySelector, IComparer comparer) + { + return new OrderedEnumerable(source, keySelector, comparer, true); + } + + public static IOrderedEnumerable ThenBy(this IOrderedEnumerable source, Func keySelector) + { + if (source == null) throw Error.ArgumentNull("source"); + return source.CreateOrderedEnumerable(keySelector, null, false); + } + + public static IOrderedEnumerable ThenBy(this IOrderedEnumerable source, Func keySelector, IComparer comparer) + { + if (source == null) throw Error.ArgumentNull("source"); + return source.CreateOrderedEnumerable(keySelector, comparer, false); + } + + public static IOrderedEnumerable ThenByDescending(this IOrderedEnumerable source, Func keySelector) + { + if (source == null) throw Error.ArgumentNull("source"); + return source.CreateOrderedEnumerable(keySelector, null, true); + } + + public static IOrderedEnumerable ThenByDescending(this IOrderedEnumerable source, Func keySelector, IComparer comparer) + { + if (source == null) throw Error.ArgumentNull("source"); + return source.CreateOrderedEnumerable(keySelector, comparer, true); + } + } + + public interface IOrderedEnumerable : IEnumerable + { + IOrderedEnumerable CreateOrderedEnumerable(Func keySelector, IComparer comparer, bool descending); + } +} diff --git a/src/System.Linq/src/System/Linq/OrderedEnumerable.cs b/src/System.Linq/src/System/Linq/OrderedEnumerable.cs new file mode 100644 index 000000000000..88dc047b066d --- /dev/null +++ b/src/System.Linq/src/System/Linq/OrderedEnumerable.cs @@ -0,0 +1,645 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace System.Linq +{ + internal abstract class OrderedEnumerable : IOrderedEnumerable, IArrayProvider, IListProvider, IPartition + { + internal IEnumerable source; + + private int[] SortedMap(Buffer buffer) + { + return GetEnumerableSorter().Sort(buffer.items, buffer.count); + } + + private int[] SortedMap(Buffer buffer, int minIdx, int maxIdx) + { + return GetEnumerableSorter().Sort(buffer.items, buffer.count, minIdx, maxIdx); + } + + public IEnumerator GetEnumerator() + { + Buffer buffer = new Buffer(source); + if (buffer.count > 0) + { + int[] map = SortedMap(buffer); + for (int i = 0; i < buffer.count; i++) yield return buffer.items[map[i]]; + } + } + + public TElement[] ToArray() + { + Buffer buffer = new Buffer(source); + + int count = buffer.count; + if (count == 0) + { + return buffer.items; + } + + TElement[] array = new TElement[count]; + int[] map = SortedMap(buffer); + for (int i = 0; i != array.Length; i++) array[i] = buffer.items[map[i]]; + + return array; + } + + public List ToList() + { + Buffer buffer = new Buffer(source); + int count = buffer.count; + List list = new List(count); + if (count > 0) + { + int[] map = SortedMap(buffer); + for (int i = 0; i != count; i++) list.Add(buffer.items[map[i]]); + } + + return list; + } + + internal IEnumerator GetEnumerator(int minIdx, int maxIdx) + { + Buffer buffer = new Buffer(source); + int count = buffer.count; + if (count > minIdx) + { + if (count <= maxIdx) maxIdx = count - 1; + if (minIdx == maxIdx) yield return GetEnumerableSorter().ElementAt(buffer.items, count, minIdx); + else + { + int[] map = SortedMap(buffer, minIdx, maxIdx); + while (minIdx <= maxIdx) + { + yield return buffer.items[map[minIdx]]; + ++minIdx; + } + } + } + } + + internal TElement[] ToArray(int minIdx, int maxIdx) + { + Buffer buffer = new Buffer(source); + int count = buffer.count; + if (count <= minIdx) return Array.Empty(); + if (count <= maxIdx) maxIdx = count - 1; + if (minIdx == maxIdx) return new TElement[] { GetEnumerableSorter().ElementAt(buffer.items, count, minIdx) }; + int[] map = SortedMap(buffer, minIdx, maxIdx); + TElement[] array = new TElement[maxIdx - minIdx + 1]; + int idx = 0; + while (minIdx <= maxIdx) + { + array[idx] = buffer.items[map[minIdx]]; + ++idx; + ++minIdx; + } + return array; + } + + private EnumerableSorter GetEnumerableSorter() + { + return GetEnumerableSorter(null); + } + + internal abstract EnumerableSorter GetEnumerableSorter(EnumerableSorter next); + + internal CachingComparer GetComparer() + { + return GetComparer(null); + } + + internal abstract CachingComparer GetComparer(CachingComparer childComparer); + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + IOrderedEnumerable IOrderedEnumerable.CreateOrderedEnumerable(Func keySelector, IComparer comparer, bool descending) + { + OrderedEnumerable result = new OrderedEnumerable(source, keySelector, comparer, descending); + result.parent = this; + return result; + } + + public IPartition Skip(int count) + { + return new OrderedPartition(this, count, int.MaxValue); + } + + public IPartition Take(int count) + { + return new OrderedPartition(this, 0, count - 1); + } + + public bool TryGetElementAt(int index, out TElement result) + { + if (index == 0) return TryGetFirst(out result); + if (index > 0) + { + Buffer buffer = new Buffer(source); + int count = buffer.count; + if (index < count) + { + result = GetEnumerableSorter().ElementAt(buffer.items, count, index); + return true; + } + } + result = default(TElement); + return false; + } + + public TElement ElementAt(int index) + { + TElement result; + if (!TryGetElementAt(index, out result)) throw Error.ArgumentOutOfRange("index"); + return result; + } + + public TElement ElementAtOrDefault(int index) + { + TElement result; + TryGetElementAt(index, out result); + return result; + } + + private bool TryGetFirst(out TElement result) + { + CachingComparer comparer = GetComparer(); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) + { + result = default(TElement); + return false; + } + TElement value = e.Current; + comparer.SetElement(value); + while (e.MoveNext()) + { + TElement x = e.Current; + if (comparer.Compare(x, true) < 0) value = x; + } + result = value; + return true; + } + } + + public TElement FirstOrDefault() + { + TElement result; + TryGetFirst(out result); + return result; + } + + public TElement First() + { + TElement result; + if (!TryGetFirst(out result)) throw Error.NoElements(); + return result; + } + + public TElement First(Func predicate) + { + CachingComparer comparer = GetComparer(); + using (IEnumerator e = source.GetEnumerator()) + { + TElement value; + do + { + if (!e.MoveNext()) throw Error.NoMatch(); + value = e.Current; + } while (!predicate(value)); + comparer.SetElement(value); + while (e.MoveNext()) + { + TElement x = e.Current; + if (predicate(x) && comparer.Compare(x, true) < 0) value = x; + } + return value; + } + } + + public TElement FirstOrDefault(Func predicate) + { + CachingComparer comparer = GetComparer(); + using (IEnumerator e = source.GetEnumerator()) + { + TElement value; + do + { + if (!e.MoveNext()) return default(TElement); + value = e.Current; + } while (!predicate(value)); + comparer.SetElement(value); + while (e.MoveNext()) + { + TElement x = e.Current; + if (predicate(x) && comparer.Compare(x, true) < 0) value = x; + } + return value; + } + } + + public TElement Last() + { + CachingComparer comparer = GetComparer(); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + TElement value = e.Current; + comparer.SetElement(value); + while (e.MoveNext()) + { + TElement x = e.Current; + if (comparer.Compare(x, false) >= 0) value = x; + } + return value; + } + } + + public TElement LastOrDefault() + { + CachingComparer comparer = GetComparer(); + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) return default(TElement); + TElement value = e.Current; + comparer.SetElement(value); + while (e.MoveNext()) + { + TElement x = e.Current; + if (comparer.Compare(x, false) >= 0) value = x; + } + return value; + } + } + + public TElement Last(int minIdx, int maxIdx) + { + Buffer buffer = new Buffer(source); + int count = buffer.count; + if (minIdx >= count) throw Error.NoElements(); + if (maxIdx < count - 1) return GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); + // If we're here, we want the same results we would have got from + // Last(), but we've already buffered our source. + return Last(buffer); + } + + public TElement LastOrDefault(int minIdx, int maxIdx) + { + Buffer buffer = new Buffer(source); + int count = buffer.count; + if (minIdx >= count) return default(TElement); + if (maxIdx < count - 1) return GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); + return Last(buffer); + } + + private TElement Last(Buffer buffer) + { + CachingComparer comparer = GetComparer(); + TElement[] items = buffer.items; + int count = buffer.count; + TElement value = items[0]; + comparer.SetElement(value); + for (int i = 1; i != count; ++i) + { + TElement x = items[i]; + if (comparer.Compare(x, false) >= 0) value = x; + } + return value; + } + + public TElement Last(Func predicate) + { + CachingComparer comparer = GetComparer(); + using (IEnumerator e = source.GetEnumerator()) + { + TElement value; + do + { + if (!e.MoveNext()) throw Error.NoMatch(); + value = e.Current; + } while (!predicate(value)); + comparer.SetElement(value); + while (e.MoveNext()) + { + TElement x = e.Current; + if (predicate(x) && comparer.Compare(x, false) >= 0) value = x; + } + return value; + } + } + + public TElement LastOrDefault(Func predicate) + { + CachingComparer comparer = GetComparer(); + using (IEnumerator e = source.GetEnumerator()) + { + TElement value; + do + { + if (!e.MoveNext()) return default(TElement); + value = e.Current; + } while (!predicate(value)); + comparer.SetElement(value); + while (e.MoveNext()) + { + TElement x = e.Current; + if (predicate(x) && comparer.Compare(x, false) > 0) value = x; + } + return value; + } + } + } + + internal sealed class OrderedEnumerable : OrderedEnumerable + { + internal OrderedEnumerable parent; + internal Func keySelector; + internal IComparer comparer; + internal bool descending; + + internal OrderedEnumerable(IEnumerable source, Func keySelector, IComparer comparer, bool descending) + { + if (source == null) throw Error.ArgumentNull("source"); + if (keySelector == null) throw Error.ArgumentNull("keySelector"); + this.source = source; + this.parent = null; + this.keySelector = keySelector; + this.comparer = comparer != null ? comparer : Comparer.Default; + this.descending = descending; + } + + internal override EnumerableSorter GetEnumerableSorter(EnumerableSorter next) + { + EnumerableSorter sorter = new EnumerableSorter(keySelector, comparer, descending, next); + if (parent != null) sorter = parent.GetEnumerableSorter(sorter); + return sorter; + } + + internal override CachingComparer GetComparer(CachingComparer childComparer) + { + CachingComparer cmp = childComparer == null + ? new CachingComparer(keySelector, comparer, descending) + : new CachingComparerWithChild(keySelector, comparer, descending, childComparer); + return parent != null ? parent.GetComparer(cmp) : cmp; + } + } + + // A comparer that chains comparisons, and pushes through the last element found to be + // lower or higher (depending on use), so as to represent the sort of comparisons + // done by OrderBy().ThenBy() combinations. + internal abstract class CachingComparer + { + internal abstract int Compare(TElement element, bool cacheLower); + internal abstract void SetElement(TElement element); + } + + internal class CachingComparer : CachingComparer + { + protected readonly Func KeySelector; + protected readonly IComparer Comparer; + protected readonly bool Descending; + protected TKey LastKey; + public CachingComparer(Func keySelector, IComparer comparer, bool descending) + { + KeySelector = keySelector; + Comparer = comparer; + Descending = descending; + } + internal override int Compare(TElement element, bool cacheLower) + { + TKey newKey = KeySelector(element); + int cmp = Descending ? Comparer.Compare(LastKey, newKey) : Comparer.Compare(newKey, LastKey); + if (cacheLower == cmp < 0) LastKey = newKey; + return cmp; + } + internal override void SetElement(TElement element) + { + LastKey = KeySelector(element); + } + } + + internal sealed class CachingComparerWithChild : CachingComparer + { + private readonly CachingComparer _child; + public CachingComparerWithChild(Func keySelector, IComparer comparer, bool descending, CachingComparer child) + : base(keySelector, comparer, descending) + { + _child = child; + } + internal override int Compare(TElement element, bool cacheLower) + { + TKey newKey = KeySelector(element); + int cmp = Descending ? Comparer.Compare(LastKey, newKey) : Comparer.Compare(newKey, LastKey); + if (cmp == 0) return _child.Compare(element, cacheLower); + if (cacheLower == cmp < 0) + { + LastKey = newKey; + _child.SetElement(element); + } + return cmp; + } + internal override void SetElement(TElement element) + { + base.SetElement(element); + _child.SetElement(element); + } + } + + internal abstract class EnumerableSorter + { + internal abstract void ComputeKeys(TElement[] elements, int count); + + internal abstract int CompareAnyKeys(int index1, int index2); + + private int[] ComputeMap(TElement[] elements, int count) + { + ComputeKeys(elements, count); + int[] map = new int[count]; + for (int i = 0; i < count; i++) map[i] = i; + return map; + } + + internal int[] Sort(TElement[] elements, int count) + { + int[] map = ComputeMap(elements, count); + QuickSort(map, 0, count - 1); + return map; + } + + internal int[] Sort(TElement[] elements, int count, int minIdx, int maxIdx) + { + int[] map = ComputeMap(elements, count); + PartialQuickSort(map, 0, count - 1, minIdx, maxIdx); + return map; + } + + internal TElement ElementAt(TElement[] elements, int count, int idx) + { + return elements[QuickSelect(ComputeMap(elements, count), count - 1, idx)]; + } + + private int CompareKeys(int index1, int index2) + { + return index1 == index2 ? 0 : CompareAnyKeys(index1, index2); + } + + private void QuickSort(int[] map, int left, int right) + { + do + { + int i = left; + int j = right; + int x = map[i + ((j - i) >> 1)]; + do + { + while (i < map.Length && CompareKeys(x, map[i]) > 0) i++; + while (j >= 0 && CompareKeys(x, map[j]) < 0) j--; + if (i > j) break; + if (i < j) + { + int temp = map[i]; + map[i] = map[j]; + map[j] = temp; + } + i++; + j--; + } while (i <= j); + if (j - left <= right - i) + { + if (left < j) QuickSort(map, left, j); + left = i; + } + else + { + if (i < right) QuickSort(map, i, right); + right = j; + } + } while (left < right); + } + + // Sorts the k elements between minIdx and maxIdx without sorting all elements + // Time complexity: O(n + k log k) best and average case. O(n^2) worse case. + private void PartialQuickSort(int[] map, int left, int right, int minIdx, int maxIdx) + { + do + { + int i = left; + int j = right; + int x = map[i + ((j - i) >> 1)]; + do + { + while (i < map.Length && CompareKeys(x, map[i]) > 0) i++; + while (j >= 0 && CompareKeys(x, map[j]) < 0) j--; + if (i > j) break; + if (i < j) + { + int temp = map[i]; + map[i] = map[j]; + map[j] = temp; + } + i++; + j--; + } while (i <= j); + if (minIdx >= i) left = i + 1; + else if (maxIdx <= j) right = j - 1; + if (j - left <= right - i) + { + if (left < j) PartialQuickSort(map, left, j, minIdx, maxIdx); + left = i; + } + else + { + if (i < right) PartialQuickSort(map, i, right, minIdx, maxIdx); + right = j; + } + } while (left < right); + } + + // Finds the element that would be at idx if the collection was sorted. + // Time complexity: O(n) best and average case. O(n^2) worse case. + private int QuickSelect(int[] map, int right, int idx) + { + int left = 0; + do + { + int i = left; + int j = right; + int x = map[i + ((j - i) >> 1)]; + do + { + while (i < map.Length && CompareKeys(x, map[i]) > 0) i++; + while (j >= 0 && CompareKeys(x, map[j]) < 0) j--; + if (i > j) break; + if (i < j) + { + int temp = map[i]; + map[i] = map[j]; + map[j] = temp; + } + i++; + j--; + } while (i <= j); + if (i <= idx) left = i + 1; + else right = j - 1; + if (j - left <= right - i) + { + if (left < j) right = j; + left = i; + } + else + { + if (i < right) left = i; + right = j; + } + } while (left < right); + return map[idx]; + } + } + + internal sealed class EnumerableSorter : EnumerableSorter + { + internal Func keySelector; + internal IComparer comparer; + internal bool descending; + internal EnumerableSorter next; + internal TKey[] keys; + + internal EnumerableSorter(Func keySelector, IComparer comparer, bool descending, EnumerableSorter next) + { + this.keySelector = keySelector; + this.comparer = comparer; + this.descending = descending; + this.next = next; + } + + internal override void ComputeKeys(TElement[] elements, int count) + { + keys = new TKey[count]; + for (int i = 0; i < count; i++) keys[i] = keySelector(elements[i]); + if (next != null) next.ComputeKeys(elements, count); + } + + internal override int CompareAnyKeys(int index1, int index2) + { + int c = comparer.Compare(keys[index1], keys[index2]); + if (c == 0) + { + if (next == null) return index1 - index2; + return next.CompareAnyKeys(index1, index2); + } + // -c will result in a negative value for int.MinValue (-int.MinValue == int.MinValue). + // Flipping keys earlier is more likely to trigger something strange in a comparer, + // particularly as it comes to the sort being stable. + return (descending != (c > 0)) ? 1 : -1; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Partition.cs b/src/System.Linq/src/System/Linq/Partition.cs new file mode 100644 index 000000000000..5bc1a14c5e9f --- /dev/null +++ b/src/System.Linq/src/System/Linq/Partition.cs @@ -0,0 +1,225 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +namespace System.Linq +{ + /// + /// An iterator that can produce an array through an optimized path. + /// + internal interface IArrayProvider + { + /// + /// Produce an array of the sequence through an optimized path. + /// + /// The array. + TElement[] ToArray(); + } + + /// + /// An iterator that can produce a through an optimized path. + /// + internal interface IListProvider + { + /// + /// Produce a of the sequence through an optimized path. + /// + /// The . + List ToList(); + } + + internal interface IPartition : IEnumerable, IArrayProvider + { + IPartition Skip(int count); + + IPartition Take(int count); + + TElement ElementAt(int index); + + TElement ElementAtOrDefault(int index); + + TElement First(); + + TElement FirstOrDefault(); + + TElement Last(); + + TElement LastOrDefault(); + } + + internal sealed class EmptyPartition : IPartition, IListProvider, IEnumerator + { + public EmptyPartition() + { + } + + public IEnumerator GetEnumerator() + { + return this; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this; + } + + public bool MoveNext() + { + return false; + } + + [ExcludeFromCodeCoverage] // Shouldn't be called, and as undefined can return or throw anything anyway. + public TElement Current + { + get { return default(TElement); } + } + + [ExcludeFromCodeCoverage] // Shouldn't be called, and as undefined can return or throw anything anyway. + object IEnumerator.Current + { + get { return default(TElement); } + } + + void IEnumerator.Reset() + { + throw Error.NotSupported(); + } + + void IDisposable.Dispose() + { + // Do nothing. + } + + public IPartition Skip(int count) + { + return new EmptyPartition(); + } + + public IPartition Take(int count) + { + return new EmptyPartition(); + } + + public TElement ElementAt(int index) + { + throw Error.ArgumentOutOfRange("index"); + } + + public TElement ElementAtOrDefault(int index) + { + return default(TElement); + } + + public TElement First() + { + throw Error.NoElements(); + } + + public TElement FirstOrDefault() + { + return default(TElement); + } + + public TElement Last() + { + throw Error.NoElements(); + } + + public TElement LastOrDefault() + { + return default(TElement); + } + + public TElement[] ToArray() + { + return Array.Empty(); + } + + public List ToList() + { + return new List(); + } + } + + internal sealed class OrderedPartition : IPartition + { + private readonly OrderedEnumerable _source; + private readonly int _minIndex; + private readonly int _maxIndex; + + public OrderedPartition(OrderedEnumerable source, int minIdx, int maxIdx) + { + _source = source; + _minIndex = minIdx; + _maxIndex = maxIdx; + } + + public IEnumerator GetEnumerator() + { + return _source.GetEnumerator(_minIndex, _maxIndex); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public IPartition Skip(int count) + { + int minIndex = _minIndex + count; + return minIndex >= _maxIndex + ? (IPartition)new EmptyPartition() + : new OrderedPartition(_source, minIndex, _maxIndex); + } + + public IPartition Take(int count) + { + int maxIndex = _minIndex + count - 1; + if (maxIndex >= _maxIndex) maxIndex = _maxIndex; + return new OrderedPartition(_source, _minIndex, maxIndex); + } + + public TElement ElementAt(int index) + { + if ((uint)index > (uint)_maxIndex - _minIndex) throw Error.ArgumentOutOfRange("index"); + return _source.ElementAt(index + _minIndex); + } + + public TElement ElementAtOrDefault(int index) + { + return (uint)index <= (uint)_maxIndex - _minIndex ? _source.ElementAtOrDefault(index + _minIndex) : default(TElement); + } + + public TElement First() + { + TElement result; + if (!_source.TryGetElementAt(_minIndex, out result)) throw Error.NoElements(); + return result; + } + + public TElement FirstOrDefault() + { + return _source.ElementAtOrDefault(_minIndex); + } + + public TElement Last() + { + return _source.Last(_minIndex, _maxIndex); + } + + public TElement LastOrDefault() + { + return _source.LastOrDefault(_minIndex, _maxIndex); + } + + public TElement[] ToArray() + { + return _source.ToArray(_minIndex, _maxIndex); + } + } +} diff --git a/src/System.Linq/src/System/Linq/Range.cs b/src/System.Linq/src/System/Linq/Range.cs new file mode 100644 index 000000000000..de01fd403f73 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Range.cs @@ -0,0 +1,129 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Range(int start, int count) + { + long max = ((long)start) + count - 1; + if (count < 0 || max > Int32.MaxValue) throw Error.ArgumentOutOfRange("count"); + if (count == 0) return new EmptyPartition(); + return new RangeIterator(start, count); + } + + private sealed class RangeIterator : Iterator, IArrayProvider, IListProvider, IPartition + { + private readonly int _start; + private readonly int _end; + + public RangeIterator(int start, int count) + { + Debug.Assert(count > 0); + _start = start; + _end = start + count; + } + + public override Iterator Clone() + { + return new RangeIterator(_start, _end - _start); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + Debug.Assert(_start != _end); + current = _start; + state = 2; + return true; + case 2: + if (++current == _end) break; + return true; + } + state = -1; + return false; + } + + public override void Dispose() + { + state = -1; // Don't reset current + } + + public int[] ToArray() + { + int[] array = new int[_end - _start]; + int cur = _start; + for (int i = 0; i != array.Length; ++i) + { + array[i] = cur; + ++cur; + } + + return array; + } + + public List ToList() + { + List list = new List(_end - _start); + for (int cur = _start; cur != _end; cur++) + { + list.Add(cur); + } + + return list; + } + + public IPartition Skip(int count) + { + if (count >= _end - _start) return new EmptyPartition(); + return new RangeIterator(_start + count, _end - _start - count); + } + + public IPartition Take(int count) + { + int curCount = _end - _start; + if (count > curCount) count = curCount; + return new RangeIterator(_start, count); + } + + public int ElementAt(int index) + { + if ((uint)index >= (uint)(_end - _start)) throw Error.ArgumentOutOfRange("index"); + return _start + index; + } + + public int ElementAtOrDefault(int index) + { + return (uint)index >= (uint)(_end - _start) ? 0 : _start + index; + } + + public int First() + { + return _start; + } + + public int FirstOrDefault() + { + return _start; + } + + public int Last() + { + return _end - 1; + } + + public int LastOrDefault() + { + return _end - 1; + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Repeat.cs b/src/System.Linq/src/System/Linq/Repeat.cs new file mode 100644 index 000000000000..ee5445442df5 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Repeat.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Repeat(TResult element, int count) + { + if (count < 0) throw Error.ArgumentOutOfRange("count"); + if (count == 0) return new EmptyPartition(); + return new RepeatIterator(element, count); + } + + private sealed class RepeatIterator : Iterator, IArrayProvider, IListProvider, IPartition + { + private readonly int _count; + private int _sent; + + public RepeatIterator(TResult element, int count) + { + Debug.Assert(count > 0); + current = element; + _count = count; + } + + public override Iterator Clone() + { + return new RepeatIterator(current, _count); + } + + public override void Dispose() + { + // Don't let base Dispose wipe current. + state = -1; + } + + public override bool MoveNext() + { + if (state == 1 & _sent != _count) + { + ++_sent; + return true; + } + state = -1; + return false; + } + + public TResult[] ToArray() + { + TResult[] array = new TResult[_count]; + if (current != null) + { + for (int i = 0; i != array.Length; ++i) array[i] = current; + } + + return array; + } + + public List ToList() + { + List list = new List(_count); + for (int i = 0; i != _count; ++i) list.Add(current); + + return list; + } + + public IPartition Skip(int count) + { + if (count >= _count) return new EmptyPartition(); + return new RepeatIterator(current, _count - count); + } + + public IPartition Take(int count) + { + if (count > _count) count = _count; + return new RepeatIterator(current, count); + } + + public TResult ElementAt(int index) + { + if ((uint)index >= (uint)_count) throw Error.ArgumentOutOfRange("index"); + return current; + } + + public TResult ElementAtOrDefault(int index) + { + return (uint)index >= (uint)_count ? default(TResult) : current; + } + + public TResult First() + { + return current; + } + + public TResult FirstOrDefault() + { + return current; + } + + public TResult Last() + { + return current; + } + + public TResult LastOrDefault() + { + return current; + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Reverse.cs b/src/System.Linq/src/System/Linq/Reverse.cs new file mode 100644 index 000000000000..363b4db5c691 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Reverse.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Reverse(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + return ReverseIterator(source); + } + + private static IEnumerable ReverseIterator(IEnumerable source) + { + Buffer buffer = new Buffer(source); + for (int i = buffer.count - 1; i >= 0; i--) yield return buffer.items[i]; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Select.cs b/src/System.Linq/src/System/Linq/Select.cs new file mode 100644 index 000000000000..ca29b30c5267 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Select.cs @@ -0,0 +1,324 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Select(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + Iterator iterator = source as Iterator; + if (iterator != null) return iterator.Select(selector); + IList ilist = source as IList; + if (ilist != null) + { + TSource[] array = source as TSource[]; + if (array != null) return new SelectArrayIterator(array, selector); + List list = source as List; + if (list != null) return new SelectListIterator(list, selector); + return new SelectIListIterator(ilist, selector); + } + return new SelectEnumerableIterator(source, selector); + } + + public static IEnumerable Select(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + return SelectIterator(source, selector); + } + + private static IEnumerable SelectIterator(IEnumerable source, Func selector) + { + int index = -1; + foreach (TSource element in source) + { + checked { index++; } + yield return selector(element, index); + } + } + + private static Func CombineSelectors(Func selector1, Func selector2) + { + return x => selector2(selector1(x)); + } + + internal sealed class SelectEnumerableIterator : Iterator + { + private readonly IEnumerable _source; + private readonly Func _selector; + private IEnumerator _enumerator; + + public SelectEnumerableIterator(IEnumerable source, Func selector) + { + Debug.Assert(source != null); + Debug.Assert(selector != null); + _source = source; + _selector = selector; + } + + public override Iterator Clone() + { + return new SelectEnumerableIterator(_source, _selector); + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + if (_enumerator.MoveNext()) + { + current = _selector(_enumerator.Current); + return true; + } + Dispose(); + break; + } + return false; + } + + public override IEnumerable Select(Func selector) + { + return new SelectEnumerableIterator(_source, CombineSelectors(_selector, selector)); + } + } + + internal sealed class SelectArrayIterator : Iterator, IArrayProvider, IListProvider + { + private readonly TSource[] _source; + private readonly Func _selector; + private int _index; + + public SelectArrayIterator(TSource[] source, Func selector) + { + Debug.Assert(source != null); + Debug.Assert(selector != null); + _source = source; + _selector = selector; + } + + public override Iterator Clone() + { + return new SelectArrayIterator(_source, _selector); + } + + public override bool MoveNext() + { + if (state == 1 && _index < _source.Length) + { + current = _selector(_source[_index++]); + return true; + } + Dispose(); + return false; + } + + public override IEnumerable Select(Func selector) + { + return new SelectArrayIterator(_source, CombineSelectors(_selector, selector)); + } + + public TResult[] ToArray() + { + if (_source.Length == 0) + { + return Array.Empty(); + } + + var results = new TResult[_source.Length]; + for (int i = 0; i < results.Length; i++) + { + results[i] = _selector(_source[i]); + } + return results; + } + + public List ToList() + { + TSource[] source = _source; + var results = new List(source.Length); + for (int i = 0; i < source.Length; i++) + { + results.Add(_selector(source[i])); + } + return results; + } + } + + internal sealed class SelectListIterator : Iterator, IArrayProvider, IListProvider + { + private readonly List _source; + private readonly Func _selector; + private List.Enumerator _enumerator; + + public SelectListIterator(List source, Func selector) + { + Debug.Assert(source != null); + Debug.Assert(selector != null); + _source = source; + _selector = selector; + } + + public override Iterator Clone() + { + return new SelectListIterator(_source, _selector); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + if (_enumerator.MoveNext()) + { + current = _selector(_enumerator.Current); + return true; + } + Dispose(); + break; + } + return false; + } + + public override IEnumerable Select(Func selector) + { + return new SelectListIterator(_source, CombineSelectors(_selector, selector)); + } + + public TResult[] ToArray() + { + int count = _source.Count; + if (count == 0) + { + return Array.Empty(); + } + + var results = new TResult[count]; + for (int i = 0; i < results.Length; i++) + { + results[i] = _selector(_source[i]); + } + return results; + } + + public List ToList() + { + int count = _source.Count; + var results = new List(count); + for (int i = 0; i < count; i++) + { + results.Add(_selector(_source[i])); + } + return results; + } + } + + internal sealed class SelectIListIterator : Iterator, IArrayProvider, IListProvider + { + private readonly IList _source; + private readonly Func _selector; + private IEnumerator _enumerator; + + public SelectIListIterator(IList source, Func selector) + { + Debug.Assert(source != null); + Debug.Assert(selector != null); + _source = source; + _selector = selector; + } + + public override Iterator Clone() + { + return new SelectIListIterator(_source, _selector); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + if (_enumerator.MoveNext()) + { + current = _selector(_enumerator.Current); + return true; + } + Dispose(); + break; + } + return false; + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); + } + + public override IEnumerable Select(Func selector) + { + return new SelectIListIterator(_source, CombineSelectors(_selector, selector)); + } + + public TResult[] ToArray() + { + int count = _source.Count; + if (count == 0) + { + return Array.Empty(); + } + + var results = new TResult[count]; + for (int i = 0; i < results.Length; i++) + { + results[i] = _selector(_source[i]); + } + return results; + } + + public List ToList() + { + int count = _source.Count; + var results = new List(count); + for (int i = 0; i < count; i++) + { + results.Add(_selector(_source[i])); + } + return results; + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/SelectMany.cs b/src/System.Linq/src/System/Linq/SelectMany.cs new file mode 100644 index 000000000000..d62972e6e231 --- /dev/null +++ b/src/System.Linq/src/System/Linq/SelectMany.cs @@ -0,0 +1,90 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable SelectMany(this IEnumerable source, Func> selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + return SelectManyIterator(source, selector); + } + + private static IEnumerable SelectManyIterator(IEnumerable source, Func> selector) + { + foreach (TSource element in source) + { + foreach (TResult subElement in selector(element)) + { + yield return subElement; + } + } + } + + public static IEnumerable SelectMany(this IEnumerable source, Func> selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + return SelectManyIterator(source, selector); + } + + private static IEnumerable SelectManyIterator(IEnumerable source, Func> selector) + { + int index = -1; + foreach (TSource element in source) + { + checked { index++; } + foreach (TResult subElement in selector(element, index)) + { + yield return subElement; + } + } + } + + public static IEnumerable SelectMany(this IEnumerable source, Func> collectionSelector, Func resultSelector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (collectionSelector == null) throw Error.ArgumentNull("collectionSelector"); + if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); + return SelectManyIterator(source, collectionSelector, resultSelector); + } + + private static IEnumerable SelectManyIterator(IEnumerable source, Func> collectionSelector, Func resultSelector) + { + int index = -1; + foreach (TSource element in source) + { + checked { index++; } + foreach (TCollection subElement in collectionSelector(element, index)) + { + yield return resultSelector(element, subElement); + } + } + } + + public static IEnumerable SelectMany(this IEnumerable source, Func> collectionSelector, Func resultSelector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (collectionSelector == null) throw Error.ArgumentNull("collectionSelector"); + if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); + return SelectManyIterator(source, collectionSelector, resultSelector); + } + + private static IEnumerable SelectManyIterator(IEnumerable source, Func> collectionSelector, Func resultSelector) + { + foreach (TSource element in source) + { + foreach (TCollection subElement in collectionSelector(element)) + { + yield return resultSelector(element, subElement); + } + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/SequenceEqual.cs b/src/System.Linq/src/System/Linq/SequenceEqual.cs new file mode 100644 index 000000000000..03f0cfa27f98 --- /dev/null +++ b/src/System.Linq/src/System/Linq/SequenceEqual.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static bool SequenceEqual(this IEnumerable first, IEnumerable second) + { + return SequenceEqual(first, second, null); + } + + public static bool SequenceEqual(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + if (comparer == null) comparer = EqualityComparer.Default; + if (first == null) throw Error.ArgumentNull("first"); + if (second == null) throw Error.ArgumentNull("second"); + + ICollection firstCol = first as ICollection; + if (firstCol != null) + { + ICollection secondCol = second as ICollection; + if (secondCol != null && firstCol.Count != secondCol.Count) return false; + } + + using (IEnumerator e1 = first.GetEnumerator()) + using (IEnumerator e2 = second.GetEnumerator()) + { + while (e1.MoveNext()) + { + if (!(e2.MoveNext() && comparer.Equals(e1.Current, e2.Current))) return false; + } + return !e2.MoveNext(); + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Set.cs b/src/System.Linq/src/System/Linq/Set.cs new file mode 100644 index 000000000000..e1406050b6ff --- /dev/null +++ b/src/System.Linq/src/System/Linq/Set.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Linq +{ + internal class Set + { + private int[] _buckets; + private Slot[] _slots; + private int _count; + private readonly IEqualityComparer _comparer; +#if DEBUG + private bool _haveRemoved; +#endif + + public Set(IEqualityComparer comparer) + { + if (comparer == null) comparer = EqualityComparer.Default; + _comparer = comparer; + _buckets = new int[7]; + _slots = new Slot[7]; + } + + // If value is not in set, add it and return true; otherwise return false + public bool Add(TElement value) + { +#if DEBUG + Debug.Assert(!_haveRemoved, "This class is optimised for never calling Add after Remove. If your changes need to do so, undo that optimization."); +#endif + int hashCode = InternalGetHashCode(value); + for (int i = _buckets[hashCode % _buckets.Length] - 1; i >= 0; i = _slots[i].next) + { + if (_slots[i].hashCode == hashCode && _comparer.Equals(_slots[i].value, value)) return false; + } + if (_count == _slots.Length) Resize(); + int index = _count; + _count++; + int bucket = hashCode % _buckets.Length; + _slots[index].hashCode = hashCode; + _slots[index].value = value; + _slots[index].next = _buckets[bucket] - 1; + _buckets[bucket] = index + 1; + return true; + } + + // If value is in set, remove it and return true; otherwise return false + public bool Remove(TElement value) + { +#if DEBUG + _haveRemoved = true; +#endif + int hashCode = InternalGetHashCode(value); + int bucket = hashCode % _buckets.Length; + int last = -1; + for (int i = _buckets[bucket] - 1; i >= 0; last = i, i = _slots[i].next) + { + if (_slots[i].hashCode == hashCode && _comparer.Equals(_slots[i].value, value)) + { + if (last < 0) + { + _buckets[bucket] = _slots[i].next + 1; + } + else + { + _slots[last].next = _slots[i].next; + } + _slots[i].hashCode = -1; + _slots[i].value = default(TElement); + _slots[i].next = -1; + return true; + } + } + return false; + } + + private void Resize() + { + int newSize = checked(_count * 2 + 1); + int[] newBuckets = new int[newSize]; + Slot[] newSlots = new Slot[newSize]; + Array.Copy(_slots, 0, newSlots, 0, _count); + for (int i = 0; i < _count; i++) + { + int bucket = newSlots[i].hashCode % newSize; + newSlots[i].next = newBuckets[bucket] - 1; + newBuckets[bucket] = i + 1; + } + _buckets = newBuckets; + _slots = newSlots; + } + + internal int InternalGetHashCode(TElement value) + { + // Handle comparer implementations that throw when passed null + return (value == null) ? 0 : _comparer.GetHashCode(value) & 0x7FFFFFFF; + } + + internal struct Slot + { + internal int hashCode; + internal int next; + internal TElement value; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Single.cs b/src/System.Linq/src/System/Linq/Single.cs new file mode 100644 index 000000000000..1215e1db0709 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Single.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static TSource Single(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + IList list = source as IList; + if (list != null) + { + switch (list.Count) + { + case 0: throw Error.NoElements(); + case 1: return list[0]; + } + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) throw Error.NoElements(); + TSource result = e.Current; + if (!e.MoveNext()) return result; + } + } + throw Error.MoreThanOneElement(); + } + + public static TSource Single(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + TSource result = e.Current; + if (predicate(result)) + { + while (e.MoveNext()) + { + if (predicate(e.Current)) throw Error.MoreThanOneMatch(); + } + return result; + } + } + } + throw Error.NoMatch(); + } + + public static TSource SingleOrDefault(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + IList list = source as IList; + if (list != null) + { + switch (list.Count) + { + case 0: return default(TSource); + case 1: return list[0]; + } + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) return default(TSource); + TSource result = e.Current; + if (!e.MoveNext()) return result; + } + } + throw Error.MoreThanOneElement(); + } + + public static TSource SingleOrDefault(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + TSource result = e.Current; + if (predicate(result)) + { + while (e.MoveNext()) + { + if (predicate(e.Current)) throw Error.MoreThanOneMatch(); + } + return result; + } + } + } + return default(TSource); + } + } +} diff --git a/src/System.Linq/src/System/Linq/Skip.cs b/src/System.Linq/src/System/Linq/Skip.cs new file mode 100644 index 000000000000..f4f2deb38dfb --- /dev/null +++ b/src/System.Linq/src/System/Linq/Skip.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Skip(this IEnumerable source, int count) + { + if (source == null) throw Error.ArgumentNull("source"); + if (count < 0) count = 0; + IPartition partition = source as IPartition; + if (partition != null) return partition.Skip(count); + IList sourceList = source as IList; + return sourceList != null ? SkipList(sourceList, count) : SkipIterator(source, count); + } + + private static IEnumerable SkipList(IList source, int count) + { + while (count < source.Count) + { + yield return source[count++]; + } + } + + private static IEnumerable SkipIterator(IEnumerable source, int count) + { + using (IEnumerator e = source.GetEnumerator()) + { + while (count > 0 && e.MoveNext()) count--; + if (count <= 0) + { + while (e.MoveNext()) yield return e.Current; + } + } + } + + public static IEnumerable SkipWhile(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + return SkipWhileIterator(source, predicate); + } + + private static IEnumerable SkipWhileIterator(IEnumerable source, Func predicate) + { + using (IEnumerator e = source.GetEnumerator()) + { + while (e.MoveNext()) + { + TSource element = e.Current; + if (!predicate(element)) + { + yield return element; + while (e.MoveNext()) + yield return e.Current; + yield break; + } + } + } + } + + public static IEnumerable SkipWhile(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + return SkipWhileIterator(source, predicate); + } + + private static IEnumerable SkipWhileIterator(IEnumerable source, Func predicate) + { + using (IEnumerator e = source.GetEnumerator()) + { + int index = -1; + while (e.MoveNext()) + { + checked { index++; } + TSource element = e.Current; + if (!predicate(element, index)) + { + yield return element; + while (e.MoveNext()) + yield return e.Current; + yield break; + } + } + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Sum.cs b/src/System.Linq/src/System/Linq/Sum.cs new file mode 100644 index 000000000000..9f77cb038612 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Sum.cs @@ -0,0 +1,241 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static int Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + int sum = 0; + checked + { + foreach (int v in source) sum += v; + } + return sum; + } + + public static int? Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + int sum = 0; + checked + { + foreach (int? v in source) + { + if (v != null) sum += v.GetValueOrDefault(); + } + } + return sum; + } + + public static long Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + long sum = 0; + checked + { + foreach (long v in source) sum += v; + } + return sum; + } + + public static long? Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + long sum = 0; + checked + { + foreach (long? v in source) + { + if (v != null) sum += v.GetValueOrDefault(); + } + } + return sum; + } + + public static float Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + double sum = 0; + foreach (float v in source) sum += v; + return (float)sum; + } + + public static float? Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + double sum = 0; + foreach (float? v in source) + { + if (v != null) sum += v.GetValueOrDefault(); + } + return (float)sum; + } + + public static double Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + double sum = 0; + foreach (double v in source) sum += v; + return sum; + } + + public static double? Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + double sum = 0; + foreach (double? v in source) + { + if (v != null) sum += v.GetValueOrDefault(); + } + return sum; + } + + public static decimal Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + decimal sum = 0; + foreach (decimal v in source) sum += v; + return sum; + } + + public static decimal? Sum(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + decimal sum = 0; + foreach (decimal? v in source) + { + if (v != null) sum += v.GetValueOrDefault(); + } + return sum; + } + + public static int Sum(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + int sum = 0; + checked + { + foreach (TSource item in source) sum += selector(item); + } + return sum; + } + + public static int? Sum(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + int sum = 0; + checked + { + foreach (TSource item in source) + { + int? v = selector(item); + if (v != null) sum += v.GetValueOrDefault(); + } + } + return sum; + } + + public static long Sum(this IEnumerable source, Func selector) + { + if (selector == null) throw Error.ArgumentNull("selector"); + if (source == null) throw Error.ArgumentNull("source"); + long sum = 0; + checked + { + foreach (TSource item in source) sum += selector(item); + } + return sum; + } + + public static long? Sum(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + long sum = 0; + checked + { + foreach (TSource item in source) + { + long? v = selector(item); + if (v != null) sum += v.GetValueOrDefault(); + } + } + return sum; + } + + public static float Sum(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + double sum = 0; + foreach (TSource item in source) sum += selector(item); + return (float)sum; + } + + public static float? Sum(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + double sum = 0; + foreach (TSource item in source) + { + float? v = selector(item); + if (v != null) sum += v.GetValueOrDefault(); + } + return (float)sum; + } + + public static double Sum(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + double sum = 0; + foreach (TSource item in source) sum += selector(item); + return sum; + } + + public static double? Sum(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + double sum = 0; + foreach (TSource item in source) + { + double? v = selector(item); + if (v != null) sum += v.GetValueOrDefault(); + } + return sum; + } + + public static decimal Sum(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + decimal sum = 0; + foreach (TSource item in source) sum += selector(item); + return sum; + } + + public static decimal? Sum(this IEnumerable source, Func selector) + { + if (source == null) throw Error.ArgumentNull("source"); + if (selector == null) throw Error.ArgumentNull("selector"); + decimal sum = 0; + foreach (TSource item in source) + { + decimal? v = selector(item); + if (v != null) sum += v.GetValueOrDefault(); + } + return sum; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Take.cs b/src/System.Linq/src/System/Linq/Take.cs new file mode 100644 index 000000000000..651bfea0dd1f --- /dev/null +++ b/src/System.Linq/src/System/Linq/Take.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Take(this IEnumerable source, int count) + { + if (source == null) throw Error.ArgumentNull("source"); + if (count <= 0) return new EmptyPartition(); + IPartition partition = source as IPartition; + if (partition != null) return partition.Take(count); + return TakeIterator(source, count); + } + + private static IEnumerable TakeIterator(IEnumerable source, int count) + { + foreach (TSource element in source) + { + yield return element; + if (--count == 0) break; + } + } + + public static IEnumerable TakeWhile(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + return TakeWhileIterator(source, predicate); + } + + private static IEnumerable TakeWhileIterator(IEnumerable source, Func predicate) + { + foreach (TSource element in source) + { + if (!predicate(element)) break; + yield return element; + } + } + + public static IEnumerable TakeWhile(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + return TakeWhileIterator(source, predicate); + } + + private static IEnumerable TakeWhileIterator(IEnumerable source, Func predicate) + { + int index = -1; + foreach (TSource element in source) + { + checked { index++; } + if (!predicate(element, index)) break; + yield return element; + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/ToCollection.cs b/src/System.Linq/src/System/Linq/ToCollection.cs new file mode 100644 index 000000000000..2f789ac8bf16 --- /dev/null +++ b/src/System.Linq/src/System/Linq/ToCollection.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static TSource[] ToArray(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + IArrayProvider arrayProvider = source as IArrayProvider; + return arrayProvider != null ? arrayProvider.ToArray() : EnumerableHelpers.ToArray(source); + } + + public static List ToList(this IEnumerable source) + { + if (source == null) throw Error.ArgumentNull("source"); + IListProvider listProvider = source as IListProvider; + return listProvider != null ? listProvider.ToList() : new List(source); + } + + public static Dictionary ToDictionary(this IEnumerable source, Func keySelector) + { + return ToDictionary(source, keySelector, null); + } + + public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, IEqualityComparer comparer) + { + if (source == null) throw Error.ArgumentNull("source"); + if (keySelector == null) throw Error.ArgumentNull("keySelector"); + + int capacity = 0; + ICollection collection = source as ICollection; + if (collection != null) + { + capacity = collection.Count; + if (capacity == 0) + return new Dictionary(comparer); + + TSource[] array = collection as TSource[]; + if (array != null) + return ToDictionary(array, keySelector, comparer); + List list = collection as List; + if (list != null) + return ToDictionary(list, keySelector, comparer); + } + + Dictionary d = new Dictionary(capacity, comparer); + foreach (TSource element in source) d.Add(keySelector(element), element); + return d; + } + + private static Dictionary ToDictionary(TSource[] source, Func keySelector, IEqualityComparer comparer) + { + Dictionary d = new Dictionary(source.Length, comparer); + for (int i = 0; i < source.Length; i++) d.Add(keySelector(source[i]), source[i]); + return d; + } + private static Dictionary ToDictionary(List source, Func keySelector, IEqualityComparer comparer) + { + Dictionary d = new Dictionary(source.Count, comparer); + foreach (TSource element in source) d.Add(keySelector(element), element); + return d; + } + + + public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, Func elementSelector) + { + return ToDictionary(source, keySelector, elementSelector, null); + } + + public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + { + if (source == null) throw Error.ArgumentNull("source"); + if (keySelector == null) throw Error.ArgumentNull("keySelector"); + if (elementSelector == null) throw Error.ArgumentNull("elementSelector"); + + int capacity = 0; + ICollection collection = source as ICollection; + if (collection != null) + { + capacity = collection.Count; + if (capacity == 0) + return new Dictionary(comparer); + + TSource[] array = collection as TSource[]; + if (array != null) + return ToDictionary(array, keySelector, elementSelector, comparer); + List list = collection as List; + if (list != null) + return ToDictionary(list, keySelector, elementSelector, comparer); + } + + Dictionary d = new Dictionary(capacity, comparer); + foreach (TSource element in source) d.Add(keySelector(element), elementSelector(element)); + return d; + } + + private static Dictionary ToDictionary(TSource[] source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + { + Dictionary d = new Dictionary(source.Length, comparer); + for (int i = 0; i < source.Length; i++) d.Add(keySelector(source[i]), elementSelector(source[i])); + return d; + } + + private static Dictionary ToDictionary(List source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + { + Dictionary d = new Dictionary(source.Count, comparer); + foreach (TSource element in source) d.Add(keySelector(element), elementSelector(element)); + return d; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Union.cs b/src/System.Linq/src/System/Linq/Union.cs new file mode 100644 index 000000000000..2c646f7e74c1 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Union.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Union(this IEnumerable first, IEnumerable second) + { + if (first == null) throw Error.ArgumentNull("first"); + if (second == null) throw Error.ArgumentNull("second"); + return UnionIterator(first, second, null); + } + + public static IEnumerable Union(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + if (first == null) throw Error.ArgumentNull("first"); + if (second == null) throw Error.ArgumentNull("second"); + return UnionIterator(first, second, comparer); + } + + private static IEnumerable UnionIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + Set set = new Set(comparer); + foreach (TSource element in first) + if (set.Add(element)) yield return element; + foreach (TSource element in second) + if (set.Add(element)) yield return element; + } + } +} diff --git a/src/System.Linq/src/System/Linq/Where.cs b/src/System.Linq/src/System/Linq/Where.cs new file mode 100644 index 000000000000..d27ddcbaebc5 --- /dev/null +++ b/src/System.Linq/src/System/Linq/Where.cs @@ -0,0 +1,376 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Where(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + Iterator iterator = source as Iterator; + if (iterator != null) return iterator.Where(predicate); + TSource[] array = source as TSource[]; + if (array != null) return new WhereArrayIterator(array, predicate); + List list = source as List; + if (list != null) return new WhereListIterator(list, predicate); + return new WhereEnumerableIterator(source, predicate); + } + + public static IEnumerable Where(this IEnumerable source, Func predicate) + { + if (source == null) throw Error.ArgumentNull("source"); + if (predicate == null) throw Error.ArgumentNull("predicate"); + return WhereIterator(source, predicate); + } + + private static IEnumerable WhereIterator(IEnumerable source, Func predicate) + { + int index = -1; + foreach (TSource element in source) + { + checked { index++; } + if (predicate(element, index)) yield return element; + } + } + + private static Func CombinePredicates(Func predicate1, Func predicate2) + { + return x => predicate1(x) && predicate2(x); + } + + internal class WhereEnumerableIterator : Iterator + { + private readonly IEnumerable _source; + private readonly Func _predicate; + private IEnumerator _enumerator; + + public WhereEnumerableIterator(IEnumerable source, Func predicate) + { + Debug.Assert(source != null); + Debug.Assert(predicate != null); + _source = source; + _predicate = predicate; + } + + public override Iterator Clone() + { + return new WhereEnumerableIterator(_source, _predicate); + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + while (_enumerator.MoveNext()) + { + TSource item = _enumerator.Current; + if (_predicate(item)) + { + current = item; + return true; + } + } + Dispose(); + break; + } + return false; + } + + public override IEnumerable Select(Func selector) + { + return new WhereSelectEnumerableIterator(_source, _predicate, selector); + } + + public override IEnumerable Where(Func predicate) + { + return new WhereEnumerableIterator(_source, CombinePredicates(_predicate, predicate)); + } + } + + internal class WhereArrayIterator : Iterator + { + private readonly TSource[] _source; + private readonly Func _predicate; + private int _index; + + public WhereArrayIterator(TSource[] source, Func predicate) + { + Debug.Assert(source != null); + Debug.Assert(predicate != null); + _source = source; + _predicate = predicate; + } + + public override Iterator Clone() + { + return new WhereArrayIterator(_source, _predicate); + } + + public override bool MoveNext() + { + if (state == 1) + { + while (_index < _source.Length) + { + TSource item = _source[_index]; + _index++; + if (_predicate(item)) + { + current = item; + return true; + } + } + Dispose(); + } + return false; + } + + public override IEnumerable Select(Func selector) + { + return new WhereSelectArrayIterator(_source, _predicate, selector); + } + + public override IEnumerable Where(Func predicate) + { + return new WhereArrayIterator(_source, CombinePredicates(_predicate, predicate)); + } + } + + internal class WhereListIterator : Iterator + { + private readonly List _source; + private readonly Func _predicate; + private List.Enumerator _enumerator; + + public WhereListIterator(List source, Func predicate) + { + Debug.Assert(source != null); + Debug.Assert(predicate != null); + _source = source; + _predicate = predicate; + } + + public override Iterator Clone() + { + return new WhereListIterator(_source, _predicate); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + while (_enumerator.MoveNext()) + { + TSource item = _enumerator.Current; + if (_predicate(item)) + { + current = item; + return true; + } + } + Dispose(); + break; + } + return false; + } + + public override IEnumerable Select(Func selector) + { + return new WhereSelectListIterator(_source, _predicate, selector); + } + + public override IEnumerable Where(Func predicate) + { + return new WhereListIterator(_source, CombinePredicates(_predicate, predicate)); + } + } + + internal class WhereSelectArrayIterator : Iterator + { + private readonly TSource[] _source; + private readonly Func _predicate; + private readonly Func _selector; + private int _index; + + public WhereSelectArrayIterator(TSource[] source, Func predicate, Func selector) + { + Debug.Assert(source != null); + Debug.Assert(predicate != null); + Debug.Assert(selector != null); + _source = source; + _predicate = predicate; + _selector = selector; + } + + public override Iterator Clone() + { + return new WhereSelectArrayIterator(_source, _predicate, _selector); + } + + public override bool MoveNext() + { + if (state == 1) + { + while (_index < _source.Length) + { + TSource item = _source[_index]; + _index++; + if (_predicate(item)) + { + current = _selector(item); + return true; + } + } + Dispose(); + } + return false; + } + + public override IEnumerable Select(Func selector) + { + return new WhereSelectArrayIterator(_source, _predicate, CombineSelectors(_selector, selector)); + } + } + + internal class WhereSelectListIterator : Iterator + { + private readonly List _source; + private readonly Func _predicate; + private readonly Func _selector; + private List.Enumerator _enumerator; + + public WhereSelectListIterator(List source, Func predicate, Func selector) + { + Debug.Assert(source != null); + Debug.Assert(predicate != null); + Debug.Assert(selector != null); + _source = source; + _predicate = predicate; + _selector = selector; + } + + public override Iterator Clone() + { + return new WhereSelectListIterator(_source, _predicate, _selector); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + while (_enumerator.MoveNext()) + { + TSource item = _enumerator.Current; + if (_predicate(item)) + { + current = _selector(item); + return true; + } + } + Dispose(); + break; + } + return false; + } + + public override IEnumerable Select(Func selector) + { + return new WhereSelectListIterator(_source, _predicate, CombineSelectors(_selector, selector)); + } + } + + internal class WhereSelectEnumerableIterator : Iterator + { + private readonly IEnumerable _source; + private readonly Func _predicate; + private readonly Func _selector; + private IEnumerator _enumerator; + + public WhereSelectEnumerableIterator(IEnumerable source, Func predicate, Func selector) + { + Debug.Assert(source != null); + Debug.Assert(predicate != null); + Debug.Assert(selector != null); + _source = source; + _predicate = predicate; + _selector = selector; + } + + public override Iterator Clone() + { + return new WhereSelectEnumerableIterator(_source, _predicate, _selector); + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + while (_enumerator.MoveNext()) + { + TSource item = _enumerator.Current; + if (_predicate(item)) + { + current = _selector(item); + return true; + } + } + Dispose(); + break; + } + return false; + } + + public override IEnumerable Select(Func selector) + { + return new WhereSelectEnumerableIterator(_source, _predicate, CombineSelectors(_selector, selector)); + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Zip.cs b/src/System.Linq/src/System/Linq/Zip.cs new file mode 100644 index 000000000000..969d75d589ff --- /dev/null +++ b/src/System.Linq/src/System/Linq/Zip.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Zip(this IEnumerable first, IEnumerable second, Func resultSelector) + { + if (first == null) throw Error.ArgumentNull("first"); + if (second == null) throw Error.ArgumentNull("second"); + if (resultSelector == null) throw Error.ArgumentNull("resultSelector"); + return ZipIterator(first, second, resultSelector); + } + + private static IEnumerable ZipIterator(IEnumerable first, IEnumerable second, Func resultSelector) + { + using (IEnumerator e1 = first.GetEnumerator()) + using (IEnumerator e2 = second.GetEnumerator()) + while (e1.MoveNext() && e2.MoveNext()) + yield return resultSelector(e1.Current, e2.Current); + } + } +}