In [6]:
import sys
sys.path.append('../')
import cirq
import tensorflow_quantum as tfq
import tensorflow as tf
import numpy as np

from encode_state import EncodeState
from input_circuits import InputCircuits
from loss import DiscriminationLoss

Here we will simply print the circuits and loss function used for state discrimination.

In [18]:
n = 4
state_creator = InputCircuits(n)
circuit_creator = EncodeState(n)

discrimination_circuit = circuit_creator.discrimination_circuit()

print(discrimination_circuit.to_text_diagram(transpose=True))

(0, 0)     (0, 1)      (1, 0)      (1, 1)
│          │           │           │
X^layer0_0 X^layer0_3  X^layer0_6  X^layer0_9
│          │           │           │
Y^layer0_1 Y^layer0_4  Y^layer0_7  Y^layer0_10
│          │           │           │
Z^layer0_2 Z^layer0_5  Z^layer0_8  Z^layer0_11
│          │           │           │
@──────────X^layer0_12 │           │
│          │           │           │
│          @───────────X^layer0_13 │
│          │           │           │
│          │           @───────────X^layer0_14
│          │           │           │
X──────────┼───────────┼───────────@^layer0_15
│          │           │           │
│          │           │           M('m0')
│          │           │           │
X^layer1_0 X^layer1_3  X^layer1_6  │
│          │           │           │
Y^layer1_1 Y^layer1_4  Y^layer1_7  │
│          │           │           │
Z^layer1_2 Z^layer1_5  Z^layer1_8  │
│          │           │           │
@──────────X^layer1_9  │           │
│          │   

The states to be discriminated can also be written as two circuits:

In [8]:
circuits, labels = state_creator.discrimination_circuits_labels(total_states=10, mu_a = 0.9)
a_circuit = circuits[labels.index(1)] # in this convention a has the label 1
b_circuit = circuits[labels.index(0)]
print(a_circuit, '\n\n')
print(b_circuit)

(0, 0): ───Ry(0.691π)───

(0, 1): ───Z────────────

(1, 0): ───Z────────────

(1, 1): ───Z──────────── 


(0, 0): ───X───iSwap──────────
               │
(0, 1): ───────iSwap^-0.337───

(1, 0): ───Z──────────────────

(1, 1): ───Z──────────────────


The states were originally specified as $|a \rangle: (\sqrt{1- a^2}, 0, a, 0)$, $|b \rangle: (0, \pm \frac{1}{\sqrt{2}}, \frac{1}{\sqrt{2}}, 0)$.

To create these states with a circuit we use $\textrm{R}_y(2 \textrm{arcsin}(a))(\textrm{q0})$
 and $\textrm{X}(\textrm{q0}), \textrm{ISWAP}(\textrm{q0}, \textrm{q1})^{2 \textrm{arcsin}(b) / \pi}$

## Loss Function
In this scheme we arbitrarily label the measurement outcomes: $|00\rangle: a, |01\rangle: b, |10\rangle: a, |11\rangle: \textrm{inconclusive}$.
We can separatley penalise inconclusive and erroneous outcomes in the loss function, but for simplicity, we will weight them equally.

So for each circuit, the probability, $P(|00\rangle + |10\rangle)$ is the probability the circuit records the input as an $a$ state,
$P(|01 \rangle)$ is the same for $b$, and $P(|11\rangle)$ is the inconclusive probability.
So dependent on the input state, the loss is a weighted sum of these probabilities.
When the input state is $a$ the loss is: $\alpha_{err} P(|01\rangle) + \alpha_{inc} P(|11\rangle)$
When the input state is $b$ the loss is: $\alpha_{err}P(|00\rangle + |10\rangle) + \alpha_{inc} P(|11\rangle)$.

The loss function for this implementation is given in: loss.py

In [16]:
with open('../loss.py', 'r') as f:
    contents = f.read()
print(contents)

import tensorflow as tf
from typing import List


class DiscriminationLoss:

    def __init__(self, w_error: float, w_inconclusive: float):
        self.w_inconclusive = w_inconclusive
        self.w_error = w_error

    @tf.function
    def discrimination_loss(self, y_label, y_measurement):
        y_label = tf.cast(y_label, tf.float32)
        error, inconclusive = tf.map_fn(lambda x: self.measurement_to_loss(x[0], x[1]), (y_measurement, y_label))
        loss_vec = tf.add(error, inconclusive)
        return tf.reduce_mean(loss_vec)

    def measurement_to_loss(self, measurement: tf.Tensor, label: tf.Tensor):
        measurement = tf.squeeze(measurement)
        probs = self.m_outcome_to_probs(measurement)
        return self.probs_to_err_inc(label, probs)

    def probs_to_err_inc(self, label: tf.Tensor, probs: tf.Tensor):
        # 1 == a, 0 == b
        fn_a = lambda: tf.gather(probs, 1)
        fn_b = lambda: tf.add(tf.gather(probs, 0), tf.gather(probs, 2))
        error = tf.cas