In [6]:
import torch
import pennylane as qml

from qulearn.hat_basis import HatBasis
from qulearn.qlayer import (HatBasisQFE,
                            MeasurementLayer,
                            MeasurementType)

In [7]:
num_qubits = 5
num_nodes = 2**num_qubits
a = -1.0
b = 1.0
hat_basis = HatBasis(a=a, b=b, num_nodes=num_nodes)

embed = HatBasisQFE(wires=num_qubits, basis=hat_basis, sqrt=True, normalize=False)
obs = qml.PauliZ(0)
model = MeasurementLayer(embed, observables=obs, measurement_type=MeasurementType.Expectation)
drawer = qml.draw(model.qnode, show_all_wires=True, expansion_strategy="device")
x = torch.tensor([0.0])
print(drawer(x))

0: ──────────────────────╭U(M2)─┤  <Z>
1: ───────────────╭U(M1)─╰U(M2)─┤     
2: ────────╭U(M1)─╰U(M1)────────┤     
3: ─╭U(M0)─╰U(M1)───────────────┤     
4: ─╰U(M0)──────────────────────┤     


In [34]:
import torch
import tntorch as tn
def zkron(t1, t2):
    c1 = t1.cores
    c2 = t2.cores
    c3 = [torch.kron(A, B) for A, B in zip(c1, c2)]
    
    t3 = tn.Tensor(c3)
    return t3

In [42]:
import tntorch as tn
import numpy as np

t1 = tn.randn([2]*3)
t2 = tn.ones([2]*3)

T1 = t1.numpy().reshape((2**3))
T2 = t2.numpy().reshape((2**3))

cores = t1.cores + t2.cores
t3 = tn.Tensor(cores)
T3 = t3.numpy().reshape((2**6))

T3_ = np.kron(T1, T2)
delta = abs(T3_ - T3)
delta = np.linalg.norm(delta)
print(delta)

t4 = zkron(t1, t2)
T4 = t4.numpy().reshape((2**6))

print(t4)
print(T3)
print("=========")
print(T4)

0.0
3D TT tensor:

  4   4   4
  |   |   |
 (0) (1) (2)
 / \ / \ / \
1   2   2   1

[ 0.15855342  0.15855342  0.15855342  0.15855342  0.15855342  0.15855342
  0.15855342  0.15855342 -0.19791673 -0.19791673 -0.19791673 -0.19791673
 -0.19791673 -0.19791673 -0.19791673 -0.19791673 -0.10040016 -0.10040016
 -0.10040016 -0.10040016 -0.10040016 -0.10040016 -0.10040016 -0.10040016
 -0.13793476 -0.13793476 -0.13793476 -0.13793476 -0.13793476 -0.13793476
 -0.13793476 -0.13793476 -0.46558505 -0.46558505 -0.46558505 -0.46558505
 -0.46558505 -0.46558505 -0.46558505 -0.46558505  0.07590834  0.07590834
  0.07590834  0.07590834  0.07590834  0.07590834  0.07590834  0.07590834
 -0.7543369  -0.7543369  -0.7543369  -0.7543369  -0.7543369  -0.7543369
 -0.7543369  -0.7543369   0.29843152  0.29843152  0.29843152  0.29843152
  0.29843152  0.29843152  0.29843152  0.29843152]
[ 0.15855342  0.15855342 -0.19791673 -0.19791673  0.15855342  0.15855342
 -0.19791673 -0.19791673 -0.10040016 -0.10040016 -0.13793476 -0.

In [43]:
print(t1.numpy().reshape((2**3)))
print(t2.numpy().reshape((2**3)))
print(T3_)
print(T4)

[ 0.15855342 -0.19791673 -0.10040016 -0.13793476 -0.46558505  0.07590834
 -0.7543369   0.29843152]
[1. 1. 1. 1. 1. 1. 1. 1.]
[ 0.15855342  0.15855342  0.15855342  0.15855342  0.15855342  0.15855342
  0.15855342  0.15855342 -0.19791673 -0.19791673 -0.19791673 -0.19791673
 -0.19791673 -0.19791673 -0.19791673 -0.19791673 -0.10040016 -0.10040016
 -0.10040016 -0.10040016 -0.10040016 -0.10040016 -0.10040016 -0.10040016
 -0.13793476 -0.13793476 -0.13793476 -0.13793476 -0.13793476 -0.13793476
 -0.13793476 -0.13793476 -0.46558505 -0.46558505 -0.46558505 -0.46558505
 -0.46558505 -0.46558505 -0.46558505 -0.46558505  0.07590834  0.07590834
  0.07590834  0.07590834  0.07590834  0.07590834  0.07590834  0.07590834
 -0.7543369  -0.7543369  -0.7543369  -0.7543369  -0.7543369  -0.7543369
 -0.7543369  -0.7543369   0.29843152  0.29843152  0.29843152  0.29843152
  0.29843152  0.29843152  0.29843152  0.29843152]
[ 0.15855342  0.15855342 -0.19791673 -0.19791673  0.15855342  0.15855342
 -0.19791673 -0.1979167