In [None]:
if 'google.colab' in str(get_ipython()):
    !pip uninstall tensorflow -y
    !pip install tensorflow==2.3.1 tensorflow-quantum
    !rm -rf quantum-gans
    !git clone https://github.com/WiktorJ/quantum-gans
    !cd quantum-gans; pip install .

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
import random
import tensorflow as tf

import cirq
import numpy as np
from qsgenerator.utils import map_to_radians
from qsgenerator.qugans import circuits
from qsgenerator.qugans.training import Trainer
from qsgenerator.phase.circuits import build_ground_state_circuit
from qsgenerator.phase.analitical import  get_ground_state_for_g
from qsgenerator.states.simple_state_circuits import build_x_rotation_state
from qsgenerator.states.simple_rotation_generators import get_binary_x_rotation_provider 
from qsgenerator.phase.analitical import construct_hamiltonian, get_theta_v, get_theta_w, get_theta_r, get_g_parameters_provider
from qsgenerator.evaluators.circuit_evaluator import CircuitEvaluator

In [None]:
generator_layers = 2
discriminator_layers = 5
data_bus_size = 5

In [None]:
real_phase = True
generic_generator = True
all_layers_labeling = False

In [None]:
gen, gs, disc, ds, ls, data_qubits, out_qubit = circuits.build_gan_circuits(
    generator_layers, 
    discriminator_layers, 
    data_bus_size, 
    all_layers_labeling=all_layers_labeling,
    use_gen_label_qubit=True)

In [None]:
if not generic_generator:
    gen, gs = build_ground_state_circuit(qubits=data_qubits, full_parametrization=True)

In [None]:
if real_phase:
    real, real_symbols = build_ground_state_circuit(qubits=data_qubits)
else:
    real, real_symbols = build_x_rotation_state(qubits=data_qubits)

In [None]:
pure_gen = gen.copy()
gen.append([disc])

In [None]:
pure_real = real.copy()
real.append([disc])

In [None]:
print("REAL GROUND STATE")
pure_real

In [None]:
print("GENERATOR")
pure_gen

In [None]:
print("DISCRIMINATOR")
disc

In [None]:
np.random.seed(0)
eps = 1e-2
init_gen_weights = np.array([0] * len(gs)) + \
                   np.random.normal(scale=eps, size=(len(gs),))
init_disc_weights = np.random.normal(size=(len(ds),))

gen_weights = tf.Variable(init_gen_weights, dtype=tf.float32)
disc_weights = tf.Variable(init_disc_weights, dtype=tf.float32)

In [None]:
class CustomScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, warmup_steps=4000):
        super(CustomScheduler, self).__init__()
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        return max(math.e ** - ((step+200) / (self.warmup_steps / math.log(100))), 0.01)

In [None]:
learning_rate = CustomScheduler()

opt = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

In [None]:
g_values = [-1]
# x_rotations = get_binary_x_rotation_provider({0: '100', 1: '011', 2: '101'})

In [None]:
trainer = Trainer(g_values, 
                  data_bus_size, 
                  disc, 
                  gen, 
                  real, 
                  out_qubit, 
                  ds, 
                  gs, 
                  real_symbols, 
                  ls,
                  real_values_provider = get_g_parameters_provider(),
                  use_analytical_expectation=True)

In [None]:
epochs = 1
disc_iteration = 10
gen_iteration = 1
snapshot_interval_epochs = 10
results = trainer.train(disc_weights,
      gen_weights, 
      opt, 
      epochs=epochs, 
      disc_iteration=disc_iteration, 
      gen_iteration=gen_iteration,
      snapshot_interval_epochs=snapshot_interval_epochs)

In [None]:
def get_gen_for_g(g, gen_weights):
    rad = map_to_radians(g)
    return np.append(gen_weights, rad)

def get_states_and_fidelty_for_real(gen_evaluator, 
                                    real_evaluator, 
                                    gen_weights, 
                                    g, 
                                    size, 
                                    real_provider):
    generated = gen_evaluator.get_state_from_params(get_gen_for_g(g, gen_weights))
    real = real_evaluator.get_state_from_params(real_provider(g))
    return generated, real, cirq.fidelity(generated, real)

def get_states_and_fidelty_for_ground(gen_evaluator, g, gen_weights, size):
    generated = gen_evaluator.get_state_from_params(get_gen_for_g(g, gen_weights), list(range(size)))
    ground = get_ground_state_for_g(g, size)
    return generated, ground, cirq.fidelity(generated, ground)

def compare_generated_for_g(gen_evaluator, g1, g2, gen_weights, size):
    generated1 = gen_evaluator.get_state_from_params(get_gen_for_g(g1, gen_weights), list(range(size)))
    generated2 = gen_evaluator.get_state_from_params(get_gen_for_g(g2, gen_weights), list(range(size)))
    return generated1, generated2, cirq.fidelity(generated1, generated2)

In [None]:
ground1 = get_ground_state_for_g(-0.25, 5)
ground2 = get_ground_state_for_g(-0., 5)
ground1, ground2, cirq.fidelity(ground1, ground2)

In [None]:
trained_disc_weights = tf.Variable(np.array([ 1.25888796e+01,  1.10409822e+01,  1.27487049e+01,  1.32927475e+01,
       -3.20522385e+01,  2.98508596e+00, -7.54223883e-01,  8.97036648e+00,
        8.98472309e+00, -2.77423954e+00,  8.90891266e+00,  5.72837019e+00,
        6.03105211e+00, -4.64482594e+00, -1.10843427e-01,  7.78598115e-02,
        3.00343895e+00,  3.83781940e-01,  6.08641243e+00,  8.64131927e+00,
       -2.17593918e+01,  1.45857897e+01,  2.36893883e+01,  8.75363445e+00,
       -2.42768993e+01,  1.27688437e+01,  2.53628349e+00,  1.39768391e+01,
        1.40961084e+01, -2.04474068e+01,  4.71392822e+00,  1.66926212e+01,
        1.88311214e+01,  9.17525005e+00,  1.16109962e+01,  1.08004580e+01,
        1.10795708e+01,  5.81477690e+00, -5.50215101e+00,  2.22007637e+01,
        2.25015125e+01,  3.21826210e+01, -1.80058708e+01, -7.85126591e+00,
       -7.77073908e+00,  1.40237570e+01,  3.14071465e+01, -1.07477732e+01,
        1.55852342e+00,  2.63786411e+01,  1.31890945e+01,  7.14759350e+00,
        5.46145630e+00, -1.79730053e+01,  7.40563774e+00,  1.06135674e+01,
       -1.55095673e+00, -2.66580944e+01,  6.51995277e+00, -2.37151980e+00,
        3.51896515e+01,  1.50027342e+01,  1.80733763e-02, -1.75052185e+01,
        9.60706902e+00, -6.06413984e+00,  5.35433817e+00, -7.88707399e+00,
        1.27363043e+01,  1.02446747e+01]), dtype=tf.float32)

trained_gen_weights = np.array([
      3.4344875812530518,
      5.743110179901123,
      8.645724296569824,
      -1.5180160999298096,
      7.943936824798584,
      -8.427645683288574,
      -0.7205737233161926,
      -5.19883918762207,
      -1.3178930282592773,
      0.9257112741470337,
      -7.529347896575928,
      -1.465057373046875,
      -2.3816378116607666,
      8.390764236450195,
      -17.026357650756836,
      -0.0793728306889534,
      1.5116246938705444,
      -4.665868759155273,
      1.3489595651626587,
      7.882970333099365,
      3.115827798843384,
      -12.186253547668457,
      -8.703598022460938,
      -6.145135879516602,
      -4.282535076141357,
      -13.794487953186035,
      -2.590550422668457,
      4.893281936645508,
      -5.5978102684021,
      -1.531538963317871,
      -6.545640468597412,
      -0.7774460315704346,
      0.5502520203590393,
      -11.596064567565918
    ])

gen_symbols = gs + (ls,)
# gen_symbols = gs
# real_symbols = ('r0', 'r1', 'r2')

In [None]:
gen_evaluator = CircuitEvaluator(pure_gen, gen_symbols)
real_evaluator = CircuitEvaluator(pure_real, real_symbols)

In [None]:
g = -0.5
get_states_and_fidelty_for_real(gen_evaluator, real_evaluator, trained_gen_weights, g, data_bus_size, get_g_parameters_provider())

In [None]:
g = -1
get_states_and_fidelty_for_real(gen_evaluator, real_evaluator, trained_gen_weights, g, data_bus_size, get_g_parameters_provider())

In [None]:
compare_generated_for_g(gen_evaluator, -0.9, -0.8, trained_gen_weights, data_bus_size)

In [None]:
g = 0
get_states_and_fidelty_for_ground(gen_evaluator, g, trained_gen_weights, data_bus_size)