From 0ccab5ab1039daaff8d2ecfff4eb67eee0629699 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 30 May 2024 10:25:11 +0200 Subject: [PATCH] More changes to fix AD tests --- test/ad.jl | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 305ba22d..e75566fa 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -228,28 +228,26 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), k1 = rand(1:3) k2 = rand(1:2) k3 = rand(1:2) - V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) - V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) - V3 = map(v -> rand(Bool) ? v' : v, rand(V, k3)) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1])) d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) d > 0 && break end - ipA = randindextuple(k1 + k2) - pA = _repartition(invperm(linearize(ipA)), k1) - ipB = randindextuple(k2 + k3) - pB = _repartition(invperm(linearize(ipB)), k2) - pAB = randindextuple(k1 + k3) + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) α = randn(T) β = randn(T) for conjA in (:N, :C), conjB in (:N, :C) A = TensorMap(randn, T, - permute(prod(V1) ← prod(conjA === :C ? conj.(V2) : V2), - ipA)) + permute(V1 ← (conjA === :C ? prod(conj, V2) : V2), ipA)) B = TensorMap(randn, T, - permute(prod(conjB === :C ? conj.(V2) : V2) ← prod(V3), - ipB)) + permute((conjB === :C ? prod(conj, V2) : V2) ← V3, ipB)) C = _randomize!(TensorOperations.tensoralloc_contract(T, pAB, A, pA, conjA, B, pB, conjB,