-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Minor tensor fixes #115125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minor tensor fixes #115125
Changes from all commits
2819cc9
4303b30
4010f66
4b6b9dc
c6ae9a0
1eba5fd
c07953f
8948b81
966a31e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,18 +133,14 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R | |
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors(); | ||
|
||
if (dimension < -1 || dimension > tensors[0].Rank) | ||
ThrowHelper.ThrowArgument_InvalidAxis(); | ||
ThrowHelper.ThrowArgument_InvalidDimension(); | ||
|
||
// Calculate total space needed. | ||
nint totalLength = 0; | ||
for (int i = 0; i < tensors.Length; i++) | ||
totalLength += tensors[i].FlattenedLength; | ||
Tensor<T> tensor; | ||
|
||
nint sumOfAxis = 0; | ||
// If axis != -1, make sure all dimensions except the one to concatenate on match. | ||
if (dimension != -1) | ||
{ | ||
sumOfAxis = tensors[0].Lengths[dimension]; | ||
nint sumOfAxis = tensors[0].Lengths[dimension]; | ||
for (int i = 1; i < tensors.Length; i++) | ||
{ | ||
if (tensors[0].Rank != tensors[i].Rank) | ||
|
@@ -157,22 +153,31 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R | |
ThrowHelper.ThrowArgument_InvalidConcatenateShape(); | ||
} | ||
} | ||
sumOfAxis += tensors[i].Lengths[dimension]; | ||
checked | ||
{ | ||
sumOfAxis += tensors[i].Lengths[dimension]; | ||
} | ||
} | ||
} | ||
|
||
Tensor<T> tensor; | ||
if (dimension == -1) | ||
{ | ||
tensor = Tensor.Create<T>([totalLength]); | ||
} | ||
else | ||
{ | ||
nint[] lengths = new nint[tensors[0].Rank]; | ||
tensors[0].Lengths.CopyTo(lengths); | ||
lengths[dimension] = sumOfAxis; | ||
tensor = Tensor.Create<T>(lengths); | ||
} | ||
else | ||
{ | ||
// Calculate total space needed. | ||
nint totalLength = 0; | ||
for (int i = 0; i < tensors.Length; i++) | ||
michaelgsharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
checked | ||
{ | ||
totalLength += tensors[i].FlattenedLength; | ||
} | ||
} | ||
|
||
tensor = Tensor.Create<T>([totalLength]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't think this is correct and it's now calling the I think this should stay We just need to ensure that adding the combined tensors flattened lengths together doesn't overflow the I also think that this represents a UX issue with the I think we may want to disambiguate as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed back to nint. But it still was calling the right overload, not the |
||
} | ||
|
||
ConcatenateOnDimension(dimension, tensors, tensor); | ||
return tensor; | ||
|
@@ -201,7 +206,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension | |
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors(); | ||
|
||
if (dimension < -1 || dimension > tensors[0].Rank) | ||
ThrowHelper.ThrowArgument_InvalidAxis(); | ||
ThrowHelper.ThrowArgument_InvalidDimension(); | ||
|
||
// Calculate total space needed. | ||
nint totalLength = 0; | ||
|
@@ -212,11 +217,12 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension | |
if (dimension != -1) | ||
{ | ||
nint sumOfAxis = tensors[0].Lengths[dimension]; | ||
int rank = tensors[0].Rank; | ||
for (int i = 1; i < tensors.Length; i++) | ||
{ | ||
if (tensors[0].Rank != tensors[i].Rank) | ||
if (rank != tensors[i].Rank) | ||
ThrowHelper.ThrowArgument_InvalidConcatenateShape(); | ||
for (int j = 0; j < tensors[0].Rank; j++) | ||
for (int j = 0; j < rank; j++) | ||
{ | ||
michaelgsharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (j != dimension) | ||
{ | ||
|
@@ -228,7 +234,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension | |
} | ||
|
||
// Make sure the destination tensor has the correct shape. | ||
nint[] lengths = new nint[tensors[0].Rank]; | ||
nint[] lengths = new nint[rank]; | ||
tensors[0].Lengths.CopyTo(lengths); | ||
lengths[dimension] = sumOfAxis; | ||
|
||
|
@@ -339,18 +345,17 @@ public static Tensor<T> Create<T>(T[] array, int start, scoped ReadOnlySpan<nint | |
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" />.</returns> | ||
public static Tensor<T> Create<T>(IEnumerable<T> enumerable, bool pinned = false) | ||
{ | ||
T[] array = enumerable.ToArray(); | ||
|
||
if (pinned) | ||
{ | ||
T[] array = enumerable.ToArray(); | ||
|
||
Tensor<T> tensor = CreateUninitialized<T>([array.Length], pinned); | ||
array.CopyTo(tensor._values); | ||
|
||
return tensor; | ||
} | ||
else | ||
{ | ||
T[] array = enumerable.ToArray(); | ||
return Create(array); | ||
} | ||
} | ||
|
@@ -364,18 +369,17 @@ public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan | |
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" /> and with the specified <paramref name="lengths" /> and <paramref name="strides" />.</returns> | ||
public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, bool pinned = false) | ||
{ | ||
T[] array = enumerable.ToArray(); | ||
|
||
if (pinned) | ||
{ | ||
T[] array = enumerable.ToArray(); | ||
|
||
Tensor<T> tensor = CreateUninitialized<T>(lengths, strides, pinned); | ||
array.CopyTo(tensor._values); | ||
|
||
return tensor; | ||
} | ||
else | ||
{ | ||
T[] array = enumerable.ToArray(); | ||
return Create(array, lengths, strides); | ||
} | ||
} | ||
|
@@ -620,20 +624,8 @@ public static bool EqualsAny<T>(in ReadOnlyTensorSpan<T> x, T y) | |
/// <param name="value">Value to update in the <paramref name="tensor"/>.</param> | ||
public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<bool> filter, T value) | ||
{ | ||
if (filter.Lengths.Length != tensor.Lengths.Length) | ||
ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter)); | ||
|
||
Span<T> srcSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength); | ||
Span<bool> filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength); | ||
|
||
for (int i = 0; i < filterSpan.Length; i++) | ||
{ | ||
if (filterSpan[i]) | ||
{ | ||
srcSpan[i] = value; | ||
} | ||
} | ||
|
||
TensorOperation.ValidateCompatibility(filter, tensor); | ||
TensorOperation.Invoke<TensorOperation.FilteredUpdate<T>, bool, T, T>(filter, value, tensor); | ||
return ref tensor; | ||
} | ||
|
||
|
@@ -646,24 +638,8 @@ public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T> | |
/// <param name="values">Values to update in the <paramref name="tensor"/>.</param> | ||
public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<bool> filter, scoped in ReadOnlyTensorSpan<T> values) | ||
{ | ||
if (filter.Lengths.Length != tensor.Lengths.Length) | ||
ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter)); | ||
if (values.Rank != 1) | ||
ThrowHelper.ThrowArgument_1DTensorRequired(nameof(values)); | ||
|
||
Span<T> dstSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength); | ||
Span<bool> filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength); | ||
Span<T> valuesSpan = MemoryMarshal.CreateSpan(ref values._reference, (int)values._shape.LinearLength); | ||
|
||
int index = 0; | ||
for (int i = 0; i < filterSpan.Length; i++) | ||
{ | ||
if (filterSpan[i]) | ||
{ | ||
dstSpan[i] = valuesSpan[index++]; | ||
} | ||
} | ||
|
||
TensorOperation.ValidateCompatibility(filter, values, tensor); | ||
TensorOperation.Invoke<TensorOperation.FilteredUpdate<T>, bool, T, T>(filter, values, tensor); | ||
return ref tensor; | ||
} | ||
#endregion | ||
|
@@ -1409,6 +1385,9 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan | |
} | ||
else | ||
{ | ||
if (!dimensions.IsEmpty && dimensions.Length != tensor.Lengths.Length) | ||
ThrowHelper.ThrowArgument_PermuteAxisOrder(); | ||
|
||
scoped Span<nint> newLengths = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<nint> lengthsRentedBuffer); | ||
scoped Span<nint> newStrides = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<nint> stridesRentedBuffer); | ||
scoped Span<int> newLinearOrder = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<int> linearOrderRentedBuffer); | ||
|
@@ -1426,11 +1405,12 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan | |
} | ||
else | ||
{ | ||
if (dimensions.Length != tensor.Lengths.Length) | ||
ThrowHelper.ThrowArgument_PermuteAxisOrder(); | ||
|
||
for (int i = 0; i < dimensions.Length; i++) | ||
{ | ||
if (dimensions[i] >= tensor.Lengths.Length || dimensions[i] < 0) | ||
{ | ||
ThrowHelper.ThrowArgument_InvalidDimension(); | ||
} | ||
newLengths[i] = tensor.Lengths[dimensions[i]]; | ||
newStrides[i] = tensor.Strides[dimensions[i]]; | ||
newLinearOrder[i] = tensor._shape.LinearRankOrder[dimensions[i]]; | ||
|
@@ -1467,7 +1447,8 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len | |
|
||
nint[] newLengths = lengths.ToArray(); | ||
// Calculate wildcard info. | ||
if (lengths.Contains(-1)) | ||
int wildcardIndex = lengths.IndexOf(-1); | ||
if (wildcardIndex >= 0) | ||
{ | ||
if (lengths.Count(-1) > 1) | ||
ThrowHelper.ThrowArgument_OnlyOneWildcard(); | ||
|
@@ -1479,7 +1460,7 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len | |
tempTotal /= lengths[i]; | ||
} | ||
} | ||
newLengths[lengths.IndexOf(-1)] = tempTotal; | ||
newLengths[wildcardIndex] = tempTotal; | ||
} | ||
|
||
nint tempLinear = TensorPrimitives.Product(newLengths); | ||
|
@@ -1538,8 +1519,8 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read | |
} | ||
|
||
nint[] newLengths = lengths.ToArray(); | ||
// Calculate wildcard info. | ||
if (lengths.Contains(-1)) | ||
int wildcardIndex = lengths.IndexOf(-1); | ||
if (wildcardIndex >= 0) | ||
{ | ||
if (lengths.Count(-1) > 1) | ||
ThrowHelper.ThrowArgument_OnlyOneWildcard(); | ||
|
@@ -1551,7 +1532,7 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read | |
tempTotal /= lengths[i]; | ||
} | ||
} | ||
newLengths[lengths.IndexOf(-1)] = tempTotal; | ||
newLengths[wildcardIndex] = tempTotal; | ||
|
||
} | ||
|
||
|
@@ -1615,7 +1596,8 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten | |
|
||
nint[] newLengths = lengths.ToArray(); | ||
// Calculate wildcard info. | ||
if (lengths.Contains(-1)) | ||
int wildcardIndex = lengths.IndexOf(-1); | ||
if (wildcardIndex >= 0) | ||
{ | ||
if (lengths.Count(-1) > 1) | ||
ThrowHelper.ThrowArgument_OnlyOneWildcard(); | ||
|
@@ -1627,7 +1609,7 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten | |
tempTotal /= lengths[i]; | ||
} | ||
} | ||
newLengths[lengths.IndexOf(-1)] = tempTotal; | ||
newLengths[wildcardIndex] = tempTotal; | ||
|
||
} | ||
|
||
|
@@ -1701,12 +1683,7 @@ public static Tensor<T> Resize<T>(Tensor<T> tensor, ReadOnlySpan<nint> lengths) | |
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param> | ||
public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> destination) | ||
{ | ||
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), (int)tensor._values.Length - tensor._start); | ||
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength); | ||
if (ospan.Length >= span.Length) | ||
span.CopyTo(ospan); | ||
else | ||
span.Slice(0, ospan.Length).CopyTo(ospan); | ||
ResizeTo(tensor.AsReadOnlyTensorSpan(), destination); | ||
} | ||
|
||
/// <summary> | ||
|
@@ -1717,12 +1694,7 @@ public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> dest | |
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param> | ||
public static void ResizeTo<T>(scoped in TensorSpan<T> tensor, in TensorSpan<T> destination) | ||
{ | ||
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength); | ||
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength); | ||
if (ospan.Length >= span.Length) | ||
span.CopyTo(ospan); | ||
else | ||
span.Slice(0, ospan.Length).CopyTo(ospan); | ||
ResizeTo(tensor.AsReadOnlyTensorSpan(), destination); | ||
} | ||
|
||
/// <summary> | ||
|
@@ -1890,6 +1862,8 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso | |
/// <param name="dimension">The axis to split on.</param> | ||
public static Tensor<T>[] Split<T>(scoped in ReadOnlyTensorSpan<T> tensor, int splitCount, nint dimension) | ||
{ | ||
if (dimension < 0 || dimension >= tensor.Rank) | ||
ThrowHelper.ThrowArgument_AxisLargerThanRank(); | ||
if (tensor.Lengths[(int)dimension] % splitCount != 0) | ||
ThrowHelper.ThrowArgument_SplitNotSplitEvenly(); | ||
|
||
|
@@ -2221,8 +2195,10 @@ public static Tensor<T> StackAlongDimension<T>(int dimension, params ReadOnlySpa | |
ThrowHelper.ThrowArgument_StackShapesNotSame(); | ||
} | ||
|
||
if (dimension < 0) | ||
dimension = tensors[0].Rank - dimension; | ||
// We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension | ||
// with our call to Unsqueeze. | ||
if (dimension < 0 || dimension > tensors[0].Rank) | ||
michaelgsharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ThrowHelper.ThrowArgument_AxisLargerThanRank(); | ||
|
||
Tensor<T>[] outputs = new Tensor<T>[tensors.Length]; | ||
for (int i = 0; i < tensors.Length; i++) | ||
|
@@ -2259,8 +2235,10 @@ public static ref readonly TensorSpan<T> StackAlongDimension<T>(scoped ReadOnlyS | |
ThrowHelper.ThrowArgument_StackShapesNotSame(); | ||
} | ||
|
||
if (dimension < 0) | ||
dimension = tensors[0].Rank - dimension; | ||
// We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension | ||
// with our call to Unsqueeze. | ||
if (dimension < 0 || dimension > tensors[0].Rank) | ||
michaelgsharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ThrowHelper.ThrowArgument_AxisLargerThanRank(); | ||
|
||
Tensor<T>[] outputs = new Tensor<T>[tensors.Length]; | ||
for (int i = 0; i < tensors.Length; i++) | ||
|
Uh oh!
There was an error while loading. Please reload this page.