Skip to content

Minor tensor fixes #115125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -177,8 +177,8 @@
<data name="ThrowArgument_InPlaceInvalidShape" xml:space="preserve">
<value>In place operations require the same shape for both tensors</value>
</data>
<data name="ThrowArgument_InvalidAxis" xml:space="preserve">
<value>Invalid axis provided. Must be greater then or equal to 0 and less than the tensor rank.</value>
<data name="ThrowArgument_InvalidDimension" xml:space="preserve">
<value>Invalid dimension provided. Must be greater then or equal to 0 and less than the tensor rank.</value>
</data>
<data name="ThrowArgument_InvalidConcatenateShape" xml:space="preserve">
<value>The tensors must have the same shape, except in the dimension corresponding to axis.</value>
Original file line number Diff line number Diff line change
@@ -133,18 +133,14 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors();

if (dimension < -1 || dimension > tensors[0].Rank)
ThrowHelper.ThrowArgument_InvalidAxis();
ThrowHelper.ThrowArgument_InvalidDimension();

// Calculate total space needed.
nint totalLength = 0;
for (int i = 0; i < tensors.Length; i++)
totalLength += tensors[i].FlattenedLength;
Tensor<T> tensor;

nint sumOfAxis = 0;
// If axis != -1, make sure all dimensions except the one to concatenate on match.
if (dimension != -1)
{
sumOfAxis = tensors[0].Lengths[dimension];
nint sumOfAxis = tensors[0].Lengths[dimension];
for (int i = 1; i < tensors.Length; i++)
{
if (tensors[0].Rank != tensors[i].Rank)
@@ -157,22 +153,31 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
ThrowHelper.ThrowArgument_InvalidConcatenateShape();
}
}
sumOfAxis += tensors[i].Lengths[dimension];
checked
{
sumOfAxis += tensors[i].Lengths[dimension];
}
}
}

Tensor<T> tensor;
if (dimension == -1)
{
tensor = Tensor.Create<T>([totalLength]);
}
else
{
nint[] lengths = new nint[tensors[0].Rank];
tensors[0].Lengths.CopyTo(lengths);
lengths[dimension] = sumOfAxis;
tensor = Tensor.Create<T>(lengths);
}
else
{
// Calculate total space needed.
nint totalLength = 0;
for (int i = 0; i < tensors.Length; i++)
{
checked
{
totalLength += tensors[i].FlattenedLength;
}
}

tensor = Tensor.Create<T>([totalLength]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this is correct and it's now calling the Tensor.Create<T>(T[] array) overload, rather than the Tensor.Create<T>(scoped ReadOnlySpan<nint> lengths, bool pinned = false) overload

I think this should stay nint and we can rely on Tensor.Create<T>(scoped ReadOnlySpan<nint> lengths, bool pinned = false) validating that it can be allocated by the underlying tensor storage.

We just need to ensure that adding the combined tensors flattened lengths together doesn't overflow the nint.


I also think that this represents a UX issue with the Create APIs.

I think we may want to disambiguate as CreateFromShape or similar so that a user passing in an nint[] isn't confused on whether its going to create a Tensor<nint> or a Tensor<T> where nint is the lengths.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed back to nint.

But it still was calling the right overload, not the Tensor.Create<T>(T[] array) one since T != int.

}

ConcatenateOnDimension(dimension, tensors, tensor);
return tensor;
@@ -201,7 +206,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors();

if (dimension < -1 || dimension > tensors[0].Rank)
ThrowHelper.ThrowArgument_InvalidAxis();
ThrowHelper.ThrowArgument_InvalidDimension();

// Calculate total space needed.
nint totalLength = 0;
@@ -212,11 +217,12 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
if (dimension != -1)
{
nint sumOfAxis = tensors[0].Lengths[dimension];
int rank = tensors[0].Rank;
for (int i = 1; i < tensors.Length; i++)
{
if (tensors[0].Rank != tensors[i].Rank)
if (rank != tensors[i].Rank)
ThrowHelper.ThrowArgument_InvalidConcatenateShape();
for (int j = 0; j < tensors[0].Rank; j++)
for (int j = 0; j < rank; j++)
{
if (j != dimension)
{
@@ -228,7 +234,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
}

// Make sure the destination tensor has the correct shape.
nint[] lengths = new nint[tensors[0].Rank];
nint[] lengths = new nint[rank];
tensors[0].Lengths.CopyTo(lengths);
lengths[dimension] = sumOfAxis;

@@ -339,18 +345,17 @@ public static Tensor<T> Create<T>(T[] array, int start, scoped ReadOnlySpan<nint
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" />.</returns>
public static Tensor<T> Create<T>(IEnumerable<T> enumerable, bool pinned = false)
{
T[] array = enumerable.ToArray();

if (pinned)
{
T[] array = enumerable.ToArray();

Tensor<T> tensor = CreateUninitialized<T>([array.Length], pinned);
array.CopyTo(tensor._values);

return tensor;
}
else
{
T[] array = enumerable.ToArray();
return Create(array);
}
}
@@ -364,18 +369,17 @@ public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" /> and with the specified <paramref name="lengths" /> and <paramref name="strides" />.</returns>
public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, bool pinned = false)
{
T[] array = enumerable.ToArray();

if (pinned)
{
T[] array = enumerable.ToArray();

Tensor<T> tensor = CreateUninitialized<T>(lengths, strides, pinned);
array.CopyTo(tensor._values);

return tensor;
}
else
{
T[] array = enumerable.ToArray();
return Create(array, lengths, strides);
}
}
@@ -620,20 +624,8 @@ public static bool EqualsAny<T>(in ReadOnlyTensorSpan<T> x, T y)
/// <param name="value">Value to update in the <paramref name="tensor"/>.</param>
public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<bool> filter, T value)
{
if (filter.Lengths.Length != tensor.Lengths.Length)
ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter));

Span<T> srcSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
Span<bool> filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength);

for (int i = 0; i < filterSpan.Length; i++)
{
if (filterSpan[i])
{
srcSpan[i] = value;
}
}

TensorOperation.ValidateCompatibility(filter, tensor);
TensorOperation.Invoke<TensorOperation.FilteredUpdate<T>, bool, T, T>(filter, value, tensor);
return ref tensor;
}

@@ -646,24 +638,8 @@ public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T>
/// <param name="values">Values to update in the <paramref name="tensor"/>.</param>
public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<bool> filter, scoped in ReadOnlyTensorSpan<T> values)
{
if (filter.Lengths.Length != tensor.Lengths.Length)
ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter));
if (values.Rank != 1)
ThrowHelper.ThrowArgument_1DTensorRequired(nameof(values));

Span<T> dstSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
Span<bool> filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength);
Span<T> valuesSpan = MemoryMarshal.CreateSpan(ref values._reference, (int)values._shape.LinearLength);

int index = 0;
for (int i = 0; i < filterSpan.Length; i++)
{
if (filterSpan[i])
{
dstSpan[i] = valuesSpan[index++];
}
}

TensorOperation.ValidateCompatibility(filter, values, tensor);
TensorOperation.Invoke<TensorOperation.FilteredUpdate<T>, bool, T, T>(filter, values, tensor);
return ref tensor;
}
#endregion
@@ -1409,6 +1385,9 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
}
else
{
if (!dimensions.IsEmpty && dimensions.Length != tensor.Lengths.Length)
ThrowHelper.ThrowArgument_PermuteAxisOrder();

scoped Span<nint> newLengths = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<nint> lengthsRentedBuffer);
scoped Span<nint> newStrides = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<nint> stridesRentedBuffer);
scoped Span<int> newLinearOrder = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<int> linearOrderRentedBuffer);
@@ -1426,11 +1405,12 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
}
else
{
if (dimensions.Length != tensor.Lengths.Length)
ThrowHelper.ThrowArgument_PermuteAxisOrder();

for (int i = 0; i < dimensions.Length; i++)
{
if (dimensions[i] >= tensor.Lengths.Length || dimensions[i] < 0)
{
ThrowHelper.ThrowArgument_InvalidDimension();
}
newLengths[i] = tensor.Lengths[dimensions[i]];
newStrides[i] = tensor.Strides[dimensions[i]];
newLinearOrder[i] = tensor._shape.LinearRankOrder[dimensions[i]];
@@ -1467,7 +1447,8 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len

nint[] newLengths = lengths.ToArray();
// Calculate wildcard info.
if (lengths.Contains(-1))
int wildcardIndex = lengths.IndexOf(-1);
if (wildcardIndex >= 0)
{
if (lengths.Count(-1) > 1)
ThrowHelper.ThrowArgument_OnlyOneWildcard();
@@ -1479,7 +1460,7 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
tempTotal /= lengths[i];
}
}
newLengths[lengths.IndexOf(-1)] = tempTotal;
newLengths[wildcardIndex] = tempTotal;
}

nint tempLinear = TensorPrimitives.Product(newLengths);
@@ -1538,8 +1519,8 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
}

nint[] newLengths = lengths.ToArray();
// Calculate wildcard info.
if (lengths.Contains(-1))
int wildcardIndex = lengths.IndexOf(-1);
if (wildcardIndex >= 0)
{
if (lengths.Count(-1) > 1)
ThrowHelper.ThrowArgument_OnlyOneWildcard();
@@ -1551,7 +1532,7 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
tempTotal /= lengths[i];
}
}
newLengths[lengths.IndexOf(-1)] = tempTotal;
newLengths[wildcardIndex] = tempTotal;

}

@@ -1615,7 +1596,8 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten

nint[] newLengths = lengths.ToArray();
// Calculate wildcard info.
if (lengths.Contains(-1))
int wildcardIndex = lengths.IndexOf(-1);
if (wildcardIndex >= 0)
{
if (lengths.Count(-1) > 1)
ThrowHelper.ThrowArgument_OnlyOneWildcard();
@@ -1627,7 +1609,7 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
tempTotal /= lengths[i];
}
}
newLengths[lengths.IndexOf(-1)] = tempTotal;
newLengths[wildcardIndex] = tempTotal;

}

@@ -1701,12 +1683,7 @@ public static Tensor<T> Resize<T>(Tensor<T> tensor, ReadOnlySpan<nint> lengths)
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> destination)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), (int)tensor._values.Length - tensor._start);
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
if (ospan.Length >= span.Length)
span.CopyTo(ospan);
else
span.Slice(0, ospan.Length).CopyTo(ospan);
ResizeTo(tensor.AsReadOnlyTensorSpan(), destination);
}

/// <summary>
@@ -1717,12 +1694,7 @@ public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> dest
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
public static void ResizeTo<T>(scoped in TensorSpan<T> tensor, in TensorSpan<T> destination)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
if (ospan.Length >= span.Length)
span.CopyTo(ospan);
else
span.Slice(0, ospan.Length).CopyTo(ospan);
ResizeTo(tensor.AsReadOnlyTensorSpan(), destination);
}

/// <summary>
@@ -1890,6 +1862,8 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso
/// <param name="dimension">The axis to split on.</param>
public static Tensor<T>[] Split<T>(scoped in ReadOnlyTensorSpan<T> tensor, int splitCount, nint dimension)
{
if (dimension < 0 || dimension >= tensor.Rank)
ThrowHelper.ThrowArgument_AxisLargerThanRank();
if (tensor.Lengths[(int)dimension] % splitCount != 0)
ThrowHelper.ThrowArgument_SplitNotSplitEvenly();

@@ -2221,8 +2195,10 @@ public static Tensor<T> StackAlongDimension<T>(int dimension, params ReadOnlySpa
ThrowHelper.ThrowArgument_StackShapesNotSame();
}

if (dimension < 0)
dimension = tensors[0].Rank - dimension;
// We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
// with our call to Unsqueeze.
if (dimension < 0 || dimension > tensors[0].Rank)
ThrowHelper.ThrowArgument_AxisLargerThanRank();

Tensor<T>[] outputs = new Tensor<T>[tensors.Length];
for (int i = 0; i < tensors.Length; i++)
@@ -2259,8 +2235,10 @@ public static ref readonly TensorSpan<T> StackAlongDimension<T>(scoped ReadOnlyS
ThrowHelper.ThrowArgument_StackShapesNotSame();
}

if (dimension < 0)
dimension = tensors[0].Rank - dimension;
// We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
// with our call to Unsqueeze.
if (dimension < 0 || dimension > tensors[0].Rank)
ThrowHelper.ThrowArgument_AxisLargerThanRank();

Tensor<T>[] outputs = new Tensor<T>[tensors.Length];
for (int i = 0; i < tensors.Length; i++)
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.