<a href="https://colab.research.google.com/github/No-Qubit-Left-Behind/Control-Engineering-in-TF/blob/master/TF_GRAPE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Propagator

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
%tensorflow_version 2.x
import tensorflow as tf
import numpy as np
import time
print(tf.__version__)

2.2.0-rc2


In [8]:
class PropagatorVL:
    def __init__(self, no_of_steps, delta_t):
        self.delta_t = delta_t
        self.duration = no_of_steps * delta_t
        """
            contraction_array determines the neccessity for the extra
            matrix multiplication step in the recursive method self.propagate()
            when the intermediate computation array has length not divisible
            by 2
        """
        self.contraction_array = self.gen_contraction_array(no_of_steps)

        self.x = tf.constant(
            [[0, 1], [1, 0]], dtype=tf.complex128
        )
        self.y = tf.constant(
            [[0 + 0j, 0 - 1j], [0 + 1j, 0 + 0j]], dtype=tf.complex128
        )
        
        """
            Van Loan generators for control robustness
        """
        xL = np.block([
            [self.x.numpy(), self.x.numpy()],
            [np.zeros((2, 2)), self.x.numpy()]          
        ])
        yL = np.block([
            [self.y.numpy(), self.y.numpy()],
            [np.zeros((2, 2)), self.y.numpy()]          
        ])

        self.generators =  tf.stack([
                                     tf.constant(xL, dtype=tf.complex128),
                                     tf.constant(yL, dtype=tf.complex128)
                                     ])

        self.ctrl_amplitudes = tf.Variable(
            tf.zeros([no_of_steps, 2], dtype=tf.float64), dtype=tf.float64
        )

    @staticmethod
    def gen_contraction_array(no_of_intervals):
        if no_of_intervals > 1:
            return (
                [bool(np.mod(no_of_intervals, 2))] +
                PropagatorVL.gen_contraction_array(
                    np.floor(no_of_intervals / 2)
                )
            )
        return []

    """
        exponentials() computes a vector matrix exponential after multiplying
        each self.ctrl_amplitudes row with a the vector of matrices in
        self.generators
    """
    def exponentials(self):
        regularized_amplitudes = 1 / np.sqrt(2) * tf.math.tanh(
            self.ctrl_amplitudes
        )

        exponents = tf.linalg.tensordot(
            tf.cast(regularized_amplitudes, dtype=tf.complex128),
            -2 * np.pi *(0 + 1j) * self.delta_t * self.generators, 1
        )
        return tf.linalg.expm(exponents)
    
    """
        propagate  computes the final propagator by recursively multiplying
        each odd element in the list of matrices with each even element --
        if the length of the array is not divisible by 2 an extra computation
        step is added
    """
    def propagate(self):
        step_exps = self.exponentials()
        for is_odd in self.contraction_array:
            if is_odd:
                odd_exp = step_exps[-1, :, :]
                step_exps = tf.linalg.matmul(
                    step_exps[1::2, :, :], step_exps[0:-1:2, :, :]
                )
                step_exps = tf.concat([
                    step_exps[0:-1, :, :],
                    [tf.linalg.matmul(odd_exp, step_exps[-1, :, :])]
                ], 0)
            else:
                step_exps = tf.linalg.matmul(
                    step_exps[1::2, :, :], step_exps[0::2, :, :]
                )
        return tf.squeeze(step_exps)

    """
        __call__ computes the final propagator fidelity squared with the
        identity operator
    """
    @tf.function
    def target(self):
        propagator = self.propagate()
        tr = tf.linalg.trace(tf.linalg.matmul(self.x, propagator[0:2, 0:2]))
        """
            infidelity part in the target
        """
        infidelity = 1 - tf.math.real(tr * tf.math.conj(tr)) / (2 ** 2)
        """
            robustness term in the target
        """
        norm_squared = 1 / ((2 * np.pi * self.duration) ** 2) / 2 * (
            tf.math.real(
                tf.linalg.trace(
                    tf.linalg.matmul(
                        propagator[0:2, 3:4],
                        propagator[0:2, 3:4],
                        adjoint_b=True
                    )
                )
            )
        )

        return 0.5 * infidelity + 0.5 * norm_squared

propagatorVL = PropagatorVL(1000, 0.001)

optimizer = tf.keras.optimizers.Adam(0.01)

propagatorVL.ctrl_amplitudes.assign(
    tf.random.uniform([1000, 2], -1, 1, dtype=tf.float64)
)

def optimization_step():
    with tf.GradientTape() as tape:
        current_target = propagatorVL.target()
    gradients = tape.gradient(current_target, [propagatorVL.ctrl_amplitudes])
    optimizer.apply_gradients(zip(gradients, [propagatorVL.ctrl_amplitudes]))
    return current_target

steps = range(1000)
for step in steps:
    current_target = optimization_step()
    print('step %2d: target=%2.5f' %
          (step, current_target))
    
propagatorVL.ctrl_amplitudes.numpy()

step  0: target=0.49827
step  1: target=0.49578
step  2: target=0.49223
step  3: target=0.48764
step  4: target=0.48202
step  5: target=0.47534
step  6: target=0.46763
step  7: target=0.45887
step  8: target=0.44909
step  9: target=0.43830
step 10: target=0.42654
step 11: target=0.41383
step 12: target=0.40019
step 13: target=0.38569
step 14: target=0.37040
step 15: target=0.35440
step 16: target=0.33776
step 17: target=0.32055
step 18: target=0.30290
step 19: target=0.28491
step 20: target=0.26671
step 21: target=0.24840
step 22: target=0.23010
step 23: target=0.21197
step 24: target=0.19413
step 25: target=0.17669
step 26: target=0.15978
step 27: target=0.14354
step 28: target=0.12806
step 29: target=0.11343
step 30: target=0.09976
step 31: target=0.08710
step 32: target=0.07551
step 33: target=0.06503
step 34: target=0.05567
step 35: target=0.04743
step 36: target=0.04030
step 37: target=0.03423
step 38: target=0.02917
step 39: target=0.02506
step 40: target=0.02182
step 41: target=

array([[ 0.98620155,  2.03369273],
       [ 0.6296251 ,  1.76085528],
       [ 1.46656031,  1.70931244],
       ...,
       [ 0.8067391 , -2.17949372],
       [ 1.20877341, -1.72462396],
       [-0.12352296, -1.791441  ]])