# Analyzing LatentQGAN: A Hybrid Quantum-Classical GAN with Autoencoders

Generative models have achieved significant success in capturing complex data distributions and generating realistic samples across various domains. However, training these models, particularly Generative Adversarial Networks (GANs), remains computationally demanding, especially for high-dimensional data. Quantum computing offers a promising avenue to address these challenges by leveraging principles such as superposition and entanglement. In this context, LatentQGAN integrates classical and quantum computing to enhance GAN training, enabling the generation of more expressive and complex data distributions beyond the capabilities of purely classical approaches.

In [None]:
!pip install qiskit qiskit-ibm-runtime pylatexenc qiskit-aer qiskit_machine_learning
!pip install torchinfo

In [2]:
from qiskit import QuantumCircuit
from qiskit.circuit import ParameterVector
from qiskit_machine_learning.connectors import TorchConnector
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.gradients import ParamShiftSamplerGradient
from qiskit.primitives import StatevectorSampler as Sampler


import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from torchinfo import summary
from rich.console import Console

device = 'cpu'
if torch.cuda.is_available():
  device = 'cuda'
print(f'Using device: {device}')

Using device: cuda


# Latent-QGAN

## Quantum Generator

In [3]:
class QGenerator(nn.Module):
  def __init__(self, num_circuits=5, num_qubits=4, num_layers=7):
    """
    Quantum Generator using Qiskit SamplerQNN and TorchConnector.
    Args:
        num_circuits: Number of independent quantum circuits.
        num_qubits: Number of qubits per circuit.
        num_layers: Number of variational layers.
    """
    super().__init__()

    self.num_circuits = num_circuits
    self.num_qubits = num_qubits
    self.num_layers = num_layers
    # 2048 runs
    self.sampler = Sampler(default_shots=2**11)

    # classical input parameters (random noise)
    self.alpha_params = [ParameterVector(f'alpha_{i}', num_qubits) for i in range(num_circuits)]

    # trainable parameters
    self.theta_params = [ParameterVector(f'theta_{i}', num_layers * num_qubits) for i in range(num_circuits)]

    # create a list of parameterized quantum circuits
    self.qc_list = [self.__create_param_circuit(i) for i in range(num_circuits)]


    # define QNNs and Torch Connectors for each sub-circuit
    self.generators = torch.nn.ModuleList([
      TorchConnector(
        SamplerQNN(
          circuit=self.qc_list[i],
          sampler=self.sampler,
          input_params=self.alpha_params[i].params,
          weight_params=self.theta_params[i].params,
          sparse=False,
          gradient=ParamShiftSamplerGradient(self.sampler),
          input_gradients=True
        ),
        torch.rand((self.num_layers*self.num_qubits))*(2 * torch.pi) * 0.1
      )
      for i in range(num_circuits)
    ])

  def __create_param_circuit(self, circuit_id):
    """Create a parameterized quantum circuit with RY layers and CZ entanglement."""
    qc = QuantumCircuit(self.num_qubits, self.num_qubits)

    # apply parameterized Ry gates
    for qubit in range(self.num_qubits):
      qc.ry(self.alpha_params[circuit_id][qubit], qubit)

    # add parametrized layers
    for layer in range(self.num_layers):
      for qubit in range(self.num_qubits):
        qc.ry(self.theta_params[circuit_id][layer * self.num_qubits + qubit], qubit)
      for i in range(self.num_qubits - 1):
        qc.cz(i, i + 1)

    qc.measure(range(self.num_qubits), range(self.num_qubits))
    return qc

  def __post_selection_and_norm(self, results):
    """
    Performs post-selection (where ancilla qubit is '0') and normalizes values.
    """
    # generate bitstrings corresponding to all possible states
    bitstrings = [format(i, f'0{self.num_qubits}b') for i in range(2 ** self.num_qubits)]

    # find indices where the last qubit (ancilla) is '0'
    ancilla_zero_indices = [i for i, b in enumerate(bitstrings) if b[-1] == '0']

    # select only the corresponding values from the output tensor
    filtered_output = results[:,:, ancilla_zero_indices]

    norm_output = filtered_output / (filtered_output.sum(dim=2, keepdim=True) + 1e-10)

    return norm_output


  def forward(self, noise):
    """
    Generates a batch of latent representations using all sub-circuits.
    Returns:
        Torch tensor: Generated latent representation.
    """
    # perform forward pass on all noise vectors
    results = [gen(noise[:, i, :]) for i, gen in enumerate(self.generators)]

    # stack the results from all generators, resulting in shape [batch_size, num_circuits, ...]
    stacked_results = torch.stack(results, dim=1)

    # apply post-selection and normalization (if necessary)
    return self.__post_selection_and_norm(stacked_results)

## Classical Discriminator

In [5]:
class Discriminator(nn.Module):
  def __init__(self, input_size):
    super(Discriminator, self).__init__()
    self.layers = nn.Sequential(
      nn.Flatten(start_dim=1),
      nn.Linear(input_size, 64),
      nn.LeakyReLU(),
      nn.Linear(64, 16),
      nn.LeakyReLU(),
      nn.Linear(16, 1),
      nn.Sigmoid()
    )

  def forward(self, x):
    return self.layers(x)

## AutoEncoder

In [4]:
class AutoEncoder(nn.Module):
  def __init__(self):
    super(AutoEncoder, self).__init__()

    self.__encoder = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=1, padding=0),
      nn.ReLU(),
      nn.Conv2d(in_channels=4, out_channels=8, kernel_size=5, stride=1, padding=0),
      nn.ReLU(),
      nn.Conv2d(in_channels=8, out_channels=4, kernel_size=5, stride=(4,2), padding=(1,0)),
      nn.ReLU(),
      nn.Flatten(start_dim=1),
      nn.Linear(160,40),
      nn.ReLU(),
      nn.Linear(40,40),
      nn.ReLU())

    self.__decoder = nn.Sequential(
      nn.Flatten(start_dim=1),
      nn.Linear(40,400),
      nn.ReLU(),
      nn.Linear(400,4000),
      nn.ReLU(),
      nn.Unflatten(dim=1, unflattened_size=(10,20,20)),
      nn.ConvTranspose2d(in_channels=10, out_channels=10, kernel_size=5, stride=1, padding=0),
      nn.ReLU(),
      nn.ConvTranspose2d(in_channels=10, out_channels=1, kernel_size=5, stride=1, padding=0),
      nn.Sigmoid())


  # extract latent representation of input image
  def encode(self, x):

    result = self.__encoder(x).view(x.shape[0],1,5,8)

    # normalize s.t. rows sum up to 1
    result = result / (torch.sum(result, dim=2) + 1e-8).unsqueeze(2)

    return result

  # reconstruc image from input latent
  def decode(self, x):
    return self.__decoder(x)

  # perform complete forward pass (encoding and decoding)
  def forward(self, x):
    return self.decode(self.encode(x))

## View used models size

In [10]:
console = Console()

# create models and view their size (# of parameters)
gen_noise = torch.randn((1,5,4))
netQG = QGenerator()

input_size = (1,1,28,28)
autoenc_noise = torch.randn(input_size)
autoenc = AutoEncoder()

disc_noise = torch.randn(1,40)
netD = Discriminator(40)

G_stats = summary(model=netQG, input_data=gen_noise, col_names=['input_size','output_size','num_params'], row_settings=('var_names',), verbose=0)
D_stats = summary(model=netD, input_data=disc_noise, col_names=['input_size','output_size','num_params'], row_settings=('var_names',), verbose=0)
autoenc_stats = summary(model=autoenc, input_data=autoenc_noise, col_names=['input_size','output_size','num_params'], row_settings=('var_names',), verbose=0)


console.print(G_stats)
print('\n+-+-+-+-+-+-+-+-+-+-\n')
console.print(D_stats)
print('\n+-+-+-+-+-+-+-+-+-+-\n')
console.print(autoenc_stats)


+-+-+-+-+-+-+-+-+-+-




+-+-+-+-+-+-+-+-+-+-



# Training

## Importing Dataset (MNIST)

In [11]:
preprocess = transforms.Compose([transforms.Resize(28), transforms.ToTensor()])

# import mnist dataset
trainset = datasets.MNIST(root='./data', download=True, transform=preprocess)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=preprocess)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.2MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 489kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.52MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.06MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






## Training the Autoencoder

In [12]:
def load_autoencoder(file_path='./autoenc_checkpoint.pth', load_eval_mode=True):
  autoenc = AutoEncoder()
  optimizer = torch.optim.AdamW(autoenc.parameters(), lr=0.001, weight_decay=1e-4)
  epoch = 0
  best_test_loss = float('inf')
  try:
    # load checkpoint file and set model's and optmizier's state dicts
    autoenc_checkpoint = torch.load(file_path, weights_only=True, map_location=device)
    autoenc.load_state_dict(autoenc_checkpoint['model_state_dict'])
    checkpoint = torch.load('./autoenc_checkpoint.pth', weights_only=True, map_location=device)
    epoch = checkpoint['epoch']
    autoenc.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    best_test_loss = checkpoint['test_loss']

    # move optimizer tensors to device
    for param in optimizer.state.values():
      if isinstance(param, torch.Tensor):
        param.data = param.data.to(device)
      if isinstance(param, dict):
        for sub_param in param.values():
          if isinstance(sub_param, torch.Tensor):
            sub_param.data = sub_param.data.to(device)
  except FileNotFoundError:
    print('ERROR: NO CHECKPOINT FILE FOUND. Using default Autoencoder.')
  except:
    print('ERROR WHILE LOADING PRE-TRAINED AUTOENCODER WEIGHTS: Using default Autoencoder.')

  if load_eval_mode:
    return autoenc
  else:
    return autoenc, optimizer, epoch, best_test_loss

In [None]:
# define train and data loader for
batch_size=20
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)
tot_batches = len(trainloader)

# load pre-trained model if available
autoenc, optimizer, epoch, best_test_loss = load_autoencoder(load_eval_mode=False)
autoenc = autoenc.to(device)

# define loss function and optimizer
loss_f = nn.BCELoss()

epochs = 100
log_after_batches = 1000

for epoch in range(epochs):
  print(f'EPOCH {epoch+1}\n========')
  autoenc.train()
  for batch_id, (images, labels) in enumerate(trainloader):
    # move data to device
    images = images.to(device)

    # perform forward pass
    decoded = autoenc(images)

    # compute loss
    loss = loss_f(decoded, images)

    # update autoencoder
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # log progress
    if batch_id%log_after_batches == 0:
      print(f'[{(batch_id)}/{tot_batches}] loss: {loss.item():.4f}')
      fig, axes = plt.subplots(1, 2)
      axes[0].imshow(images.detach().cpu()[0].permute(1,2,0).numpy(), cmap='gray')
      axes[0].set_title(f'Original {labels[0].item()}')
      axes[0].axis("off")
      axes[1].imshow(decoded.detach().cpu()[0].permute(1,2,0).numpy(), cmap='gray')
      axes[1].set_title(f'Reconstructed {labels[0].item()}')
      axes[1].axis("off")
      plt.show()


  # evaluation on test set
  with torch.no_grad():
    autoenc.eval()
    test_loss = 0
    for images, labels in testloader:
      images = images.to(device)
      labels = labels.to(device)

      encoded = autoenc.encode(images)
      decoded = autoenc.decode(encoded)
      test_loss += loss_f(decoded, images).item()

    test_loss /= len(testloader)

    # save checkpoint if results are new best
    if test_loss < best_test_loss:
      print('NEW BEST')
      best_test_loss = test_loss
      torch.save({'epoch': epoch,
                  'model_state_dict': autoenc.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'test_loss': loss},
                  'autoenc_checkpoint.pth')
    print(f'test loss: {test_loss:.6f},   current_best: {best_test_loss:.6f}')
  print()

## Training the QGAN

In [14]:
def load_qgan(file_path='./qgan_checkpoint.pth', load_eval_mode=True):
  netD = Discriminator(40)
  netQG = QGenerator(num_circuits=5, num_qubits=4, num_layers=7)
  gen_optim = optim.Adam(netQG.parameters(), lr=0.001, betas=(0.5, 0.9))
  disc_optim = optim.SGD(netD.parameters(), lr=0.001)

  try:
    qgan_checkpoint = torch.load(file_path, weights_only=True, map_location=device)
    netQG.load_state_dict(qgan_checkpoint['gen_state_dict'])
    gen_optim.load_state_dict(qgan_checkpoint['gen_optim_state_dict'])
    netD.load_state_dict(qgan_checkpoint['disc_state_dict'])
    disc_optim.load_state_dict(qgan_checkpoint['disc_optim_state_dict'])


    # move optimizers tensors to device
    for param in gen_optim.state.values():
      if isinstance(param, torch.Tensor):
        param.data = param.data.to(device)
      if isinstance(param, dict):
        for sub_param in param.values():
          if isinstance(sub_param, torch.Tensor):
            sub_param.data = sub_param.data.to(device)

    for param in disc_optim.state.values():
      if isinstance(param, torch.Tensor):
        param.data = param.data.to(device)
      if isinstance(param, dict):
        for sub_param in param.values():
          if isinstance(sub_param, torch.Tensor):
            sub_param.data = sub_param.data.to(device)

  except FileNotFoundError:
    print('ERROR: NO CHECKPOINT FILE FOUND. Using default QGAN.')
  except:
    print('ERROR WHILE LOADING PRE-TRAINED QGAN WEIGHTS: Using default QGAN.')

  if load_eval_mode:
    return netQG, netD
  else:
    return netQG, gen_optim, netD, disc_optim

In [None]:
batch_size=2
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

# load pre-trained models if available
autoenc = load_autoencoder(load_eval_mode=True)
netQG, gen_optim, netD, disc_optim = load_qgan(load_eval_mode=False)

autoenc.eval()
autoenc = autoenc.to(device)
netD = netD.to(device)
netQG = netQG.to(device)

loss_fn = nn.BCELoss()

# Training loop
num_epochs=5
total_batches=50
log_after_batches=10

for epoch in range(1,num_epochs+1):
  print(f"Epoch {epoch}/{num_epochs}\n = = = = = = = = = = = = = = =")
  for i, (images, labels) in enumerate(trainloader):
    # load data to device
    images = images.to(device)
    labels = labels.to(device)

    ########################
    # discriminator update #
    ########################
    netD.train()

    #sample random noise
    noise = torch.rand((batch_size, netQG.num_circuits,netQG.num_qubits), device=device)*(2 * torch.pi)

    # generate fake data
    latent_fake = netQG(noise).unsqueeze(1).to(device)

    # extract latent of true images
    latent_true = autoenc.encode(images)

    # create batch with real and fake data
    batch = torch.cat((latent_true, latent_fake.detach()))

    # label smoothing (0.9 instead of 1.0; 0.1 instead of 0.)
    real_labels = torch.full((batch_size, 1), 0.9, device=device)
    fake_labels = torch.full((batch_size, 1), 0.1, device=device)
    true_labels = torch.cat((real_labels, fake_labels), dim=0).to(device)

    # forward pass of the discriminator
    disc_output = netD(batch)

    # compute discriminator loss
    disc_loss = loss_fn(disc_output, true_labels)

    # update discriminator
    disc_optim.zero_grad()
    disc_loss.backward()
    disc_optim.step()


    ####################
    # generator update #
    ####################
    netD.eval()

    # compute generator loss
    gen_loss = (-torch.log(netD(latent_fake))).mean()

    # update generator
    gen_optim.zero_grad()
    gen_loss.backward()
    gen_optim.step()

    # log results and show generated images
    if i%log_after_batches==0:
      fake_img = autoenc.decode(latent_fake)

      fig, axes = plt.subplots(1, 2)
      axes[0].imshow(fake_img[0].detach().cpu().permute(1,2,0).numpy(), cmap='gray')
      axes[0].axis("off")
      axes[1].imshow(fake_img[1].detach().cpu().permute(1,2,0).numpy(), cmap='gray')
      axes[1].axis("off")
      plt.show()
      print(f'epoch {epoch}/{num_epochs} - [{i}/{total_batches}]\t GEN loss: {gen_loss.item():.4f}, DISC loss: {disc_loss.item():.4f}')

    if i==total_batches:
      break

  # save model checkpoint
  torch.save({'gen_state_dict': netQG.state_dict(),
              'gen_optim_state_dict': gen_optim.state_dict(),
              'disc_state_dict': netD.state_dict(),
              'disc_optim_state_dict': disc_optim.state_dict()},
              'qgan_checkpoint.pth')
  print(f'GEN loss: {gen_loss:.4f}, DISC loss: {disc_loss.item():.4f}\n')

# Demo

In [None]:
# create q-GAN (eventually import pre-trained one)
netQG, netD = load_qgan(load_eval_mode=True)

# create autoencoder (eventually import pre-trained one)
autoencoder = load_autoencoder(load_eval_mode=True)

# sample random noise
noise = torch.rand((1, netQG.num_circuits, netQG.num_qubits))

# perform forward pass of q-generator
result = netQG(noise)

# evaluate results with discriminator
disc_output = netD(result)

# map to pixel-space using the decoder
decoded = autoencoder.decode(result)

# show resulting image
plt.imshow(decoded[0].detach().cpu().permute(1,2,0).numpy(), cmap='gray')
plt.show()