In [3]:
import collections
import torch
from torch import cat, no_grad, manual_seed
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import tqdm
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from QuICT_ml.ansatz_library import QNNLayer
from QuICT_ml.utils.encoding import *
from QuICT_ml.utils.ml_utils import *
from QuICT_ml.model.QNN import QuantumNet

In [27]:
np.random.seed = 42
manual_seed(42)
EPOCH = 50       # 训练总轮数
BATCH_SIZE = 64 # 一次迭代使用的样本数
LR = 0.001      # 梯度下降的学习率
SEED = 42       # 随机数种子

In [12]:
X_train = datasets.MNIST(root="./data/", train=True, download=True)
batch_size = 64
n_samples = 1024  # We will concentrate on the first 100 samples
# 创建一个索引列表，包含所有类别（0-9）的样本
idx = []
for label in range(2):  # 遍历所有类别（0-9）
    label_idx = np.where(X_train.targets == label)[0][:n_samples]  # 获取当前类别的样本索引
    idx.append(label_idx)

# 将所有类别的索引合并为一个数组
idx = np.concatenate(idx)

# 根据索引过滤数据
X_train.data = X_train.data[idx]
X_train.targets = X_train.targets[idx]
train_X = X_train.data
train_Y = X_train.targets

n_samples = 512
X_test = datasets.MNIST(root="./data/", train=False, download=True)

idx = []
for label in range(2):  # 遍历所有类别（0-9）
    label_idx = np.where(X_test.targets == label)[0][:n_samples]  # 获取当前类别的样本索引
    idx.append(label_idx)

# 将所有类别的索引合并为一个数组
idx = np.concatenate(idx)

# 根据索引过滤数据
X_test.data = X_test.data[idx]
X_test.targets = X_test.targets[idx]
test_X = X_test.data
test_Y = X_test.targets
print("Training examples: ", len(train_Y))
print("Testing examples: ", len(test_Y))

Training examples:  2048
Testing examples:  1024


In [13]:
def downscale(X, resize):
    transform = transforms.Resize(size=resize)
    X = transform(X) / 255.0
    return X

resized_train_X = downscale(train_X, (4, 4))
resized_test_X = downscale(test_X, (4, 4))



In [14]:
def remove_conflict(X, Y, resize):
    x_dict = collections.defaultdict(set)
    for x, y in zip(X, Y):
        x_dict[tuple(x.numpy().flatten())].add(y.item())
    X_rmcon = []
    Y_rmcon = []
    for x in x_dict.keys():
        if len(x_dict[x]) == 1:
            X_rmcon.append(np.array(x).reshape(resize))
            Y_rmcon.append(list(x_dict[x])[0])
    X = torch.from_numpy(np.array(X_rmcon))
    Y = torch.from_numpy(np.array(Y_rmcon))
    return X, Y

nocon_train_X, nocon_train_Y = remove_conflict(resized_train_X, train_Y, (4, 4))
nocon_test_X, nocon_test_Y = remove_conflict(resized_test_X, test_Y, (4, 4))
print("Remaining training examples: ", len(nocon_train_Y))
print("Remaining testing examples: ", len(nocon_test_Y))


Remaining training examples:  1462
Remaining testing examples:  629


In [15]:
def binary_img(X, threshold):
    X = X > threshold
    X = X.type(torch.int)
    return X

threshold = 0.5
bin_train_X = binary_img(nocon_train_X, threshold)
bin_test_X = binary_img(nocon_test_X, threshold)

In [16]:
device = torch.device("cuda:0")

train_X = bin_train_X.to(device)
train_Y = nocon_train_Y.to(device)
test_X = bin_test_X.to(device)
test_Y = nocon_test_Y.to(device)

In [18]:
def qubit_encoding(X, device):
    new_X = []
    n_qubits = X[0].shape[0] * X[0].shape[1]
    qe = Qubit(n_qubits, device)
    for x in X:
        qe.encoding(x)
        new_X.append(qe.ansatz)
    return new_X

In [19]:
ansatz_train_X = qubit_encoding(train_X, device)
ansatz_test_X = qubit_encoding(test_X, device)

In [20]:
pqc = QNNLayer(list(range(4)), 4, device=device)
params = nn.Parameter(torch.rand(1, 4, device=device), requires_grad=True)
model_circuit = pqc.circuit_layer(["XX"], params)
model_circuit.draw()

<Figure size 848.056x645 with 1 Axes>

In [21]:
data_qubits = list(range(16))
readout_qubit = 16
pqc = QNNLayer(data_qubits, readout_qubit, device=device)
layers = ["XX", "ZZ"]
params = nn.Parameter(torch.rand(2, 16, device=device), requires_grad=True)
model_ansatz = pqc(layers, params)

In [33]:
train_dataset = data.TensorDataset(train_X, train_Y)
test_dataset = data.TensorDataset(test_X, test_Y)
train_loader = data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True
)
test_loader = data.DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True
)

In [24]:
net = QuantumNet(16, layers, encoding="qubit", device=device)
optim = torch.optim.Adam([dict(params=net.parameters(), lr=LR)])

In [35]:
def loss_func(y_true, y_pred):
    y_true = 2 * y_true.type(torch.float32) - 1.0
    y_pred = 2 * y_pred - 1.0
    loss = torch.clamp(1 - y_pred * y_true, min=0.0)
    correct = torch.where(y_true * y_pred > 0)[0].shape[0]
    return torch.mean(loss), correct

In [36]:
# train epoch
for ep in range(EPOCH):
    net.train()
    loader = tqdm.tqdm(
        train_loader, desc="Training epoch {}".format(ep + 1), leave=True
    )
    # train iteration
    for it, (x_train, y_train) in enumerate(loader):
        optim.zero_grad()
        y_pred = net(x_train)

        loss, correct = loss_func(y_train, y_pred)
        accuracy = correct / len(y_train)
        loss.backward()
        optim.step()
        loader.set_postfix(
            it=it,
            loss="{:.3f}".format(loss),
            accuracy="{:.3f}".format(accuracy),
        )

    # Validation
    net.eval()
    loader_val = tqdm.tqdm(
        test_loader, desc="Validating epoch {}".format(ep + 1), leave=True
    )
    loss_val_list = []
    total_correct = 0
    for it, (x_test, y_test) in enumerate(loader_val):
        y_pred = net(x_test)
        loss_val, correct = loss_func(y_test, y_pred)
        loss_val_list.append(loss_val.cpu().detach().numpy())
        total_correct += correct
        accuracy_val = correct / len(y_test)
        loader_val.set_postfix(
            it=it,
            loss="{:.3f}".format(loss_val),
            accuracy="{:.3f}".format(accuracy_val),
        )
    avg_loss = np.mean(loss_val_list)
    avg_acc = total_correct / (len(loader_val) * BATCH_SIZE)
    print("Validation Average Loss: {}, Accuracy: {}".format(avg_loss, avg_acc))

Training epoch 1: 100%|██████████| 22/22 [06:46<00:00, 18.48s/it, accuracy=0.578, it=21, loss=0.881]
Validating epoch 1: 100%|██████████| 9/9 [02:47<00:00, 18.57s/it, accuracy=0.547, it=8, loss=0.950]


Validation Average Loss: 0.9631057381629944, Accuracy: 0.5416666666666666


Training epoch 2: 100%|██████████| 22/22 [06:48<00:00, 18.59s/it, accuracy=0.688, it=21, loss=0.866]
Validating epoch 2: 100%|██████████| 9/9 [02:47<00:00, 18.59s/it, accuracy=0.594, it=8, loss=0.915]


Validation Average Loss: 0.8825179934501648, Accuracy: 0.6145833333333334


Training epoch 3: 100%|██████████| 22/22 [06:48<00:00, 18.56s/it, accuracy=0.828, it=21, loss=0.786]
Validating epoch 3: 100%|██████████| 9/9 [02:47<00:00, 18.58s/it, accuracy=0.766, it=8, loss=0.809]


Validation Average Loss: 0.82808518409729, Accuracy: 0.7482638888888888


Training epoch 4: 100%|██████████| 22/22 [06:44<00:00, 18.41s/it, accuracy=0.781, it=21, loss=0.775]
Validating epoch 4: 100%|██████████| 9/9 [02:27<00:00, 16.40s/it, accuracy=0.781, it=8, loss=0.759]


Validation Average Loss: 0.7848742604255676, Accuracy: 0.7430555555555556


Training epoch 5: 100%|██████████| 22/22 [06:01<00:00, 16.41s/it, accuracy=0.766, it=21, loss=0.720]
Validating epoch 5: 100%|██████████| 9/9 [02:27<00:00, 16.37s/it, accuracy=0.688, it=8, loss=0.811]


Validation Average Loss: 0.7364639639854431, Accuracy: 0.7534722222222222


Training epoch 6: 100%|██████████| 22/22 [06:01<00:00, 16.42s/it, accuracy=0.781, it=21, loss=0.666]
Validating epoch 6: 100%|██████████| 9/9 [02:27<00:00, 16.40s/it, accuracy=0.750, it=8, loss=0.689]


Validation Average Loss: 0.7043659687042236, Accuracy: 0.7534722222222222


Training epoch 7: 100%|██████████| 22/22 [06:01<00:00, 16.42s/it, accuracy=0.781, it=21, loss=0.615]
Validating epoch 7: 100%|██████████| 9/9 [02:27<00:00, 16.44s/it, accuracy=0.750, it=8, loss=0.686]


Validation Average Loss: 0.690216064453125, Accuracy: 0.75


Training epoch 8: 100%|██████████| 22/22 [05:46<00:00, 15.75s/it, accuracy=0.750, it=21, loss=0.658]
Validating epoch 8: 100%|██████████| 9/9 [02:16<00:00, 15.17s/it, accuracy=0.750, it=8, loss=0.694]


Validation Average Loss: 0.6622174382209778, Accuracy: 0.7569444444444444


Training epoch 9: 100%|██████████| 22/22 [05:32<00:00, 15.12s/it, accuracy=0.781, it=21, loss=0.633]
Validating epoch 9: 100%|██████████| 9/9 [02:16<00:00, 15.13s/it, accuracy=0.766, it=8, loss=0.597]


Validation Average Loss: 0.6417844891548157, Accuracy: 0.765625


Training epoch 10: 100%|██████████| 22/22 [05:34<00:00, 15.19s/it, accuracy=0.719, it=21, loss=0.678]
Validating epoch 10: 100%|██████████| 9/9 [02:16<00:00, 15.19s/it, accuracy=0.781, it=8, loss=0.596]


Validation Average Loss: 0.6441090106964111, Accuracy: 0.7534722222222222


Training epoch 11: 100%|██████████| 22/22 [05:33<00:00, 15.17s/it, accuracy=0.766, it=21, loss=0.630]
Validating epoch 11: 100%|██████████| 9/9 [02:16<00:00, 15.21s/it, accuracy=0.703, it=8, loss=0.660]


Validation Average Loss: 0.6095855236053467, Accuracy: 0.7673611111111112


Training epoch 12: 100%|██████████| 22/22 [05:39<00:00, 15.44s/it, accuracy=0.859, it=21, loss=0.538]
Validating epoch 12: 100%|██████████| 9/9 [02:17<00:00, 15.28s/it, accuracy=0.812, it=8, loss=0.539]


Validation Average Loss: 0.6099399924278259, Accuracy: 0.7569444444444444


Training epoch 13: 100%|██████████| 22/22 [05:31<00:00, 15.06s/it, accuracy=0.703, it=21, loss=0.677]
Validating epoch 13: 100%|██████████| 9/9 [02:17<00:00, 15.24s/it, accuracy=0.766, it=8, loss=0.596]


Validation Average Loss: 0.6009268760681152, Accuracy: 0.7586805555555556


Training epoch 14: 100%|██████████| 22/22 [05:35<00:00, 15.25s/it, accuracy=0.812, it=21, loss=0.531]
Validating epoch 14: 100%|██████████| 9/9 [02:15<00:00, 15.07s/it, accuracy=0.766, it=8, loss=0.607]


Validation Average Loss: 0.5798186659812927, Accuracy: 0.765625


Training epoch 15: 100%|██████████| 22/22 [05:32<00:00, 15.12s/it, accuracy=0.812, it=21, loss=0.517]
Validating epoch 15: 100%|██████████| 9/9 [02:17<00:00, 15.25s/it, accuracy=0.844, it=8, loss=0.467]


Validation Average Loss: 0.5966569781303406, Accuracy: 0.75


Training epoch 16: 100%|██████████| 22/22 [05:32<00:00, 15.11s/it, accuracy=0.781, it=21, loss=0.538]
Validating epoch 16: 100%|██████████| 9/9 [02:17<00:00, 15.30s/it, accuracy=0.766, it=8, loss=0.582]


Validation Average Loss: 0.5938238501548767, Accuracy: 0.7482638888888888


Training epoch 17: 100%|██████████| 22/22 [05:32<00:00, 15.09s/it, accuracy=0.719, it=21, loss=0.655]
Validating epoch 17: 100%|██████████| 9/9 [02:16<00:00, 15.18s/it, accuracy=0.797, it=8, loss=0.532]


Validation Average Loss: 0.5877219438552856, Accuracy: 0.7517361111111112


Training epoch 18: 100%|██████████| 22/22 [05:32<00:00, 15.11s/it, accuracy=0.781, it=21, loss=0.546]
Validating epoch 18: 100%|██████████| 9/9 [02:18<00:00, 15.38s/it, accuracy=0.797, it=8, loss=0.538]


Validation Average Loss: 0.5777895450592041, Accuracy: 0.7534722222222222


Training epoch 19: 100%|██████████| 22/22 [05:36<00:00, 15.29s/it, accuracy=0.781, it=21, loss=0.528]
Validating epoch 19: 100%|██████████| 9/9 [02:18<00:00, 15.40s/it, accuracy=0.703, it=8, loss=0.620]


Validation Average Loss: 0.5636030435562134, Accuracy: 0.7569444444444444


Training epoch 20: 100%|██████████| 22/22 [05:33<00:00, 15.14s/it, accuracy=0.719, it=21, loss=0.615]
Validating epoch 20: 100%|██████████| 9/9 [02:16<00:00, 15.15s/it, accuracy=0.750, it=8, loss=0.624]


Validation Average Loss: 0.5698352456092834, Accuracy: 0.7517361111111112


Training epoch 21: 100%|██████████| 22/22 [05:34<00:00, 15.19s/it, accuracy=0.750, it=21, loss=0.548]
Validating epoch 21: 100%|██████████| 9/9 [02:17<00:00, 15.31s/it, accuracy=0.781, it=8, loss=0.508]


Validation Average Loss: 0.5708842873573303, Accuracy: 0.7534722222222222


Training epoch 22: 100%|██████████| 22/22 [05:33<00:00, 15.14s/it, accuracy=0.797, it=21, loss=0.446]
Validating epoch 22: 100%|██████████| 9/9 [02:17<00:00, 15.24s/it, accuracy=0.797, it=8, loss=0.504]


Validation Average Loss: 0.5518158674240112, Accuracy: 0.7638888888888888


Training epoch 23: 100%|██████████| 22/22 [05:29<00:00, 14.98s/it, accuracy=0.812, it=21, loss=0.473]
Validating epoch 23: 100%|██████████| 9/9 [02:16<00:00, 15.18s/it, accuracy=0.766, it=8, loss=0.576]


Validation Average Loss: 0.5694606900215149, Accuracy: 0.7517361111111112


Training epoch 24: 100%|██████████| 22/22 [05:30<00:00, 15.02s/it, accuracy=0.859, it=21, loss=0.376]
Validating epoch 24: 100%|██████████| 9/9 [02:16<00:00, 15.16s/it, accuracy=0.766, it=8, loss=0.549]


Validation Average Loss: 0.5571860074996948, Accuracy: 0.7569444444444444


Training epoch 25: 100%|██████████| 22/22 [05:34<00:00, 15.18s/it, accuracy=0.766, it=21, loss=0.513]
Validating epoch 25: 100%|██████████| 9/9 [02:19<00:00, 15.48s/it, accuracy=0.703, it=8, loss=0.674]


Validation Average Loss: 0.558228611946106, Accuracy: 0.7569444444444444


Training epoch 26: 100%|██████████| 22/22 [05:32<00:00, 15.11s/it, accuracy=0.781, it=21, loss=0.527]
Validating epoch 26: 100%|██████████| 9/9 [02:17<00:00, 15.32s/it, accuracy=0.781, it=8, loss=0.530]


Validation Average Loss: 0.5427374839782715, Accuracy: 0.7673611111111112


Training epoch 27: 100%|██████████| 22/22 [05:35<00:00, 15.24s/it, accuracy=0.781, it=21, loss=0.499]
Validating epoch 27: 100%|██████████| 9/9 [02:16<00:00, 15.20s/it, accuracy=0.828, it=8, loss=0.434]


Validation Average Loss: 0.5501111745834351, Accuracy: 0.7638888888888888


Training epoch 28: 100%|██████████| 22/22 [05:34<00:00, 15.22s/it, accuracy=0.734, it=21, loss=0.546]
Validating epoch 28: 100%|██████████| 9/9 [02:16<00:00, 15.13s/it, accuracy=0.750, it=8, loss=0.551]


Validation Average Loss: 0.5559328198432922, Accuracy: 0.7552083333333334


Training epoch 29: 100%|██████████| 22/22 [05:34<00:00, 15.19s/it, accuracy=0.797, it=21, loss=0.428]
Validating epoch 29: 100%|██████████| 9/9 [02:18<00:00, 15.38s/it, accuracy=0.719, it=8, loss=0.620]


Validation Average Loss: 0.5573146939277649, Accuracy: 0.7552083333333334


Training epoch 30: 100%|██████████| 22/22 [05:34<00:00, 15.22s/it, accuracy=0.797, it=21, loss=0.439]
Validating epoch 30: 100%|██████████| 9/9 [02:18<00:00, 15.42s/it, accuracy=0.641, it=8, loss=0.708]


Validation Average Loss: 0.566136360168457, Accuracy: 0.7482638888888888


Training epoch 31: 100%|██████████| 22/22 [05:34<00:00, 15.19s/it, accuracy=0.844, it=21, loss=0.336]
Validating epoch 31: 100%|██████████| 9/9 [02:16<00:00, 15.14s/it, accuracy=0.688, it=8, loss=0.691]


Validation Average Loss: 0.5548200607299805, Accuracy: 0.7569444444444444


Training epoch 32: 100%|██████████| 22/22 [05:41<00:00, 15.53s/it, accuracy=0.859, it=21, loss=0.406]
Validating epoch 32: 100%|██████████| 9/9 [02:38<00:00, 17.61s/it, accuracy=0.781, it=8, loss=0.546]


Validation Average Loss: 0.5477726459503174, Accuracy: 0.7604166666666666


Training epoch 33: 100%|██████████| 22/22 [11:26<00:00, 31.19s/it, accuracy=0.797, it=21, loss=0.485]
Validating epoch 33: 100%|██████████| 9/9 [06:14<00:00, 41.57s/it, accuracy=0.828, it=8, loss=0.409]


Validation Average Loss: 0.5467122197151184, Accuracy: 0.7569444444444444


Training epoch 34: 100%|██████████| 22/22 [18:32<00:00, 50.57s/it, accuracy=0.828, it=21, loss=0.382]
Validating epoch 34: 100%|██████████| 9/9 [09:55<00:00, 66.16s/it, accuracy=0.812, it=8, loss=0.468]


Validation Average Loss: 0.5385504961013794, Accuracy: 0.7638888888888888


Training epoch 35: 100%|██████████| 22/22 [28:52<00:00, 78.74s/it, accuracy=0.844, it=21, loss=0.382]
Validating epoch 35: 100%|██████████| 9/9 [15:42<00:00, 104.73s/it, accuracy=0.797, it=8, loss=0.501]


Validation Average Loss: 0.5439577102661133, Accuracy: 0.7604166666666666


Training epoch 36: 100%|██████████| 22/22 [37:11<00:00, 101.44s/it, accuracy=0.750, it=21, loss=0.568]
Validating epoch 36: 100%|██████████| 9/9 [15:24<00:00, 102.67s/it, accuracy=0.719, it=8, loss=0.635]


Validation Average Loss: 0.533050000667572, Accuracy: 0.765625


Training epoch 37: 100%|██████████| 22/22 [37:19<00:00, 101.78s/it, accuracy=0.828, it=21, loss=0.449]
Validating epoch 37: 100%|██████████| 9/9 [15:39<00:00, 104.36s/it, accuracy=0.812, it=8, loss=0.510]


Validation Average Loss: 0.528014063835144, Accuracy: 0.7690972222222222


Training epoch 38: 100%|██████████| 22/22 [37:29<00:00, 102.27s/it, accuracy=0.781, it=21, loss=0.457]
Validating epoch 38: 100%|██████████| 9/9 [15:11<00:00, 101.24s/it, accuracy=0.750, it=8, loss=0.562]


Validation Average Loss: 0.5430059432983398, Accuracy: 0.7604166666666666


Training epoch 39: 100%|██████████| 22/22 [37:50<00:00, 103.22s/it, accuracy=0.797, it=21, loss=0.412]
Validating epoch 39: 100%|██████████| 9/9 [14:47<00:00, 98.62s/it, accuracy=0.688, it=8, loss=0.644] 


Validation Average Loss: 0.5509097576141357, Accuracy: 0.7552083333333334


Training epoch 40: 100%|██████████| 22/22 [37:58<00:00, 103.56s/it, accuracy=0.891, it=21, loss=0.361]
Validating epoch 40: 100%|██████████| 9/9 [15:09<00:00, 101.07s/it, accuracy=0.859, it=8, loss=0.429]


Validation Average Loss: 0.5355576276779175, Accuracy: 0.7760416666666666


Training epoch 41: 100%|██████████| 22/22 [37:09<00:00, 101.32s/it, accuracy=0.781, it=21, loss=0.499]
Validating epoch 41: 100%|██████████| 9/9 [15:34<00:00, 103.87s/it, accuracy=0.781, it=8, loss=0.554]


Validation Average Loss: 0.5570106506347656, Accuracy: 0.7638888888888888


Training epoch 42: 100%|██████████| 22/22 [36:09<00:00, 98.60s/it, accuracy=0.859, it=21, loss=0.378] 
Validating epoch 42: 100%|██████████| 9/9 [15:44<00:00, 104.93s/it, accuracy=0.797, it=8, loss=0.448]


Validation Average Loss: 0.5389440059661865, Accuracy: 0.7708333333333334


Training epoch 43: 100%|██████████| 22/22 [38:08<00:00, 104.03s/it, accuracy=0.812, it=21, loss=0.419]
Validating epoch 43: 100%|██████████| 9/9 [15:03<00:00, 100.39s/it, accuracy=0.828, it=8, loss=0.449]


Validation Average Loss: 0.533496081829071, Accuracy: 0.7743055555555556


Training epoch 44: 100%|██████████| 22/22 [38:06<00:00, 103.92s/it, accuracy=0.812, it=21, loss=0.432]
Validating epoch 44: 100%|██████████| 9/9 [15:32<00:00, 103.66s/it, accuracy=0.844, it=8, loss=0.443]


Validation Average Loss: 0.533491849899292, Accuracy: 0.7725694444444444


Training epoch 45: 100%|██████████| 22/22 [37:32<00:00, 102.38s/it, accuracy=0.812, it=21, loss=0.414]
Validating epoch 45: 100%|██████████| 9/9 [15:22<00:00, 102.53s/it, accuracy=0.797, it=8, loss=0.507]


Validation Average Loss: 0.5345349907875061, Accuracy: 0.7743055555555556


Training epoch 46: 100%|██████████| 22/22 [37:33<00:00, 102.41s/it, accuracy=0.875, it=21, loss=0.356]
Validating epoch 46:  78%|███████▊  | 7/9 [13:04<04:00, 120.18s/it, accuracy=0.859, it=6, loss=0.411]