In [1]:
# Test of the circuit from "Crossing a topological phase transition with a quantum computer" (https://arxiv.org/pdf/1910.05351.pdf)

In [6]:
from typing import Any

import sympy
import math
import tensorflow as tf
import tensorflow_quantum as tfq

import cirq
import numpy as np
from cirq import GridQubit, ops
from qsgenerator.utils import map_to_radians
from qsgenerator.phase.analitical import construct_hamiltonian, get_theta_v, get_theta_w, get_theta_r
from qsgenerator.phase.circuits import build_ground_state_circuit, build_u1_gate, build_u_gate
from qsgenerator.states.simple_state_circuits import build_x_rotation_state


In [222]:
size = 3# size of quantum circuit excluding boundary qubits

In [231]:
g = -0.5 # G parameter from the paper 

In [232]:
H = construct_hamiltonian(size, g)
lam, V = np.linalg.eigh(H)

# ground state wavefunction
psi = V[:, 0] / np.linalg.norm(V[:, 0])

In [233]:
real, real_symbols = build_ground_state_circuit(size=size)

In [234]:
real

In [235]:
resolver = cirq.ParamResolver({'theta_v': get_theta_v(g), 'theta_w': get_theta_w(g), 'theta_r': get_theta_r(g)})

In [236]:
resolved = cirq.resolve_parameters(real, resolver)

In [237]:
final_state = cirq.final_state_vector(resolved)

In [238]:
final_state

array([ 0.22222227+0.j,  0.31426978+0.j,  0.22222224+0.j, -0.15713488+0.j,
       -0.1111111 +0.j, -0.15713485+0.j,  0.22222221+0.j, -0.15713486+0.j,
       -0.11111113+0.j, -0.15713489+0.j, -0.11111112+0.j,  0.07856744+0.j,
       -0.11111112+0.j, -0.15713486+0.j,  0.22222224+0.j, -0.15713488+0.j,
       -0.15713489+0.j, -0.22222227+0.j, -0.15713489+0.j,  0.11111113+0.j,
        0.07856743+0.j,  0.11111113+0.j, -0.15713486+0.j,  0.11111113+0.j,
       -0.15713486+0.j, -0.22222225+0.j, -0.15713485+0.j,  0.11111112+0.j,
       -0.15713486+0.j, -0.22222225+0.j,  0.31426978+0.j, -0.22222227+0.j],
      dtype=complex64)

In [143]:
first_qubit_zero_mask = [1 if len(final_state)/2 > i else 0 for i in range(len(final_state)) ]
last_qubit_zero_mask = [1 if i %2 == 0 else 0 for i in range(len(final_state))]

In [146]:
real_state = first_qubit_zero_mask * final_state * last_qubit_zero_mask
real_state

array([0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j,
       0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j,
       0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j,
       0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j])

In [147]:
norm_state = real_state/np.linalg.norm(real_state)
norm_state

  norm_state = real_state/np.linalg.norm(real_state)


array([nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj,
       nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj,
       nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj,
       nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj,
       nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj, nan+nanj,
       nan+nanj, nan+nanj])

In [148]:
first_qubit = next(q for q in real.all_qubits() if q.col == 0)
last_qubit = next(q for q in real.all_qubits() if q.col == len(real.all_qubits()) - 1)


In [149]:
realm = real + cirq.measure(first_qubit) + cirq.measure(last_qubit) 
resolvedm = cirq.resolve_parameters(realm, resolver)

In [171]:
cirq.Simulator().simulate(resolved)

measurements: (no measurements)
output vector: 0.707|00001⟩ + 0.707|11110⟩

In [166]:
fs = cirq.final_state_vector(resolvedm) 
fs

array([0.       +0.j, 0.9999999+0.j, 0.       +0.j, 0.       +0.j,
       0.       +0.j, 0.       +0.j, 0.       +0.j, 0.       +0.j,
       0.       +0.j, 0.       +0.j, 0.       +0.j, 0.       +0.j,
       0.       +0.j, 0.       +0.j, 0.       +0.j, 0.       +0.j,
       0.       +0.j, 0.       +0.j, 0.       +0.j, 0.       +0.j,
       0.       +0.j, 0.       +0.j, 0.       +0.j, 0.       +0.j,
       0.       +0.j, 0.       +0.j, 0.       +0.j, 0.       +0.j,
       0.       +0.j, 0.       +0.j, 0.       +0.j, 0.       +0.j],
      dtype=complex64)

In [170]:
cirq.partial_trace_of_state_vector_as_mixture(fs, [1,2,3]) 

((0.49999991059303284,
  array([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
        dtype=complex64)),
 (0.49999991059303284,
  array([0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j],
        dtype=complex64)))

In [169]:
psi

array([1., 0., 0., 0., 0., 0., 0., 0.])

In [98]:
# TODO: plug in 0 instead of tracing-out
# TODO: renormalize
partials = cirq.partial_trace_of_state_vector_as_mixture(final_state, [4])
partials

((0.49999991059303284, array([1.+0.j, 0.+0.j], dtype=complex64)),
 (0.49999991059303284, array([0.+0.j, 1.+0.j], dtype=complex64)))

In [18]:
for p in partials:
    print(p[0], cirq.fidelity(p[1], psi))

0.2499997466802597 0.29109976691641815
0.2499997466802597 0.1830890225454351
0.2499997764825821 0.2089002146716102
0.2499997764825821 0.3169109914231965


In [110]:
g = 0.8

q1, q2 = cirq.GridQubit.rect(1, 2)
cirq.unitary(build_u1_gate(q1, q2, theta_r))

array([[ 0.47140452+0.j,  0.52704628+0.j,  0.47140452+0.j,
         0.52704628+0.j],
       [ 0.52704628+0.j, -0.47140452+0.j,  0.52704628+0.j,
        -0.47140452+0.j],
       [ 0.52704628+0.j,  0.47140452+0.j, -0.52704628+0.j,
        -0.47140452+0.j],
       [-0.47140452+0.j,  0.52704628+0.j,  0.47140452+0.j,
        -0.52704628+0.j]])

In [111]:
size
x_rot, x_rot_symbols = build_x_rotation_state(size=size)

In [112]:
x_resolver = cirq.ParamResolver({'r0': get_theta_v(g), 'r1': get_theta_w(g), 'r2': get_theta_r(g)})

In [113]:
x_resolved = cirq.resolve_parameters(x_rot, x_resolver)

In [109]:
g05 = cirq.final_state_vector(x_resolved)

In [114]:
g1 = cirq.final_state_vector(x_resolved)

In [115]:
g1

array([ 0.56851923+0.j        ,  0.        -0.63562375j,
        0.        -0.2542495j , -0.2842596 +0.j        ,
        0.        -0.21715502j, -0.24278669+0.j        ,
       -0.09711468+0.j        ,  0.        +0.10857751j], dtype=complex64)

In [116]:
g05

array([ 0.2542495 +0.j        ,  0.        -0.2842596j ,
        0.        -0.56851923j, -0.63562375+0.j        ,
        0.        -0.09711468j, -0.10857751+0.j        ,
       -0.21715502+0.j        ,  0.        +0.24278669j], dtype=complex64)

In [117]:
cirq.fidelity(g1, g05)

0.5555555126335463