diff --git a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx
index 208b9802175091..6561dfdafbfa90 100644
--- a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx
+++ b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx
@@ -177,8 +177,8 @@
In place operations require the same shape for both tensors
-
- Invalid axis provided. Must be greater then or equal to 0 and less than the tensor rank.
+
+ Invalid dimension provided. Must be greater then or equal to 0 and less than the tensor rank.
The tensors must have the same shape, except in the dimension corresponding to axis.
diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs
index aaf9612ae0a0b3..8d4a670a1bd9b6 100644
--- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs
+++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs
@@ -133,18 +133,14 @@ public static Tensor ConcatenateOnDimension(int dimension, params scoped R
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors();
if (dimension < -1 || dimension > tensors[0].Rank)
- ThrowHelper.ThrowArgument_InvalidAxis();
+ ThrowHelper.ThrowArgument_InvalidDimension();
- // Calculate total space needed.
- nint totalLength = 0;
- for (int i = 0; i < tensors.Length; i++)
- totalLength += tensors[i].FlattenedLength;
+ Tensor tensor;
- nint sumOfAxis = 0;
// If axis != -1, make sure all dimensions except the one to concatenate on match.
if (dimension != -1)
{
- sumOfAxis = tensors[0].Lengths[dimension];
+ nint sumOfAxis = tensors[0].Lengths[dimension];
for (int i = 1; i < tensors.Length; i++)
{
if (tensors[0].Rank != tensors[i].Rank)
@@ -157,22 +153,31 @@ public static Tensor ConcatenateOnDimension(int dimension, params scoped R
ThrowHelper.ThrowArgument_InvalidConcatenateShape();
}
}
- sumOfAxis += tensors[i].Lengths[dimension];
+ checked
+ {
+ sumOfAxis += tensors[i].Lengths[dimension];
+ }
}
- }
- Tensor tensor;
- if (dimension == -1)
- {
- tensor = Tensor.Create([totalLength]);
- }
- else
- {
nint[] lengths = new nint[tensors[0].Rank];
tensors[0].Lengths.CopyTo(lengths);
lengths[dimension] = sumOfAxis;
tensor = Tensor.Create(lengths);
}
+ else
+ {
+ // Calculate total space needed.
+ nint totalLength = 0;
+ for (int i = 0; i < tensors.Length; i++)
+ {
+ checked
+ {
+ totalLength += tensors[i].FlattenedLength;
+ }
+ }
+
+ tensor = Tensor.Create([totalLength]);
+ }
ConcatenateOnDimension(dimension, tensors, tensor);
return tensor;
@@ -201,7 +206,7 @@ public static ref readonly TensorSpan ConcatenateOnDimension(int dimension
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors();
if (dimension < -1 || dimension > tensors[0].Rank)
- ThrowHelper.ThrowArgument_InvalidAxis();
+ ThrowHelper.ThrowArgument_InvalidDimension();
// Calculate total space needed.
nint totalLength = 0;
@@ -212,11 +217,12 @@ public static ref readonly TensorSpan ConcatenateOnDimension(int dimension
if (dimension != -1)
{
nint sumOfAxis = tensors[0].Lengths[dimension];
+ int rank = tensors[0].Rank;
for (int i = 1; i < tensors.Length; i++)
{
- if (tensors[0].Rank != tensors[i].Rank)
+ if (rank != tensors[i].Rank)
ThrowHelper.ThrowArgument_InvalidConcatenateShape();
- for (int j = 0; j < tensors[0].Rank; j++)
+ for (int j = 0; j < rank; j++)
{
if (j != dimension)
{
@@ -228,7 +234,7 @@ public static ref readonly TensorSpan ConcatenateOnDimension(int dimension
}
// Make sure the destination tensor has the correct shape.
- nint[] lengths = new nint[tensors[0].Rank];
+ nint[] lengths = new nint[rank];
tensors[0].Lengths.CopyTo(lengths);
lengths[dimension] = sumOfAxis;
@@ -339,10 +345,10 @@ public static Tensor Create(T[] array, int start, scoped ReadOnlySpanA new tensor that contains elements copied from .
public static Tensor Create(IEnumerable enumerable, bool pinned = false)
{
+ T[] array = enumerable.ToArray();
+
if (pinned)
{
- T[] array = enumerable.ToArray();
-
Tensor tensor = CreateUninitialized([array.Length], pinned);
array.CopyTo(tensor._values);
@@ -350,7 +356,6 @@ public static Tensor Create(IEnumerable enumerable, bool pinned = false
}
else
{
- T[] array = enumerable.ToArray();
return Create(array);
}
}
@@ -364,10 +369,10 @@ public static Tensor Create(IEnumerable enumerable, scoped ReadOnlySpan
/// A new tensor that contains elements copied from and with the specified and .
public static Tensor Create(IEnumerable enumerable, scoped ReadOnlySpan lengths, scoped ReadOnlySpan strides, bool pinned = false)
{
+ T[] array = enumerable.ToArray();
+
if (pinned)
{
- T[] array = enumerable.ToArray();
-
Tensor tensor = CreateUninitialized(lengths, strides, pinned);
array.CopyTo(tensor._values);
@@ -375,7 +380,6 @@ public static Tensor Create(IEnumerable enumerable, scoped ReadOnlySpan
}
else
{
- T[] array = enumerable.ToArray();
return Create(array, lengths, strides);
}
}
@@ -620,20 +624,8 @@ public static bool EqualsAny(in ReadOnlyTensorSpan x, T y)
/// Value to update in the .
public static ref readonly TensorSpan FilteredUpdate(in this TensorSpan tensor, scoped in ReadOnlyTensorSpan filter, T value)
{
- if (filter.Lengths.Length != tensor.Lengths.Length)
- ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter));
-
- Span srcSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
- Span filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength);
-
- for (int i = 0; i < filterSpan.Length; i++)
- {
- if (filterSpan[i])
- {
- srcSpan[i] = value;
- }
- }
-
+ TensorOperation.ValidateCompatibility(filter, tensor);
+ TensorOperation.Invoke, bool, T, T>(filter, value, tensor);
return ref tensor;
}
@@ -646,24 +638,8 @@ public static ref readonly TensorSpan FilteredUpdate(in this TensorSpan
/// Values to update in the .
public static ref readonly TensorSpan FilteredUpdate(in this TensorSpan tensor, scoped in ReadOnlyTensorSpan filter, scoped in ReadOnlyTensorSpan values)
{
- if (filter.Lengths.Length != tensor.Lengths.Length)
- ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter));
- if (values.Rank != 1)
- ThrowHelper.ThrowArgument_1DTensorRequired(nameof(values));
-
- Span dstSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
- Span filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength);
- Span valuesSpan = MemoryMarshal.CreateSpan(ref values._reference, (int)values._shape.LinearLength);
-
- int index = 0;
- for (int i = 0; i < filterSpan.Length; i++)
- {
- if (filterSpan[i])
- {
- dstSpan[i] = valuesSpan[index++];
- }
- }
-
+ TensorOperation.ValidateCompatibility(filter, values, tensor);
+ TensorOperation.Invoke, bool, T, T>(filter, values, tensor);
return ref tensor;
}
#endregion
@@ -1409,6 +1385,9 @@ public static Tensor PermuteDimensions(this Tensor tensor, ReadOnlySpan
}
else
{
+ if (!dimensions.IsEmpty && dimensions.Length != tensor.Lengths.Length)
+ ThrowHelper.ThrowArgument_PermuteAxisOrder();
+
scoped Span newLengths = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer lengthsRentedBuffer);
scoped Span newStrides = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer stridesRentedBuffer);
scoped Span newLinearOrder = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer linearOrderRentedBuffer);
@@ -1426,11 +1405,12 @@ public static Tensor PermuteDimensions(this Tensor tensor, ReadOnlySpan
}
else
{
- if (dimensions.Length != tensor.Lengths.Length)
- ThrowHelper.ThrowArgument_PermuteAxisOrder();
-
for (int i = 0; i < dimensions.Length; i++)
{
+ if (dimensions[i] >= tensor.Lengths.Length || dimensions[i] < 0)
+ {
+ ThrowHelper.ThrowArgument_InvalidDimension();
+ }
newLengths[i] = tensor.Lengths[dimensions[i]];
newStrides[i] = tensor.Strides[dimensions[i]];
newLinearOrder[i] = tensor._shape.LinearRankOrder[dimensions[i]];
@@ -1467,7 +1447,8 @@ public static Tensor Reshape(this Tensor tensor, ReadOnlySpan len
nint[] newLengths = lengths.ToArray();
// Calculate wildcard info.
- if (lengths.Contains(-1))
+ int wildcardIndex = lengths.IndexOf(-1);
+ if (wildcardIndex >= 0)
{
if (lengths.Count(-1) > 1)
ThrowHelper.ThrowArgument_OnlyOneWildcard();
@@ -1479,7 +1460,7 @@ public static Tensor Reshape(this Tensor tensor, ReadOnlySpan len
tempTotal /= lengths[i];
}
}
- newLengths[lengths.IndexOf(-1)] = tempTotal;
+ newLengths[wildcardIndex] = tempTotal;
}
nint tempLinear = TensorPrimitives.Product(newLengths);
@@ -1538,8 +1519,8 @@ public static TensorSpan Reshape(in this TensorSpan tensor, scoped Read
}
nint[] newLengths = lengths.ToArray();
- // Calculate wildcard info.
- if (lengths.Contains(-1))
+ int wildcardIndex = lengths.IndexOf(-1);
+ if (wildcardIndex >= 0)
{
if (lengths.Count(-1) > 1)
ThrowHelper.ThrowArgument_OnlyOneWildcard();
@@ -1551,7 +1532,7 @@ public static TensorSpan Reshape(in this TensorSpan tensor, scoped Read
tempTotal /= lengths[i];
}
}
- newLengths[lengths.IndexOf(-1)] = tempTotal;
+ newLengths[wildcardIndex] = tempTotal;
}
@@ -1615,7 +1596,8 @@ public static ReadOnlyTensorSpan Reshape(in this ReadOnlyTensorSpan ten
nint[] newLengths = lengths.ToArray();
// Calculate wildcard info.
- if (lengths.Contains(-1))
+ int wildcardIndex = lengths.IndexOf(-1);
+ if (wildcardIndex >= 0)
{
if (lengths.Count(-1) > 1)
ThrowHelper.ThrowArgument_OnlyOneWildcard();
@@ -1627,7 +1609,7 @@ public static ReadOnlyTensorSpan Reshape(in this ReadOnlyTensorSpan ten
tempTotal /= lengths[i];
}
}
- newLengths[lengths.IndexOf(-1)] = tempTotal;
+ newLengths[wildcardIndex] = tempTotal;
}
@@ -1701,12 +1683,7 @@ public static Tensor Resize(Tensor tensor, ReadOnlySpan lengths)
/// Destination with the desired new shape.
public static void ResizeTo(scoped in Tensor tensor, in TensorSpan destination)
{
- ReadOnlySpan span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), (int)tensor._values.Length - tensor._start);
- Span ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
- if (ospan.Length >= span.Length)
- span.CopyTo(ospan);
- else
- span.Slice(0, ospan.Length).CopyTo(ospan);
+ ResizeTo(tensor.AsReadOnlyTensorSpan(), destination);
}
///
@@ -1717,12 +1694,7 @@ public static void ResizeTo(scoped in Tensor tensor, in TensorSpan dest
/// Destination with the desired new shape.
public static void ResizeTo(scoped in TensorSpan tensor, in TensorSpan destination)
{
- ReadOnlySpan span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
- Span ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
- if (ospan.Length >= span.Length)
- span.CopyTo(ospan);
- else
- span.Slice(0, ospan.Length).CopyTo(ospan);
+ ResizeTo(tensor.AsReadOnlyTensorSpan(), destination);
}
///
@@ -1890,6 +1862,8 @@ public static ref readonly TensorSpan SetSlice(this in TensorSpan tenso
/// The axis to split on.
public static Tensor[] Split(scoped in ReadOnlyTensorSpan tensor, int splitCount, nint dimension)
{
+ if (dimension < 0 || dimension >= tensor.Rank)
+ ThrowHelper.ThrowArgument_AxisLargerThanRank();
if (tensor.Lengths[(int)dimension] % splitCount != 0)
ThrowHelper.ThrowArgument_SplitNotSplitEvenly();
@@ -2221,8 +2195,10 @@ public static Tensor StackAlongDimension(int dimension, params ReadOnlySpa
ThrowHelper.ThrowArgument_StackShapesNotSame();
}
- if (dimension < 0)
- dimension = tensors[0].Rank - dimension;
+ // We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
+ // with our call to Unsqueeze.
+ if (dimension < 0 || dimension > tensors[0].Rank)
+ ThrowHelper.ThrowArgument_AxisLargerThanRank();
Tensor[] outputs = new Tensor[tensors.Length];
for (int i = 0; i < tensors.Length; i++)
@@ -2259,8 +2235,10 @@ public static ref readonly TensorSpan StackAlongDimension(scoped ReadOnlyS
ThrowHelper.ThrowArgument_StackShapesNotSame();
}
- if (dimension < 0)
- dimension = tensors[0].Rank - dimension;
+ // We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
+ // with our call to Unsqueeze.
+ if (dimension < 0 || dimension > tensors[0].Rank)
+ ThrowHelper.ThrowArgument_AxisLargerThanRank();
Tensor[] outputs = new Tensor[tensors.Length];
for (int i = 0; i < tensors.Length; i++)
diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs
index d372602f2fa429..d3b702d050cb7f 100644
--- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs
+++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs
@@ -191,8 +191,8 @@ ref destination
xRentedBuffer.Dispose();
}
- public static void Invoke(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination)
- where TOperation : TensorOperation.IBinaryOperation_Tensor_Tensor
+ public static void Invoke(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination)
+ where TOperation : TensorOperation.IBinaryOperation_Tensor_Tensor
{
scoped Span xIndexes = RentedBuffer.Create(destination.Rank, x.Strides, out nint xLinearOffset, out RentedBuffer xRentedBuffer);
scoped Span yIndexes = RentedBuffer.Create(destination.Rank, y.Strides, out nint yLinearOffset, out RentedBuffer yRentedBuffer);
@@ -216,6 +216,10 @@ ref Unsafe.Add(ref destination._reference, destinationLinearOffset)
destinationRentedBuffer.Dispose();
}
+ public static void Invoke(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination)
+ where TOperation : TensorOperation.IBinaryOperation_Tensor_Tensor
+ => Invoke(in x, in y, in destination);
+
public static void Invoke(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, ref TResult result)
where TOperation : TensorOperation.IBinaryOperation_Tensor_Tensor
{
@@ -248,8 +252,8 @@ public static void Invoke(in ReadOnlyTensorSpan
public static void Invoke(in ReadOnlyTensorSpan x, int y, in TensorSpan destination)
where TOperation : TensorOperation.IBinaryOperation_Tensor_Int32 => Invoke(in x, y, in destination);
- public static void Invoke(in ReadOnlyTensorSpan x, T2 y, in TensorSpan destination)
- where TOperation : TensorOperation.IBinaryOperation_Tensor_Scalar
+ public static void Invoke(in ReadOnlyTensorSpan x, TArg2 y, in TensorSpan destination)
+ where TOperation : TensorOperation.IBinaryOperation_Tensor_Scalar
{
scoped Span xIndexes = RentedBuffer.Create(destination.Rank, x.Strides, out nint xLinearOffset, out RentedBuffer xRentedBuffer);
scoped Span destinationIndexes = RentedBuffer.Create(destination.Rank, destination.Strides, out nint destinationLinearOffset, out RentedBuffer destinationRentedBuffer);
@@ -292,8 +296,8 @@ ref Unsafe.Add(ref destination._reference, destinationLinearOffset)
destinationRentedBuffer.Dispose();
}
- public static void Invoke(in ReadOnlyTensorSpan x, T2 y, ref TResult result)
- where TOperation : TensorOperation.IBinaryOperation_Tensor_Scalar
+ public static void Invoke(in ReadOnlyTensorSpan x, TArg2 y, ref TResult result)
+ where TOperation : TensorOperation.IBinaryOperation_Tensor_Scalar
{
scoped Span xIndexes = RentedBuffer.Create(x.Rank, x.Strides, out nint xLinearOffset, out RentedBuffer xRentedBuffer);
@@ -335,7 +339,7 @@ public static void ValidateCompatibility(in ReadOnlyTensorSpan(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination)
+ public static void ValidateCompatibility(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpan y, in TensorSpan destination)
{
// can do bidirectional validation between x and y, that result can then be broadcast to destination
if (TensorShape.AreCompatible(x._shape, y._shape, true))
@@ -2153,6 +2157,48 @@ public static void Invoke(Span destination, T value)
}
}
+ public readonly struct FilteredUpdate
+ : IBinaryOperation_Tensor_Scalar,
+ IBinaryOperation_Tensor_Tensor
+ {
+ public static void Invoke(ref readonly bool x, ref readonly T y, ref T destination)
+ {
+ if (x)
+ {
+ destination = y;
+ }
+ }
+ public static void Invoke(ReadOnlySpan x, ReadOnlySpan y, Span destination)
+ {
+ for (int i = 0; i < x.Length; i++)
+ {
+ if (x[i])
+ {
+ destination[i] = y[i];
+ }
+ }
+ }
+
+ public static void Invoke(ref readonly bool x, T y, ref T destination)
+ {
+ if (x)
+ {
+ destination = y;
+ }
+ }
+
+ public static void Invoke(ReadOnlySpan x, T y, Span destination)
+ {
+ for (int i = 0; i < x.Length; i++)
+ {
+ if (x[i])
+ {
+ destination[i] = y;
+ }
+ }
+ }
+ }
+
public readonly struct GreaterThan
: IBinaryOperation_Tensor_Scalar,
IBinaryOperation_Tensor_Tensor
@@ -2531,10 +2577,15 @@ public interface IBinaryOperation_Scalar_Tensor
static abstract void Invoke(T1 x, ReadOnlySpan y, Span destination);
}
+ public interface IBinaryOperation_Tensor_Tensor
+ {
+ static abstract void Invoke(ref readonly T1 x, ref readonly T2 y, ref TResult destination);
+ static abstract void Invoke(ReadOnlySpan x, ReadOnlySpan y, Span destination);
+ }
+
public interface IBinaryOperation_Tensor_Tensor
+ : IBinaryOperation_Tensor_Tensor
{
- static abstract void Invoke(ref readonly T x, ref readonly T y, ref TResult destination);
- static abstract void Invoke(ReadOnlySpan x, ReadOnlySpan y, Span destination);
}
public interface IOperation
diff --git a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs
index 81338921a5810d..f19d78077b9680 100644
--- a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs
+++ b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs
@@ -197,9 +197,9 @@ public static void ThrowArgument_ConcatenateTooFewTensors()
}
[DoesNotReturn]
- public static void ThrowArgument_InvalidAxis()
+ public static void ThrowArgument_InvalidDimension()
{
- throw new ArgumentException(SR.ThrowArgument_InvalidAxis);
+ throw new ArgumentException(SR.ThrowArgument_InvalidDimension);
}
[DoesNotReturn]