# CHAPTER 12 - Quantum Generative Adversarial Networks - Qiskit Code

*Note*: You may skip the following four cells if you have alredy installed the right versions of all the libraries mentioned in *Appendix D*. This will likely NOT be the case if you are running this notebook on a cloud service such as Google Colab.

In [None]:
pip install scikit-learn==1.2.1

In [None]:
pip install qiskit==0.39.2

In [None]:
pip install qiskit_machine_learning==0.5.0

In [None]:
pip install matplotlib==3.2.2

In [None]:
import numpy as np

from qiskit import *
from qiskit.utils import algorithm_globals

seed = 1234
np.random.seed(seed)
algorithm_globals.random_seed = seed

In [None]:
N = 1000
n = 3
p = 0.5

real_data = np.random.binomial(n, p, N)

In [None]:
from qiskit_machine_learning.algorithms import QGAN
from qiskit.utils import QuantumInstance

ncycles = 3000 # Number of training cycles.
bsize = 100 # Batch size.

# Quantum instance on which the QGAN will run.
quantum_instance = QuantumInstance(
    backend=Aer.get_backend('statevector_simulator'))

# Create the QGAN object.
qgan = QGAN(data = real_data, 
            num_qubits = [2], 
            batch_size = bsize, 
            num_epochs = ncycles,
            bounds = [0,3],
            seed = seed,
            tol_rel_ent = 0.001)

In [None]:
result = qgan.run(quantum_instance)

In [None]:
import matplotlib.pyplot as plt
plt.title("Loss function evolution")
cycles = np.array(range(len(qgan.g_loss))) + 1
plt.plot(cycles, qgan.g_loss, label = "Generator")
plt.plot(cycles, qgan.d_loss, label = "Discriminator")
plt.xlabel("Cycle")
plt.legend()

In [None]:
samples_g, prob_g = qgan.generator.get_output(qgan.quantum_instance,
                                             shots=10000)

real_distr = []
for i in range(0,3+1):
    proportion = np.count_nonzero(real_data == i) / N
    real_distr.append(proportion)

plt.bar(range(4), real_distr, width = 0.7, color = "royalblue", 
        label = "Real distribution")
plt.bar(range(4), prob_g, width = 0.5, color = "black", 
        label = "Generated distribution")