Skip to content

Commit

Permalink
More changes to fix AD tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed May 30, 2024
1 parent 1a04143 commit 0ccab5a
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,28 +228,26 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
k1 = rand(1:3)
k2 = rand(1:2)
k3 = rand(1:2)
V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1))
V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2))
V3 = map(v -> rand(Bool) ? v' : v, rand(V, k3))
V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1]))
V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1]))
V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1]))
d = min(dim(V1 V2), dim(V1' V2), dim(V2 V3), dim(V2' V3))
d > 0 && break
end
ipA = randindextuple(k1 + k2)
pA = _repartition(invperm(linearize(ipA)), k1)
ipB = randindextuple(k2 + k3)
pB = _repartition(invperm(linearize(ipB)), k2)
pAB = randindextuple(k1 + k3)
ipA = randindextuple(length(V1) + length(V2))
pA = _repartition(invperm(linearize(ipA)), length(V1))
ipB = randindextuple(length(V2) + length(V3))
pB = _repartition(invperm(linearize(ipB)), length(V2))
pAB = randindextuple(length(V1) + length(V3))

α = randn(T)
β = randn(T)

for conjA in (:N, :C), conjB in (:N, :C)
A = TensorMap(randn, T,
permute(prod(V1) prod(conjA === :C ? conj.(V2) : V2),
ipA))
permute(V1 (conjA === :C ? prod(conj, V2) : V2), ipA))
B = TensorMap(randn, T,
permute(prod(conjB === :C ? conj.(V2) : V2) prod(V3),
ipB))
permute((conjB === :C ? prod(conj, V2) : V2) V3, ipB))
C = _randomize!(TensorOperations.tensoralloc_contract(T, pAB, A, pA,
conjA,
B, pB, conjB,
Expand Down

0 comments on commit 0ccab5a

Please sign in to comment.