In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pyqpanda as pq
import pyvqnet as pv

In [5]:
obs_list = [
    {
        "wires": [i],
        "observables": ["X"],
        "coefficient": [1],
    }
    for i in range(10)
]


def pqc(input, param, qubits, cbits, machine):
    n = len(input)
    prog = pq.QProg()
    prog.insert(pq.H(qubits[10]))
    for i in range(10):
        prog.insert(pq.CNOT(qubits[10], qubits[i]))
    for i in range(n):
        prog.insert(pq.RZ(qubits[i % 10], input[i]))
    for i in range(10):
        prog.insert(pq.RZ(qubits[i], param[i]))
    for i in range(10):
        prog.insert(pq.RZ(qubits[i], -i * 2 * np.pi / 10))

    re = [
        2 * x * x
        for x in pv.qnn.measure.MeasurePauliSum(machine, prog, obs_list, qubits)
    ]
    return re


class Model(pv.nn.module.Module):
    def __init__(self):
        super(Model, self).__init__()
        # self.li = pv.nn.Linear(30, 30)
        self.fc = pv.qnn.quantumlayer.QuantumLayer(
            pqc, 10, "CPU", 11, diff_method="finite_diff"
        )
        # self.fc = pv.nn.Linear(10, 10)

    def forward(self, x):
        # x = self.li(x)
        x = 2 * pv.tensor.atan(x)
        x = self.fc(x)
        return x

In [3]:
epoch = 1000
batch = 16

In [10]:
m = Model()
print(sum(p.numel() for p in m.parameters()))
# print(m.parameters())
x = np.random.rand(batch, 10).astype("float32")
y_pred = m(x)
print(y_pred[0])
y = np.random.randint(0, 10, (batch,), dtype="int64")
print(y)
print(y_pred.argmax(1, False).to_numpy())

10
[0.,0.,0.,0.,0.,0.,0.,0.,0.,0.]
[4 9 5 8 9 9 5 5 7 7 3 8 8 8 1 2]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [75]:
los = pv.nn.loss.CrossEntropyLoss()
# opt = pv.optim.SGD(m.parameters())
opt = pv.optim.Adam(m.parameters())
for e in range(1, epoch + 1):
    opt.zero_grad()
    print(f"epoch {e} \t loss ", end="")
    y_pred = m(x)
    loss = los(y, y_pred)
    print(loss.item())
    y_p = y_pred.argmax(1, False).to_numpy()
    s = sum(y_p == y)
    print(y)
    print(y_p)
    print(f"correct: {s}")
    print()
    if s == y.shape[0]:
        print("done")
        break
    loss.backward()
    opt._step()

epoch 1 	 loss 2.40496826171875
[0 2 3 4 5 2 9 2 4 8 5 1 8 3 4 6]
[4 0 6 4 4 6 4 6 6 6 5 0 9 4 6 6]
correct: 3

epoch 2 	 loss 2.4025261402130127
[0 2 3 4 5 2 9 2 4 8 5 1 8 3 4 6]
[4 0 6 4 4 6 4 6 6 6 5 0 9 4 6 6]
correct: 3

epoch 3 	 loss 2.4000866413116455
[0 2 3 4 5 2 9 2 4 8 5 1 8 3 4 6]
[4 0 6 4 4 6 4 6 6 6 5 0 9 4 6 6]
correct: 3

epoch 4 	 loss 2.3976500034332275
[0 2 3 4 5 2 9 2 4 8 5 1 8 3 4 6]
[4 0 6 4 4 6 4 6 6 6 5 0 9 4 6 6]
correct: 3

epoch 5 	 loss 2.3952157497406006
[0 2 3 4 5 2 9 2 4 8 5 1 8 3 4 6]
[4 0 6 4 4 6 4 6 6 6 5 0 9 4 6 6]
correct: 3

epoch 6 	 loss 2.3927829265594482
[0 2 3 4 5 2 9 2 4 8 5 1 8 3 4 6]
[4 0 6 4 4 6 4 6 6 6 5 0 9 4 6 6]
correct: 3

epoch 7 	 loss 2.390350580215454
[0 2 3 4 5 2 9 2 4 8 5 1 8 3 4 6]
[4 0 6 4 4 6 4 6 6 6 5 0 9 4 6 6]
correct: 3

epoch 8 	 loss 2.3879175186157227
[0 2 3 4 5 2 9 2 4 8 5 1 8 3 4 6]
[4 0 6 4 4 6 4 6 6 6 5 0 9 4 6 6]
correct: 3

epoch 9 	 loss 2.385484218597412
[0 2 3 4 5 2 9 2 4 8 5 1 8 3 4 6]
[4 0 6 4 4 6 4 6 6 6 5 0

KeyboardInterrupt: 