In [None]:
if 'google.colab' in str(get_ipython()):
    !pip uninstall tensorflow -y
    !pip install tensorflow==2.3.1 tensorflow-quantum neptune-client
    !rm -rf quantum-gans
    !git clone https://github.com/WiktorJ/quantum-gans
    !cd quantum-gans; pip install .
    neptun_token = "" # put manually for the time being
else:
    import subprocess
    def get_var(varname):
        CMD = 'echo $(source ~/.bash_profile; echo $%s)' % varname
        p = subprocess.Popen(CMD, stdout=subprocess.PIPE, shell=True, executable='/bin/bash')
        return p.stdout.readlines()[0].strip()
    neptun_token = get_var('NEPTUNE_API_TOKEN').decode("utf-8") 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
use_neptune = True

In [None]:
import neptune
import tensorflow as tf

import io
import cirq
import numpy as np
from qsgenerator import circuits
from qsgenerator.quwgans import circuits as quwgans_circuits
from qsgenerator.quwgans.training import Trainer
from qsgenerator.phase.circuits import PhaseCircuitBuilder
from qsgenerator.evaluators.circuit_evaluator import CircuitEvaluator
from qsgenerator.phase.analitical import get_theta_v, get_theta_w, get_theta_r, get_g_parameters_provider


In [None]:
generator_layers = 3
data_bus_size = 5
rank = 2

In [None]:
real_phase = True
generic_generator = False
zxz = False
all_gates_parametrized = False

In [None]:
data_qubits = qubits = cirq.GridQubit.rect(1, data_bus_size)

In [None]:
builder = PhaseCircuitBuilder(all_gates_parametrized=False)
real, real_symbols, symbols_dict_real = builder.build_ground_state_circuit(qubits=data_qubits)

In [None]:
pauli_strings, qubit_to_string_index = quwgans_circuits.get_discriminator(real)

In [None]:
if generic_generator:
    gen, gs = circuits.build_circuit(generator_layers, data_qubits, "g")
    symbols_dict_gen = {}
else:
    builder = PhaseCircuitBuilder(all_gates_parametrized=all_gates_parametrized)
    gen, gs, symbols_dict_gen = builder.build_ground_state_circuit(qubits=data_qubits, full_parametrization=True, zxz=zxz)

In [None]:
print("REAL GROUND STATE")
real

In [None]:
print("GENERATOR")
gen

In [None]:
g_values = [-0.8]
real_values_provider = get_g_parameters_provider()

In [None]:
opt = tf.keras.optimizers.Adam(0.1, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

In [None]:
trainer = Trainer(real, 
                  real_symbols,
                  gen,
                  gs,
                  g_values,
                  real_values_provider,
                  rank=rank,
                  use_neptune=use_neptune)

In [None]:
epochs = 5
gen_iteration = 1
snapshot_interval_epochs = 2

In [None]:
if use_neptune:
    neptune.init(project_qualified_name='wiktor.jurasz/thesis-em', api_token=neptun_token)
    neptun_params = {
        'generator_layers': generator_layers,
        'size':  data_bus_size,
        'rank': rank,
        'real_phase': real_phase,
        'generic_generator': generic_generator,
        'zxz': zxz,
        'all_gates_parametrized': all_gates_parametrized,
        'g_values': g_values,
        'gen_iteration': gen_iteration,
        'epochs': epochs
    }
    neptune.create_experiment(name=None, description=None, params=neptun_params)
    neptune.log_artifact(io.StringIO(str(gen)), "gen.txt")
    neptune.log_artifact(io.StringIO(str(([(el[0].numpy(), el[1], list(el[2].numpy())) for el in trainer.gen_weights]))), 'init_gen_weights.txt')

In [None]:
json_result = trainer.train(opt, epochs, gen_iteration, snapshot_interval_epochs, plot=True)

In [None]:
def get_all_states_and_fidelty_for_real(gen_evaluator, 
                                    real_evaluator,
                                    g, 
                                    size):
    generated = gen_evaluator.get_all_states_from_params()
    real = real_evaluator.get_all_states_from_params()
    
    return generated, real, 

In [None]:
gen_evaluator = trainer.gen_evaluator
real_evaluator = trainer.real_evaluator

In [None]:
trainer.get_fidelty_for_real()

In [None]:
gen_evaluator.get_all_states_from_params()

In [None]:
real_evaluator.get_all_states_from_params()

In [None]:
if use_neptune:
    neptune.stop()