Skip to content

Commit

Permalink
[NDTensors] Optimize permutedims (#1288)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 16, 2023
1 parent a00ef70 commit faa0357
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 10 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ SparseArrays = "1.6"
SplitApplyCombine = "1.2.2"
StaticArrays = "0.12, 1.0"
Strided = "2"
StridedViews = "0.2"
StridedViews = "0.2.2"
TimerOutputs = "0.5.5"
TupleTools = "1.2.0"
VectorInterface = "0.4.2"
Expand Down
12 changes: 8 additions & 4 deletions NDTensors/src/array/permutedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ function permutedims(E::Exposed{<:Array}, perm)
end

function permutedims!(Edest::Exposed{<:Array}, Esrc::Exposed{<:Array}, perm)
@strided unexpose(Edest) .= permutedims(Esrc, perm)
return unexpose(Edest)
a_dest = unexpose(Edest)
a_src = unexpose(Esrc)
@strided a_dest .= permutedims(a_src, perm)
return a_dest
end

function permutedims!(Edest::Exposed{<:Array}, Esrc::Exposed{<:Array}, perm, f)
@strided unexpose(Edest) .= f.(unexpose(Edest), permutedims(Esrc, perm))
return unexpose(Edest)
a_dest = unexpose(Edest)
a_src = unexpose(Esrc)
@strided a_dest .= f.(a_dest, permutedims(a_src, perm))
return a_dest
end
11 changes: 9 additions & 2 deletions NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,20 @@ end
# Maybe allocate output data.
# TODO: Remove this in favor of `map!`
# applied to `PermutedDimsArray`.
function permutedims!!(R::DenseTensor, T::DenseTensor, perm, f::Function=(r, t) -> t)
function permutedims!!(R::DenseTensor, T::DenseTensor, perm, f::Function)
Base.checkdims_perm(R, T, perm)
RR = convert(promote_type(typeof(R), typeof(T)), R)
permutedims!(RR, T, perm, f)
return RR
end

function permutedims!!(R::DenseTensor, T::DenseTensor, perm)
Base.checkdims_perm(R, T, perm)
RR = convert(promote_type(typeof(R), typeof(T)), R)
permutedims!(RR, T, perm)
return RR
end

# TODO: call permutedims!(R,T,perm,(r,t)->t)?
function permutedims!(
R::DenseTensor{<:Number,N,StoreT}, T::DenseTensor{<:Number,N,StoreT}, perm::NTuple{N,Int}
Expand All @@ -216,7 +223,7 @@ function permutedims!(
) where {N}
RA = array(R)
TA = array(T)
RA .= permutedims(expose(TA), perm)
permutedims!(expose(RA), expose(TA), perm)
return R
end

Expand Down
4 changes: 4 additions & 0 deletions NDTensors/src/lib/Unwrap/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
21 changes: 18 additions & 3 deletions NDTensors/src/tensoroperations/generic_tensor_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,31 @@ end
# and return the result of the permutation.
# Similar to `BangBang.jl` notation:
# https://juliafolds.github.io/BangBang.jl/stable/.
function permutedims!!(output_tensor::Tensor, tensor::Tensor, perm, f::Function=(r, t) -> t)
function permutedims!!(output_tensor::Tensor, tensor::Tensor, perm, f::Function)
Base.checkdims_perm(output_tensor, tensor, perm)
permutedims!(output_tensor, tensor, perm, f)
return output_tensor
end

function permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function=(r, t) -> t)
# Equivalent to `permutedims!!(output_tensor, tensor, perm, (r, t) -> t)`
function permutedims!!(output_tensor::Tensor, tensor::Tensor, perm)
Base.checkdims_perm(output_tensor, tensor, perm)
permutedims!(output_tensor, tensor, perm)
return output_tensor
end

function permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function)
Base.checkdims_perm(output_tensor, tensor, perm)
error(
"`permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function` not implemented for `typeof(output_tensor) = $(typeof(output_tensor))`, `typeof(tensor) = $(typeof(tensor))`, `perm = $perm`, and `f = $f`.",
)
return output_tensor
end

function permutedims!(output_tensor::Tensor, tensor::Tensor, perm)
Base.checkdims_perm(output_tensor, tensor, perm)
error(
"`permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function=(r, t) -> t)` not implemented for `typeof(output_tensor) = $(typeof(output_tensor))`, `typeof(tensor) = $(typeof(tensor))`, `perm = $perm`, and `f = $f`.",
"`permutedims!(output_tensor::Tensor, tensor::Tensor, perm` not implemented for `typeof(output_tensor) = $(typeof(output_tensor))`, `typeof(tensor) = $(typeof(tensor))`, and `perm = $perm`.",
)
return output_tensor
end
Expand Down

0 comments on commit faa0357

Please sign in to comment.