In [None]:
if 'google.colab' in str(get_ipython()):
    !pip uninstall tensorflow -y
    !pip install tensorflow==2.3.1 tensorflow-quantum
    !rm -rf quantum-gans
    !git clone https://github.com/WiktorJ/quantum-gans
    !cd quantum-gans; pip install .

In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [34]:
import math
import random
import tensorflow as tf

import cirq
import numpy as np
from qsgenerator.utils import map_to_radians
from qsgenerator.qugans import circuits
from qsgenerator.qugans.training import Trainer
from qsgenerator.phase.circuits import build_ground_state_circuit
from qsgenerator.phase.analitical import  get_ground_state_for_g
from qsgenerator.states.simple_state_circuits import build_x_rotation_state
from qsgenerator.states.simple_rotation_generators import get_binary_x_rotation_provider, get_arcsin_x_rotation_provider
from qsgenerator.phase.analitical import construct_hamiltonian, get_theta_v, get_theta_w, get_theta_r, get_g_parameters_provider
from qsgenerator.evaluators.circuit_evaluator import CircuitEvaluator

In [47]:
generator_layers = 1
discriminator_layers = 2
data_bus_size = 5

In [48]:
real_phase = False
generic_generator = True
all_layers_labeling = False
full_layer_labeling = False
use_gen_label_qubit = False

In [96]:
gen, gs, disc, ds, ls, data_qubits, out_qubit = circuits.build_gan_circuits(
    generator_layers, 
    discriminator_layers, 
    data_bus_size, 
    all_layers_labeling=all_layers_labeling,
    full_layer_labeling=full_layer_labeling,
    use_gen_label_qubit=use_gen_label_qubit)

In [97]:
if not generic_generator:
    gen, gs = build_ground_state_circuit(qubits=data_qubits, full_parametrization=True)

In [98]:
if real_phase:
    real, real_symbols = build_ground_state_circuit(qubits=data_qubits)
else:
    real, real_symbols = build_x_rotation_state(qubits=data_qubits)

In [99]:
pure_gen = gen.copy()
gen.append([disc])

In [100]:
pure_real = real.copy()
real.append([disc])

In [101]:
print("REAL GROUND STATE")
pure_real

REAL GROUND STATE


In [102]:
print("GENERATOR")
pure_gen

GENERATOR


In [56]:
print("DISCRIMINATOR")
disc

DISCRIMINATOR


In [57]:
np.random.seed(0)
eps = 1e-2
init_gen_weights = np.array([0] * len(gs)) + \
                   np.random.normal(scale=eps, size=(len(gs),))
init_disc_weights = np.random.normal(size=(len(ds),))

gen_weights = tf.Variable(init_gen_weights, dtype=tf.float32)
disc_weights = tf.Variable(init_disc_weights, dtype=tf.float32)

In [58]:
class CustomScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, warmup_steps=4000):
        super(CustomScheduler, self).__init__()
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        return max(math.e ** - ((step+200) / (self.warmup_steps / math.log(100))), 0.01)

In [59]:
learning_rate = CustomScheduler()

opt = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

In [72]:
g_values = [0.1]
# x_rotations = get_binary_x_rotation_provider({0: '100', 1: '011', 2: '101'})
x_rotations = get_arcsin_x_rotation_provider(g_values, 5)

In [63]:
trainer = Trainer(g_values, 
                  data_bus_size, 
                  disc, 
                  gen, 
                  real, 
                  out_qubit, 
                  ds, 
                  gs, 
                  real_symbols, 
                  ls,
                  real_values_provider = x_rotations,
                  use_analytical_expectation=True)

In [None]:
epochs = 1
disc_iteration = 10
gen_iteration = 1
snapshot_interval_epochs = 10
results = trainer.train(disc_weights,
      gen_weights, 
      opt, 
      epochs=epochs, 
      disc_iteration=disc_iteration, 
      gen_iteration=gen_iteration,
      snapshot_interval_epochs=snapshot_interval_epochs)

In [68]:
def get_gen_for_g(g, gen_weights):
    rad = map_to_radians(g)
    return np.append(gen_weights, rad)

def get_states_and_fidelty_for_real(gen_evaluator, 
                                    real_evaluator, 
                                    gen_weights, 
                                    g, 
                                    size, 
                                    real_provider):
    generated = gen_evaluator.get_state_from_params(get_gen_for_g(g, gen_weights))
    real = real_evaluator.get_state_from_params(real_provider(g))
    return generated, real, cirq.fidelity(generated, real)

def get_states_and_fidelty_for_ground(gen_evaluator, g, gen_weights, size):
    generated = gen_evaluator.get_state_from_params(get_gen_for_g(g, gen_weights), list(range(size)))
    ground = get_ground_state_for_g(g, size)
    return generated, ground, cirq.fidelity(generated, ground)

def compare_generated_for_g(gen_evaluator, g1, g2, gen_weights, size):
    generated1 = gen_evaluator.get_state_from_params(get_gen_for_g(g1, gen_weights), list(range(size)))
    generated2 = gen_evaluator.get_state_from_params(get_gen_for_g(g2, gen_weights), list(range(size)))
    return generated1, generated2, cirq.fidelity(generated1, generated2)

In [129]:
trained_disc_weights = tf.Variable(np.array([ 1.25888796e+01,  1.10409822e+01,  1.27487049e+01,  1.32927475e+01,
       -3.20522385e+01,  2.98508596e+00, -7.54223883e-01,  8.97036648e+00,
        8.98472309e+00, -2.77423954e+00,  8.90891266e+00,  5.72837019e+00,
        6.03105211e+00, -4.64482594e+00, -1.10843427e-01,  7.78598115e-02,
        3.00343895e+00,  3.83781940e-01,  6.08641243e+00,  8.64131927e+00,
       -2.17593918e+01,  1.45857897e+01,  2.36893883e+01,  8.75363445e+00,
       -2.42768993e+01,  1.27688437e+01,  2.53628349e+00,  1.39768391e+01,
        1.40961084e+01, -2.04474068e+01,  4.71392822e+00,  1.66926212e+01,
        1.88311214e+01,  9.17525005e+00,  1.16109962e+01,  1.08004580e+01,
        1.10795708e+01,  5.81477690e+00, -5.50215101e+00,  2.22007637e+01,
        2.25015125e+01,  3.21826210e+01, -1.80058708e+01, -7.85126591e+00,
       -7.77073908e+00,  1.40237570e+01,  3.14071465e+01, -1.07477732e+01,
        1.55852342e+00,  2.63786411e+01,  1.31890945e+01,  7.14759350e+00,
        5.46145630e+00, -1.79730053e+01,  7.40563774e+00,  1.06135674e+01,
       -1.55095673e+00, -2.66580944e+01,  6.51995277e+00, -2.37151980e+00,
        3.51896515e+01,  1.50027342e+01,  1.80733763e-02, -1.75052185e+01,
        9.60706902e+00, -6.06413984e+00,  5.35433817e+00, -7.88707399e+00,
        1.27363043e+01,  1.02446747e+01]), dtype=tf.float32)

trained_gen_weights = np.array([
      -0.20239973068237305,
      4.103419780731201,
      4.835892677307129,
      -5.277573585510254,
      7.324734210968018,
      -0.006883351132273674,
      -5.015811920166016,
      -12.468819618225098,
      -2.2864558696746826,
      6.433105945587158,
      6.776863098144531,
      -4.772155284881592,
      -4.78993034362793,
      -8.97536563873291
    ])

    "disc_weights": [
      6.056931018829346,
      -0.15719473361968994,
      7.582610130310059,
      2.3493590354919434,
      20.549972534179688,
      -6.741025924682617,
      -11.768488883972168,
      -1.7888803482055664,
      -3.6880688667297363,
      0.23714135587215424,
      0.26757222414016724,
      10.884228706359863,
      26.334882736206055,
      12.091363906860352,
      -0.33033716678619385,
      21.02811622619629,
      6.088003158569336,
      3.49955677986145,
      1.9190806150436401,
      5.089786052703857,
      1.5454983711242676,
      1.5706307888031006,
      -10.008342742919922,
      6.43993616104126,
      25.079999923706055,
      15.897772789001465,
      -8.204235076904297,
      1.4552363157272339,
      27.395193099975586,
      8.439077377319336,
      1.9011459350585938,
      7.813032150268555,
      -7.695796489715576,
      15.171481132507324,
      0.49958866834640503,
      23.002290725708008,
      -1.2182766199111938,
      3.0128731727600098,
      19.245874404907227,
      13.645651817321777,
      -1.571423888206482,
      26.03333282470703,
      29.370750427246094,
      4.49747896194458,
      14.91414737701416,
      2.9845504760742188,
      4.903937816619873,
      -4.884820461273193,
      36.793392181396484,
      10.524744987487793,
      18.778743743896484,
      29.2131290435791,
      24.37220001220703,
      23.11734390258789,
      24.74367332458496,
      11.340784072875977,
      30.13662338256836,
      41.193572998046875,
      23.201465606689453,
      9.845664024353027
    ]

    "disc_weights": [
      6.056931018829346,
      -0.15719473361968994,
      7.582610130310059,
      2.3493590354919434,
      20.549972534179688,
      -6.741025924682617,
      -11.768488883972168,
      -1.7888803482055664,
      -3.6880688667297363,
      0.23714135587215424,
      0.26757222414016724,
      10.884228706359863,
      26.334882736206055,
      12.091363906860352,
      -0.33033716678619385,
      21.02811622619629,
      6.088003158569336,
      3.49955677986145,
      1.9190806150436401,
      5.089786052703857,
      1.5454983711242676,
      1.5706307888031006,
      -10.008342742919922,
      6.43993616104126,
      25.079999923706055,
      15.897772789001465,
      -8.204235076904297,
      1.4552363157272339,
      27.395193099975586,
      8.439077377319336,
      1.9011459350585938,
      7.813032150268555,
      -7.695796489715576,
      15.171481132507324,
      0.49958866834640503,
      23.002290725708008,
      -1.2182766199111938,
      3.0128731727600098,
      19.245874404907227,
      13.645651817321777,
      -1.571423888206482,
      26.03333282470703,
      29.370750427246094,
      4.49747896194458,
      14.91414737701416,
      2.9845504760742188,
      4.903937816619873,
      -4.884820461273193,
      36.793392181396484,
      10.524744987487793,
      18.778743743896484,
      29.2131290435791,
      24.37220001220703,
      23.11734390258789,
      24.74367332458496,
      11.340784072875977,
      30.13662338256836,
      41.193572998046875,
      23.201465606689453,
      9.845664024353027
    ],
gen_symbols = gs 
# gen_symbols = gs
# real_symbols = ('r0', 'r1', 'r2')

In [130]:
gen_evaluator = CircuitEvaluator(pure_gen, gen_symbols)
real_evaluator = CircuitEvaluator(pure_real, real_symbols)
x_rotations(g)

[0.2003348423231196,
 0.4027158415806616,
 0.609385308030795,
 0.8230336921349761,
 1.0471975511965976]

In [131]:
g=0.1
generated = gen_evaluator.get_state_from_params(trained_gen_weights)
real = real_evaluator.get_state_from_params(x_rotations(g))
generated, real, cirq.fidelity(generated, real)

(array([ 7.38144875e-01+2.54046521e-03j,  1.46673818e-03-4.26168144e-01j,
         1.10874989e-03-3.22152823e-01j, -1.85995027e-01-6.40137063e-04j,
         7.98939320e-04-2.32135817e-01j, -1.34023681e-01-4.61267831e-04j,
        -1.01312384e-01-3.48685717e-04j, -2.01313800e-04+5.84927313e-02j,
         5.18570247e-04-1.50673181e-01j, -8.69912058e-02-2.99396692e-04j,
        -6.57591745e-02-2.26322634e-04j, -1.30667424e-04+3.79660763e-02j,
        -4.73845303e-02-1.63082819e-04j, -9.41559119e-05+2.73574702e-02j,
        -7.11751782e-05+2.06803028e-02j,  1.19397789e-02+4.10930079e-05j,
        -2.55326304e-04-7.41863474e-02j, -4.28315103e-02+1.47412720e-04j,
        -3.23775820e-02+1.11433576e-04j,  6.43361927e-05+1.86932068e-02j,
        -2.33305302e-02+8.02964205e-05j,  4.63591678e-05+1.34698888e-02j,
         3.50442351e-05+1.01822773e-02j,  5.87874139e-03-2.02327974e-05j,
        -1.51432259e-02+5.21182701e-05j,  3.00904994e-05+8.74294620e-03j,
         2.27462788e-05+6.60904590e-03

In [93]:
g = 0.1
real_evaluator.get_state_from_params(x_rotations(g))

array([ 0.7381492 +0.j        ,  0.        -0.42617065j,
        0.        -0.32215473j, -0.18599613+0.j        ,
        0.        -0.23213719j, -0.13402447+0.j        ,
       -0.10131298+0.j        ,  0.        +0.05849308j,
        0.        -0.15067407j, -0.08699172+0.j        ,
       -0.06575956+0.j        ,  0.        +0.0379663j ,
       -0.04738481+0.j        ,  0.        +0.02735763j,
        0.        +0.02068043j,  0.01193985+0.j        ,
        0.        -0.07418679j, -0.04283176+0.j        ,
       -0.03237777+0.j        ,  0.        +0.01869331j,
       -0.02333067+0.j        ,  0.        +0.01346997j,
        0.        +0.01018234j,  0.00587878+0.j        ,
       -0.01514332+0.j        ,  0.        +0.008743j  ,
        0.        +0.00660909j,  0.00381576+0.j        ,
        0.        +0.00476235j,  0.00274955+0.j        ,
        0.00207846+0.j        ,  0.        -0.0012j    ], dtype=complex64)

In [73]:
g = 0.1
get_states_and_fidelty_for_real(gen_evaluator, real_evaluator, trained_gen_weights, g, data_bus_size, x_rotations)

(array([ 4.0448181e-02-3.6572769e-02j,  3.0413866e-02+3.7496548e-02j,
        -4.3743864e-02-4.7772691e-02j,  5.7309680e-02+2.1482296e-03j,
        -3.5664630e-01-3.0501288e-01j,  2.8659377e-01-3.0083203e-01j,
        -2.3218048e-01-5.0678837e-01j,  4.6374798e-01+1.6889630e-01j,
         3.1307053e-03-3.0828714e-03j,  2.5769311e-03+2.9141973e-03j,
        -3.6852700e-03-3.6957406e-03j,  4.6208757e-03-2.3607048e-05j,
        -4.4369027e-03-3.7551146e-02j,  3.3409685e-02-2.1400079e-03j,
         1.3960926e-02-4.2690683e-02j,  1.8217541e-02+3.5348885e-02j,
        -9.5909908e-03-3.4613612e-03j,  3.5155120e-03-8.3150957e-03j,
        -4.0399656e-03+1.1418272e-02j, -4.6823034e-03-9.6473601e-03j,
        -1.8995151e-02+8.5669011e-02j, -7.4838422e-02-2.0860633e-02j,
        -6.3235864e-02+8.2860529e-02j, -1.2882898e-02-9.1382317e-02j,
        -7.4565585e-04-3.4494483e-04j,  3.4036499e-04-6.4286019e-04j,
        -4.0417822e-04+8.8829501e-04j, -3.0835066e-04-8.0714206e-04j,
        -6.2797572e-

In [190]:
g = -1
get_states_and_fidelty_for_real(gen_evaluator, real_evaluator, trained_gen_weights, g, data_bus_size, get_g_parameters_provider())

(array([-0.02435654-0.1902056j , -0.19515288-0.11648373j,
         0.08345419-0.1903887j , -0.18535236-0.08096562j,
         0.11612582+0.04735206j, -0.06836013+0.19240332j,
         0.14600399-0.11264499j, -0.00122743+0.13270715j,
         0.00031277+0.1505127j ,  0.16298561+0.11961195j,
        -0.08987801+0.16151407j,  0.13619669+0.08167708j,
         0.11782363+0.02000505j,  0.01544211+0.19913372j,
         0.08575617-0.15900977j,  0.02598283+0.12305039j,
         0.05198056+0.16467279j,  0.1923159 +0.07155653j,
        -0.04311047+0.18240288j,  0.17686254+0.04210223j,
        -0.10968473-0.02181027j,  0.03130385-0.17920177j,
        -0.11122639+0.12064118j, -0.02152339-0.11628234j,
        -0.01267695+0.19866441j,  0.17221147+0.14199674j,
        -0.10956728+0.17222396j,  0.17286973+0.11913121j,
         0.12042966+0.01432219j, -0.03035793+0.17915641j,
         0.11047583-0.12090723j,  0.03273152+0.12426975j], dtype=complex64),
 array([ 0.1767766 +0.j,  0.1767766 +0.j,  0.1767766 

In [None]:
compare_generated_for_g(gen_evaluator, -0.9, -0.8, trained_gen_weights, data_bus_size)

In [None]:
g = 0
get_states_and_fidelty_for_ground(gen_evaluator, g, trained_gen_weights, data_bus_size)