In [1]:
# Basic Tools
import time
import numpy as np 
import sys
sys.path.append("..")
from QuantumTrain.util import *

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.transforms as transforms

# TorchQuantum
import torchquantum as tq
# Plotting
import matplotlib.pyplot as plt

# torch.manual_seed(42)
# np.random.seed(42)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  return torch._C._cuda_getDeviceCount() > 0


In [2]:

### Classical target model initialization ###

# Define the CNN model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        # Writing every operation as layer, so that the extraction function could read
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(8, 12, kernel_size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()  
        self.fc1 = nn.Linear(12*4*4, 20)
        self.fc2 = nn.Linear(20, 10)
        
    def forward(self, x):
        x = self.pool1(self.conv1(x))
        x = self.pool2(self.conv2(x))
        x = self.flatten(x)  # Use the Flatten layer
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x


# Instantiate the model and loss function
model = CNNModel()

In [3]:
n_qubit, nw_list_normal = required_qubits_estimation(model)
network_config          = network_config_extract(model)

# of NN parameters:  6690
Required qubit number:  13


In [4]:

### Training setting ########################

step       = 1e-4   # Learning rate
batch_size = 1000    # Number of samples for each training step
num_epochs = 10      # Number of training epochs
q_depth    = 16     # Depth of the quantum circuit (number of variational layers)

# Dataset setup
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Instantiate the model, move it to GPU, and set up loss function and optimizer
model_qt = QuantumTrain(
                        model,
                        n_qubit,
                        nw_list_normal,
                        q_depth,
                        device,
                        network_config
                        ).to(device)

criterion = nn.CrossEntropyLoss()

# optimizer = optim.Adam(model_qt.parameters(), lr=step, weight_decay=1e-5, eps=1e-6)
optimizer = optim.Adam([
    {'params': model_qt.QuantumNN.parameters()},
    {'params': model_qt.MappingNetwork.parameters()}
], lr=step, weight_decay=1e-5, eps=1e-6)


scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 5, verbose = True, factor = 0.5)  # 'min' because we're minimizing loss


In [5]:

num_trainable_params_MM = sum(p.numel() for p in model_qt.MappingNetwork.parameters() if p.requires_grad)
num_trainable_params_QNN = sum(p.numel() for p in model_qt.QuantumNN.parameters() if p.requires_grad)
num_trainable_params = sum(p.numel() for p in model_qt.parameters() if p.requires_grad)

print("# of trainable parameter in Mapping model: ", num_trainable_params_MM)
print("# of trainable parameter in QNN model: ", num_trainable_params_QNN)
print("# of trainable parameter in full model: ", num_trainable_params)


# of trainable parameter in Mapping model:  249
# of trainable parameter in QNN model:  1248
# of trainable parameter in full model:  8187


In [6]:


#############################################
### Training loop ###########################

### (Optional) Start from pretrained model ##
# model_qt = torch.load('L16/tq_mm_acc_99_bsf')
# model_qt.eval()  # Set the model to evaluation mode
#############################################

loss_list = [] 
acc_list = [] 
acc_best = 0
for epoch in range(num_epochs):
    model_qt.train()
    train_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        correct = 0
        total = 0
        since_batch = time.time()
        
        images, labels = images.to(device), labels.to(device)  # Move data to GPU
        optimizer.zero_grad()
        # Forward pass
        outputs = model_qt(images)
        # print("output: ", outputs)
        labels_one_hot = F.one_hot(labels, num_classes=10).float()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        # Compute loss
        loss = criterion(outputs, labels_one_hot)
        # log_loss = torch.log(loss + 1e-6)
        
        loss_list.append(loss.cpu().detach().numpy())
        acc = 100 * correct / total
        acc_list.append(acc)
        train_loss += loss.cpu().detach().numpy()
        
        # np.array(loss_list).dump("L16/3/loss_list.dat")
        # np.array(acc_list).dump("L16/3/acc_list.dat")
        if acc > acc_best:
            # torch.save(model_qt, 'L16/3/tq_mm_acc_'+str(int(acc))+'_bsf')
            acc_best = acc
        # Backward pass and optimization
        loss.backward()
        
        optimizer.step()
        # if (i+1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}, batch time: {time.time() - since_batch:.2f}, accuracy:  {(acc):.2f}%")
    
    train_loss /= len(train_loader)
    scheduler.step(train_loss)
    
#############################################

Epoch [1/10], Step [1/60], Loss: 22.4013, batch time: 0.55, accuracy:  7.40%
Epoch [1/10], Step [2/60], Loss: 22.1687, batch time: 0.38, accuracy:  7.40%
Epoch [1/10], Step [3/60], Loss: 20.4757, batch time: 0.28, accuracy:  7.10%
Epoch [1/10], Step [4/60], Loss: 20.7899, batch time: 0.29, accuracy:  7.90%
Epoch [1/10], Step [5/60], Loss: 20.1095, batch time: 0.29, accuracy:  7.80%
Epoch [1/10], Step [6/60], Loss: 18.8639, batch time: 0.29, accuracy:  7.90%
Epoch [1/10], Step [7/60], Loss: 19.0451, batch time: 0.29, accuracy:  8.30%
Epoch [1/10], Step [8/60], Loss: 19.0837, batch time: 0.29, accuracy:  7.70%
Epoch [1/10], Step [9/60], Loss: 18.2830, batch time: 0.29, accuracy:  8.30%
Epoch [1/10], Step [10/60], Loss: 18.5170, batch time: 0.30, accuracy:  6.70%
Epoch [1/10], Step [11/60], Loss: 17.5140, batch time: 0.30, accuracy:  6.10%
Epoch [1/10], Step [12/60], Loss: 17.3420, batch time: 0.29, accuracy:  8.40%
Epoch [1/10], Step [13/60], Loss: 16.6696, batch time: 0.30, accuracy:  7

In [None]:
# Print gradients of all parameters
for name, param in model_qt.named_parameters():
    print(f"Gradient of {name}: {param.grad}")