# Dataset preparation

In [1]:
import torch
import pennylane as qml
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from data_utils import mnist_preparation 
from tqdm import tqdm
import matplotlib as plt

In [None]:
labels = [0,1,2,3]
# Download MNIST and prepare transforms
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
                                transforms.Resize((16, 16)),  # Resize to 16x16
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))  # Normalize
                             ]))

#dataloader
train_dataloader, test_dataloader = mnist_preparation(dataset=mnist, labels = labels, train_test_ratio=0.8,batch_size=64)

# Pure and mixed states
Before diving in mid circuit measurements it could be useful to have a small recap on pure and mixed states. In simple terms, a state is **pure** if you do not have ignorance about its preparation while it is **mixed** if you do not know. The so called **density operator** can be defined as:

$$
\hat{\rho} = \sum \limits_i p_i | \phi_i \rangle \langle \phi_i |
$$ 

where $\sum \limits_i p_i = 1$ and each of $p_i$ is associated to state $|\phi_i \rangle$. The density operator is useful to describe mixed states. For a pure state instead you will have the probability being 1 in a specific state ending up in $\hat{\rho} = | \phi \rangle \langle \phi |$. 

Expressing the density operator as $\hat{\rho} = \sum \limits_{i=1}^l p_k |\psi_k \rangle \langle \psi_k|$ it can be represented as a matrix using the basis $\{ |i \rangle \}$, then the ij-th element of the matrix will be $\langle i | \rho | j \rangle$. In matrix representation on the diagonal there are the so called *population terms* and off diagonal the *coherence terms*.

The density operator: 
- is hermitian, meaning that on the diagonal $\rho_{ii} = \rho_{ii}^*$ and off diagonal $\rho_{ij}= \rho{ji}^*$;
- has unitary trace $\text{tr}[\rho] = 1$; 
- is non negative, i.e. for any vector in the Hilbert space $\langle \phi |\rho | \phi \rangle \geq 0$. 

A simple criterion to check if a state is pure or mixed is to calculate $\text{tr}[\rho^2]$, if it is 1 it is pure, if it is less than 1 it is mixed.



# Mid circuit measurements
A measurement $M$ is a process that maps a valid quantum state $\rho$ (density operator) to a classical probabilistic mixture of post measurement quantum states $\rho_i$, specified by $M$: 

$$
M[\rho] = \sum \limits_{i=1}^n p_i \rho_i
$$

where $n$ is the number of possible outcomes and $p_i$ the probability to measure outcome $i$ associated to $\rho_i$. This describes the probabilistic mixture after the measurement if we do not record the outcome. If the outcome is recorded, we no longer have a probabilistic mixture but the state $\rho_i$ associated to the filtered outcome $i$. 

Considering the so called projective measurements, calling $\Pi_i$ the projector associated to the measurement outcome, with all projectors summming to the identity, the post measurement states are given by: 

$$
\rho_i = \frac{\Pi_i \rho \Pi_i}{\text{tr}[\Pi_i \rho]}
$$

where $p_i = \text{tr}[\Pi_i \rho]$. If we do not measure the outcome the system will end up in state: 

$$
M[\rho] = \sum \limits_{i=1}^n \Pi_i \rho \Pi_i
$$

Even if it is not written, I suppose that in [Pennylane](https://pennylane.ai/qml/demos/tutorial_mcm_introduction/), they refers to $\Pi_i$ as the so called Positive Operator Valued Measurements defined as $\Pi_i = M_i^{\dagger}M_i$ and $\sum_i \Pi_i = \mathbf{I}$. For POVM $\Pi_i$ are non negative, sufficient condition to describe as positive the probability to obtain an outcome $i$. In this scenario the $M$ operator defining the $\Pi_i$ can be thought as Generalized Measurements, a specific collection of operators such that for pure states the probability of the measurement outcome $i$ is given by:

$$
p(i) = \langle \psi | M_i^{\dagger} M_i | \psi \rangle
$$

The state of the system after the measurement, given that the outcome $i$ was obtained, is:

$$
|\psi'\rangle = \frac{M_i |\psi\rangle}{\sqrt{\langle \psi | M_i^{\dagger} M_i | \psi \rangle}}
$$

For mixed states instead the probability of the measurement outcome $i$ is given by: 
$$
p(i) = \text{tr}[\rho M_i^{\dagger} M_i]
$$

The state of the system after the measurement, given that the outcome $i$ was obtained, is:

$$
\rho' = \frac{M_i^{\dagger} \rho M_i}{\text{tr}[\rho M_i^{\dagger} M_i]}
$$

The **projective measurements** are a specific class of Generalized Measurements  where $P_i = | i \rangle  \langle i |$, $\sum_i P_i = \mathbf{I}$, $P_iP_j = \delta_{ij} P_i$ and $P_i^2 = \mathbf{I}$. In this case for pure states the probability of obtaining outcome $i$ is: 

$$
p(i) = \langle \psi |P_i| \psi \rangle
$$

and the final state is: 

$$
|\psi' \rangle = \frac{P_i |\psi \rangle}{ \sqrt{\langle \psi |P_i| \psi \rangle}}
$$

For mixed state the probability to obtain an outcome $i$ is given by: 

$$
p(i) = \text{tr}[\rho P_i]
$$

And the final state: 

$$
\rho' = \frac{P_i \rho P_i}{\text{tr}[\rho P_i]}
$$

# Original model
The initial state of the system is initialized using one image. If the image is of size 16x16, then you need 8 qubits in order to have $2^8 = 256$ possible states in order to associate each of the 256 values referred to an image to define the amplitude to be in a specific state. In PennyLane, the states are in lexiographic order, whose meaning is hereafter reported for completeness:  

 

In [1]:
#lexiographic order
states = []
for i in range(256):
    states.append(format(i, '08b'))  
print(states)

['00000000', '00000001', '00000010', '00000011', '00000100', '00000101', '00000110', '00000111', '00001000', '00001001', '00001010', '00001011', '00001100', '00001101', '00001110', '00001111', '00010000', '00010001', '00010010', '00010011', '00010100', '00010101', '00010110', '00010111', '00011000', '00011001', '00011010', '00011011', '00011100', '00011101', '00011110', '00011111', '00100000', '00100001', '00100010', '00100011', '00100100', '00100101', '00100110', '00100111', '00101000', '00101001', '00101010', '00101011', '00101100', '00101101', '00101110', '00101111', '00110000', '00110001', '00110010', '00110011', '00110100', '00110101', '00110110', '00110111', '00111000', '00111001', '00111010', '00111011', '00111100', '00111101', '00111110', '00111111', '01000000', '01000001', '01000010', '01000011', '01000100', '01000101', '01000110', '01000111', '01001000', '01001001', '01001010', '01001011', '01001100', '01001101', '01001110', '01001111', '01010000', '01010001', '01010010', '01

Considering that $\text{dim}\mathcal{H}_8 = 2^8$, that the computational basis for an 8-qubit system is $\{ |00000000\rangle, |00000001 \rangle, \dots, |11111111 \rangle \}$ and can be synthetically expressed as $\{ |i \rangle \}_{i = 1}^{2^8}$, the general initial state can be expressed as: 

$$
|\psi \rangle = \sum \limits_{i=1}^{256} \alpha_i |i \rangle
$$

where $\alpha_i$ are complex coefficients such that $\sum \limits_{i=1}^{256} \alpha_i = 1$. 

Each layer of the original model consists in parametrized rotations defined using the $\sigma_X, \sigma_Y, \sigma_Z$ Pauli Matrices [RX](https://docs.pennylane.ai/en/stable/code/api/pennylane.RX.html), [RY](https://docs.pennylane.ai/en/stable/code/api/pennylane.RY.html), [RZ](https://docs.pennylane.ai/en/stable/code/api/pennylane.RZ.html) gates applied on each qubit, then [CNOT](https://docs.pennylane.ai/en/stable/code/api/pennylane.CNOT.html) gates are applied to create entanglement. 

Just to remind, a very small recap on entanglement. Considering a system composed by two subsystems A,B, for pure states it can be described by $|\phi \rangle_{AB}$. Then the state is **separable** iff: 

$$
$|\phi \rangle_{AB}$ = |\phi \rangle_A \otimes |\phi \rangle_B
$$

For example $\frac{|00\rangle_{AB} + |01 \rangle_{AB}}{\sqrt{2}$ is separable beacuse it can be written as $\frac{|0\rangle_{A} (|0\rangle_{B} + |1 \rangle_{B})}{\sqrt{2}$.

If the state is not separable it is **entangled**. This fact is not easy to prove, for this reason there exists some criteria working under specific circumstances. One example of Entangled states are the [Bell states](https://en.wikipedia.org/wiki/Bell_state), that can be used as a basis of Entangled Hilbert Space.

For mixed states instead we can use an operational approach to express **separability** called Local Operation and Classical Communication. A and B, prepare separable states locally and cooperate with classical communication to create the full state: 

$$
\rho_{AB} = \sum \limits_i p_i \rho_A^i \otimes \rho_B^i
$$

Also there if a state is not separable, then it is **entangled**. 

In [None]:
NUM_QUBITS = 8
NUM_LAYERS = 3

# get the device
dev = qml.device("default.qubit", wires=NUM_QUBITS)

# circuit using the strongly entangling layer ansatz
@qml.qnode(dev, interface="torch")
def circuit_block(params, state=None):

    # Load the initial state if provided
    if state is not None: qml.QubitStateVector(state, wires=range(NUM_QUBITS))

    # Real quantum encoding (using amplitude encoding)
    #if state is not None: qml.AmplitudeEmbedding(features=state, wires=range(NUM_QUBITS))

    #qml.StronglyEntanglingLayers(params, wires=range(NUM_QUBITS), ranges = [1]*params.shape[0])

    # Quantum circuit
    for i in range(NUM_LAYERS):

      # Rotation layer
      for j in range(NUM_QUBITS):
          qml.RX(params[i, j, 0], wires=j)
          qml.RY(params[i, j, 1], wires=j)
          qml.RZ(params[i, j, 2], wires=j)

      # Entangling layer
      for j in range(NUM_QUBITS):
          qml.CNOT(wires=[j, (j + 1) % NUM_QUBITS])


    # Return the state vector
    return qml.state()

  # define general circuit
def circuit(params, state):

    # apply first small block
    state = circuit_block(params, state)

    # return probability of measuring |0> in the first qubit
    return measure(state)

# define function that outputs the probability of mesuring |0> in the first qubit
def measure(state):

    # compute the probability of measuring |0> in the first qubit
    prob = torch.sum(torch.abs(state[:,:2**(NUM_QUBITS-1)])**2, dim = 1)

    # cast to float32
    prob = prob.type(torch.float32)

    return prob