Skip to content

Commit

Permalink
More retries to avoid empty tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed May 30, 2024
1 parent 164ddbe commit 1a04143
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function FiniteDifferences.to_vec(t::T) where {T<:TensorKit.TrivialTensorMap}
return vec, x -> T(from_vec(x), codomain(t), domain(t))
end
function FiniteDifferences.to_vec(t::AbstractTensorMap)
vec = mapreduce(vcat, blocks(t); init=eltype(t)[]) do (c, b)
vec = mapreduce(vcat, blocks(t); init=real(eltype(t)[])) do (c, b)
if scalartype(t) <: Real
return reshape(b, :) .* sqrt(dim(c))
else
Expand Down Expand Up @@ -221,13 +221,19 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),

@timedtestset "tensorcontract!" begin
for _ in 1:5
k1 = rand(1:3)
k2 = rand(1:2)
k3 = rand(1:2)
V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1))
V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2))
V3 = map(v -> rand(Bool) ? v' : v, rand(V, k3))

d = 0
local V1, V2, V3
# retry a couple times to make sure there is at least some nonzero elements
for _ in 1:10
k1 = rand(1:3)
k2 = rand(1:2)
k3 = rand(1:2)
V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1))
V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2))
V3 = map(v -> rand(Bool) ? v' : v, rand(V, k3))
d = min(dim(V1 V2), dim(V1' V2), dim(V2 V3), dim(V2' V3))
d > 0 && break
end
ipA = randindextuple(k1 + k2)
pA = _repartition(invperm(linearize(ipA)), k1)
ipB = randindextuple(k2 + k3)
Expand Down

0 comments on commit 1a04143

Please sign in to comment.