diff --git a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx index 208b9802175091..6561dfdafbfa90 100644 --- a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx +++ b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx @@ -177,8 +177,8 @@ In place operations require the same shape for both tensors - - Invalid axis provided. Must be greater then or equal to 0 and less than the tensor rank. + + Invalid dimension provided. Must be greater then or equal to 0 and less than the tensor rank. The tensors must have the same shape, except in the dimension corresponding to axis. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs index aaf9612ae0a0b3..8d4a670a1bd9b6 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs @@ -133,18 +133,14 @@ public static Tensor ConcatenateOnDimension(int dimension, params scoped R ThrowHelper.ThrowArgument_ConcatenateTooFewTensors(); if (dimension < -1 || dimension > tensors[0].Rank) - ThrowHelper.ThrowArgument_InvalidAxis(); + ThrowHelper.ThrowArgument_InvalidDimension(); - // Calculate total space needed. - nint totalLength = 0; - for (int i = 0; i < tensors.Length; i++) - totalLength += tensors[i].FlattenedLength; + Tensor tensor; - nint sumOfAxis = 0; // If axis != -1, make sure all dimensions except the one to concatenate on match. if (dimension != -1) { - sumOfAxis = tensors[0].Lengths[dimension]; + nint sumOfAxis = tensors[0].Lengths[dimension]; for (int i = 1; i < tensors.Length; i++) { if (tensors[0].Rank != tensors[i].Rank) @@ -157,22 +153,31 @@ public static Tensor ConcatenateOnDimension(int dimension, params scoped R ThrowHelper.ThrowArgument_InvalidConcatenateShape(); } } - sumOfAxis += tensors[i].Lengths[dimension]; + checked + { + sumOfAxis += tensors[i].Lengths[dimension]; + } } - } - Tensor tensor; - if (dimension == -1) - { - tensor = Tensor.Create([totalLength]); - } - else - { nint[] lengths = new nint[tensors[0].Rank]; tensors[0].Lengths.CopyTo(lengths); lengths[dimension] = sumOfAxis; tensor = Tensor.Create(lengths); } + else + { + // Calculate total space needed. + nint totalLength = 0; + for (int i = 0; i < tensors.Length; i++) + { + checked + { + totalLength += tensors[i].FlattenedLength; + } + } + + tensor = Tensor.Create([totalLength]); + } ConcatenateOnDimension(dimension, tensors, tensor); return tensor; @@ -201,7 +206,7 @@ public static ref readonly TensorSpan ConcatenateOnDimension(int dimension ThrowHelper.ThrowArgument_ConcatenateTooFewTensors(); if (dimension < -1 || dimension > tensors[0].Rank) - ThrowHelper.ThrowArgument_InvalidAxis(); + ThrowHelper.ThrowArgument_InvalidDimension(); // Calculate total space needed. nint totalLength = 0; @@ -212,11 +217,12 @@ public static ref readonly TensorSpan ConcatenateOnDimension(int dimension if (dimension != -1) { nint sumOfAxis = tensors[0].Lengths[dimension]; + int rank = tensors[0].Rank; for (int i = 1; i < tensors.Length; i++) { - if (tensors[0].Rank != tensors[i].Rank) + if (rank != tensors[i].Rank) ThrowHelper.ThrowArgument_InvalidConcatenateShape(); - for (int j = 0; j < tensors[0].Rank; j++) + for (int j = 0; j < rank; j++) { if (j != dimension) { @@ -228,7 +234,7 @@ public static ref readonly TensorSpan ConcatenateOnDimension(int dimension } // Make sure the destination tensor has the correct shape. - nint[] lengths = new nint[tensors[0].Rank]; + nint[] lengths = new nint[rank]; tensors[0].Lengths.CopyTo(lengths); lengths[dimension] = sumOfAxis; @@ -339,10 +345,10 @@ public static Tensor Create(T[] array, int start, scoped ReadOnlySpanA new tensor that contains elements copied from . public static Tensor Create(IEnumerable enumerable, bool pinned = false) { + T[] array = enumerable.ToArray(); + if (pinned) { - T[] array = enumerable.ToArray(); - Tensor tensor = CreateUninitialized([array.Length], pinned); array.CopyTo(tensor._values); @@ -350,7 +356,6 @@ public static Tensor Create(IEnumerable enumerable, bool pinned = false } else { - T[] array = enumerable.ToArray(); return Create(array); } } @@ -364,10 +369,10 @@ public static Tensor Create(IEnumerable enumerable, scoped ReadOnlySpan /// A new tensor that contains elements copied from and with the specified and . public static Tensor Create(IEnumerable enumerable, scoped ReadOnlySpan lengths, scoped ReadOnlySpan strides, bool pinned = false) { + T[] array = enumerable.ToArray(); + if (pinned) { - T[] array = enumerable.ToArray(); - Tensor tensor = CreateUninitialized(lengths, strides, pinned); array.CopyTo(tensor._values); @@ -375,7 +380,6 @@ public static Tensor Create(IEnumerable enumerable, scoped ReadOnlySpan } else { - T[] array = enumerable.ToArray(); return Create(array, lengths, strides); } } @@ -620,20 +624,8 @@ public static bool EqualsAny(in ReadOnlyTensorSpan x, T y) /// Value to update in the . public static ref readonly TensorSpan FilteredUpdate(in this TensorSpan tensor, scoped in ReadOnlyTensorSpan filter, T value) { - if (filter.Lengths.Length != tensor.Lengths.Length) - ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter)); - - Span srcSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength); - Span filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength); - - for (int i = 0; i < filterSpan.Length; i++) - { - if (filterSpan[i]) - { - srcSpan[i] = value; - } - } - + TensorOperation.ValidateCompatibility(filter, tensor); + TensorOperation.Invoke, bool, T, T>(filter, value, tensor); return ref tensor; } @@ -646,24 +638,8 @@ public static ref readonly TensorSpan FilteredUpdate(in this TensorSpan /// Values to update in the . public static ref readonly TensorSpan FilteredUpdate(in this TensorSpan tensor, scoped in ReadOnlyTensorSpan filter, scoped in ReadOnlyTensorSpan values) { - if (filter.Lengths.Length != tensor.Lengths.Length) - ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter)); - if (values.Rank != 1) - ThrowHelper.ThrowArgument_1DTensorRequired(nameof(values)); - - Span dstSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength); - Span filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength); - Span valuesSpan = MemoryMarshal.CreateSpan(ref values._reference, (int)values._shape.LinearLength); - - int index = 0; - for (int i = 0; i < filterSpan.Length; i++) - { - if (filterSpan[i]) - { - dstSpan[i] = valuesSpan[index++]; - } - } - + TensorOperation.ValidateCompatibility(filter, values, tensor); + TensorOperation.Invoke, bool, T, T>(filter, values, tensor); return ref tensor; } #endregion @@ -1409,6 +1385,9 @@ public static Tensor PermuteDimensions(this Tensor tensor, ReadOnlySpan } else { + if (!dimensions.IsEmpty && dimensions.Length != tensor.Lengths.Length) + ThrowHelper.ThrowArgument_PermuteAxisOrder(); + scoped Span newLengths = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer lengthsRentedBuffer); scoped Span newStrides = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer stridesRentedBuffer); scoped Span newLinearOrder = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer linearOrderRentedBuffer); @@ -1426,11 +1405,12 @@ public static Tensor PermuteDimensions(this Tensor tensor, ReadOnlySpan } else { - if (dimensions.Length != tensor.Lengths.Length) - ThrowHelper.ThrowArgument_PermuteAxisOrder(); - for (int i = 0; i < dimensions.Length; i++) { + if (dimensions[i] >= tensor.Lengths.Length || dimensions[i] < 0) + { + ThrowHelper.ThrowArgument_InvalidDimension(); + } newLengths[i] = tensor.Lengths[dimensions[i]]; newStrides[i] = tensor.Strides[dimensions[i]]; newLinearOrder[i] = tensor._shape.LinearRankOrder[dimensions[i]]; @@ -1467,7 +1447,8 @@ public static Tensor Reshape(this Tensor tensor, ReadOnlySpan len nint[] newLengths = lengths.ToArray(); // Calculate wildcard info. - if (lengths.Contains(-1)) + int wildcardIndex = lengths.IndexOf(-1); + if (wildcardIndex >= 0) { if (lengths.Count(-1) > 1) ThrowHelper.ThrowArgument_OnlyOneWildcard(); @@ -1479,7 +1460,7 @@ public static Tensor Reshape(this Tensor tensor, ReadOnlySpan len tempTotal /= lengths[i]; } } - newLengths[lengths.IndexOf(-1)] = tempTotal; + newLengths[wildcardIndex] = tempTotal; } nint tempLinear = TensorPrimitives.Product(newLengths); @@ -1538,8 +1519,8 @@ public static TensorSpan Reshape(in this TensorSpan tensor, scoped Read } nint[] newLengths = lengths.ToArray(); - // Calculate wildcard info. - if (lengths.Contains(-1)) + int wildcardIndex = lengths.IndexOf(-1); + if (wildcardIndex >= 0) { if (lengths.Count(-1) > 1) ThrowHelper.ThrowArgument_OnlyOneWildcard(); @@ -1551,7 +1532,7 @@ public static TensorSpan Reshape(in this TensorSpan tensor, scoped Read tempTotal /= lengths[i]; } } - newLengths[lengths.IndexOf(-1)] = tempTotal; + newLengths[wildcardIndex] = tempTotal; } @@ -1615,7 +1596,8 @@ public static ReadOnlyTensorSpan Reshape(in this ReadOnlyTensorSpan ten nint[] newLengths = lengths.ToArray(); // Calculate wildcard info. - if (lengths.Contains(-1)) + int wildcardIndex = lengths.IndexOf(-1); + if (wildcardIndex >= 0) { if (lengths.Count(-1) > 1) ThrowHelper.ThrowArgument_OnlyOneWildcard(); @@ -1627,7 +1609,7 @@ public static ReadOnlyTensorSpan Reshape(in this ReadOnlyTensorSpan ten tempTotal /= lengths[i]; } } - newLengths[lengths.IndexOf(-1)] = tempTotal; + newLengths[wildcardIndex] = tempTotal; } @@ -1701,12 +1683,7 @@ public static Tensor Resize(Tensor tensor, ReadOnlySpan lengths) /// Destination with the desired new shape. public static void ResizeTo(scoped in Tensor tensor, in TensorSpan destination) { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), (int)tensor._values.Length - tensor._start); - Span ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength); - if (ospan.Length >= span.Length) - span.CopyTo(ospan); - else - span.Slice(0, ospan.Length).CopyTo(ospan); + ResizeTo(tensor.AsReadOnlyTensorSpan(), destination); } /// @@ -1717,12 +1694,7 @@ public static void ResizeTo(scoped in Tensor tensor, in TensorSpan dest /// Destination with the desired new shape. public static void ResizeTo(scoped in TensorSpan tensor, in TensorSpan destination) { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength); - Span ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength); - if (ospan.Length >= span.Length) - span.CopyTo(ospan); - else - span.Slice(0, ospan.Length).CopyTo(ospan); + ResizeTo(tensor.AsReadOnlyTensorSpan(), destination); } /// @@ -1890,6 +1862,8 @@ public static ref readonly TensorSpan SetSlice(this in TensorSpan tenso /// The axis to split on. public static Tensor[] Split(scoped in ReadOnlyTensorSpan tensor, int splitCount, nint dimension) { + if (dimension < 0 || dimension >= tensor.Rank) + ThrowHelper.ThrowArgument_AxisLargerThanRank(); if (tensor.Lengths[(int)dimension] % splitCount != 0) ThrowHelper.ThrowArgument_SplitNotSplitEvenly(); @@ -2221,8 +2195,10 @@ public static Tensor StackAlongDimension(int dimension, params ReadOnlySpa ThrowHelper.ThrowArgument_StackShapesNotSame(); } - if (dimension < 0) - dimension = tensors[0].Rank - dimension; + // We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension + // with our call to Unsqueeze. + if (dimension < 0 || dimension > tensors[0].Rank) + ThrowHelper.ThrowArgument_AxisLargerThanRank(); Tensor[] outputs = new Tensor[tensors.Length]; for (int i = 0; i < tensors.Length; i++) @@ -2259,8 +2235,10 @@ public static ref readonly TensorSpan StackAlongDimension(scoped ReadOnlyS ThrowHelper.ThrowArgument_StackShapesNotSame(); } - if (dimension < 0) - dimension = tensors[0].Rank - dimension; + // We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension + // with our call to Unsqueeze. + if (dimension < 0 || dimension > tensors[0].Rank) + ThrowHelper.ThrowArgument_AxisLargerThanRank(); Tensor[] outputs = new Tensor[tensors.Length]; for (int i = 0; i < tensors.Length; i++) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs index d372602f2fa429..d3b702d050cb7f 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs @@ -191,8 +191,8 @@ ref destination xRentedBuffer.Dispose(); } - public static void Invoke(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination) - where TOperation : TensorOperation.IBinaryOperation_Tensor_Tensor + public static void Invoke(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination) + where TOperation : TensorOperation.IBinaryOperation_Tensor_Tensor { scoped Span xIndexes = RentedBuffer.Create(destination.Rank, x.Strides, out nint xLinearOffset, out RentedBuffer xRentedBuffer); scoped Span yIndexes = RentedBuffer.Create(destination.Rank, y.Strides, out nint yLinearOffset, out RentedBuffer yRentedBuffer); @@ -216,6 +216,10 @@ ref Unsafe.Add(ref destination._reference, destinationLinearOffset) destinationRentedBuffer.Dispose(); } + public static void Invoke(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination) + where TOperation : TensorOperation.IBinaryOperation_Tensor_Tensor + => Invoke(in x, in y, in destination); + public static void Invoke(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, ref TResult result) where TOperation : TensorOperation.IBinaryOperation_Tensor_Tensor { @@ -248,8 +252,8 @@ public static void Invoke(in ReadOnlyTensorSpan public static void Invoke(in ReadOnlyTensorSpan x, int y, in TensorSpan destination) where TOperation : TensorOperation.IBinaryOperation_Tensor_Int32 => Invoke(in x, y, in destination); - public static void Invoke(in ReadOnlyTensorSpan x, T2 y, in TensorSpan destination) - where TOperation : TensorOperation.IBinaryOperation_Tensor_Scalar + public static void Invoke(in ReadOnlyTensorSpan x, TArg2 y, in TensorSpan destination) + where TOperation : TensorOperation.IBinaryOperation_Tensor_Scalar { scoped Span xIndexes = RentedBuffer.Create(destination.Rank, x.Strides, out nint xLinearOffset, out RentedBuffer xRentedBuffer); scoped Span destinationIndexes = RentedBuffer.Create(destination.Rank, destination.Strides, out nint destinationLinearOffset, out RentedBuffer destinationRentedBuffer); @@ -292,8 +296,8 @@ ref Unsafe.Add(ref destination._reference, destinationLinearOffset) destinationRentedBuffer.Dispose(); } - public static void Invoke(in ReadOnlyTensorSpan x, T2 y, ref TResult result) - where TOperation : TensorOperation.IBinaryOperation_Tensor_Scalar + public static void Invoke(in ReadOnlyTensorSpan x, TArg2 y, ref TResult result) + where TOperation : TensorOperation.IBinaryOperation_Tensor_Scalar { scoped Span xIndexes = RentedBuffer.Create(x.Rank, x.Strides, out nint xLinearOffset, out RentedBuffer xRentedBuffer); @@ -335,7 +339,7 @@ public static void ValidateCompatibility(in ReadOnlyTensorSpan(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination) + public static void ValidateCompatibility(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination) { // can do bidirectional validation between x and y, that result can then be broadcast to destination if (TensorShape.AreCompatible(x._shape, y._shape, true)) @@ -2153,6 +2157,48 @@ public static void Invoke(Span destination, T value) } } + public readonly struct FilteredUpdate + : IBinaryOperation_Tensor_Scalar, + IBinaryOperation_Tensor_Tensor + { + public static void Invoke(ref readonly bool x, ref readonly T y, ref T destination) + { + if (x) + { + destination = y; + } + } + public static void Invoke(ReadOnlySpan x, ReadOnlySpan y, Span destination) + { + for (int i = 0; i < x.Length; i++) + { + if (x[i]) + { + destination[i] = y[i]; + } + } + } + + public static void Invoke(ref readonly bool x, T y, ref T destination) + { + if (x) + { + destination = y; + } + } + + public static void Invoke(ReadOnlySpan x, T y, Span destination) + { + for (int i = 0; i < x.Length; i++) + { + if (x[i]) + { + destination[i] = y; + } + } + } + } + public readonly struct GreaterThan : IBinaryOperation_Tensor_Scalar, IBinaryOperation_Tensor_Tensor @@ -2531,10 +2577,15 @@ public interface IBinaryOperation_Scalar_Tensor static abstract void Invoke(T1 x, ReadOnlySpan y, Span destination); } + public interface IBinaryOperation_Tensor_Tensor + { + static abstract void Invoke(ref readonly T1 x, ref readonly T2 y, ref TResult destination); + static abstract void Invoke(ReadOnlySpan x, ReadOnlySpan y, Span destination); + } + public interface IBinaryOperation_Tensor_Tensor + : IBinaryOperation_Tensor_Tensor { - static abstract void Invoke(ref readonly T x, ref readonly T y, ref TResult destination); - static abstract void Invoke(ReadOnlySpan x, ReadOnlySpan y, Span destination); } public interface IOperation diff --git a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs index 81338921a5810d..f19d78077b9680 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs @@ -197,9 +197,9 @@ public static void ThrowArgument_ConcatenateTooFewTensors() } [DoesNotReturn] - public static void ThrowArgument_InvalidAxis() + public static void ThrowArgument_InvalidDimension() { - throw new ArgumentException(SR.ThrowArgument_InvalidAxis); + throw new ArgumentException(SR.ThrowArgument_InvalidDimension); } [DoesNotReturn]