In [None]:
#%%
# note: timestamp can't use "/" character for h5 saving.
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S"
END_OPT_STRING = "\n" + "=" * 60 + "\n"
import numpy as np
import tensorflow as tf

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)  # supress warnings
import h5py


import ECD_control.ECD_optimization.tf_quantum as tfq
from ECD_control.ECD_optimization.visualization import VisualizationMixin
import qutip as qt
import datetime
import time
class BatchOptimizer(VisualizationMixin):

    # a block is defined as the unitary: CD(beta)D(alpha)R_phi(theta)
    def __init__(
        self,
        optimization_type="state transfer",
        target_unitary=None,
        P_cav1=None,
        P_cav2 = None, #EG
        N_cav1=None,
        N_cav2 = None, #EG
        initial_states=None,
        target_states=None,
        expectation_operators=None,
        target_expectation_values=None,
        N_multistart=10,
        N_blocks=20,
        term_fid=0.99,  # can set >1 to force run all epochs
        dfid_stop=1e-4,  # can be set= -1 to force run all epochs
        learning_rate=0.01,
        epoch_size=10,
        epochs=100,
        beta_scale=1.0,
        gamma_scale = 1.0, #EG
        alpha1_scale=1.0,
        alpha2_scale=1.0,
        theta_scale=np.pi,
        use_etas=False,
        use_displacements=False,
        no_CD_end=False,
        beta_mask=None,
        gamma_mask = None, 
        phi_mask=None,
        eta_mask=None,
        theta_mask=None,
        alpha1_mask=None,
        alpha2_mask=None,
        name="ECD_control",
        filename=None,
        comment="",
        use_phase=False,  # include the phase in the optimization cost function. Important for unitaries.
        timestamps=[],
        initial_params = None,
        **kwargs
    ):
        '''
        Calls the following functions:
        
        construct needed matrices called
        construct opt masks called
        batch state transfer fids real part called
        construct block ops batch called
        construct displacement ops batch called
        construct displacement ops batch called
        '''
        self.parameters = {
            "optimization_type": optimization_type,
            "N_multistart": N_multistart,
            "N_blocks": N_blocks,
            "term_fid": term_fid,
            "dfid_stop": dfid_stop,
            "no_CD_end": no_CD_end,
            "learning_rate": learning_rate,
            "epoch_size": epoch_size,
            "epochs": epochs,
            "beta_scale": beta_scale,
            "gamma_scale": gamma_scale, #EG
            "alpha1_scale": alpha1_scale,
            "alpha2_scale": alpha2_scale,
            "theta_scale": theta_scale,
            "use_etas": use_etas,
            "use_displacements": use_displacements,
            "use_phase": use_phase,
            "name": name,
            "comment": comment,
            "initial_params": initial_params,
        }
        self.parameters.update(kwargs)
        if (
            self.parameters["optimization_type"] == "state transfer"
            or self.parameters["optimization_type"] == "analysis"
        ):
            self.batch_fidelities = (
                self.batch_state_transfer_fidelities
                if self.parameters["use_phase"]
                else self.batch_state_transfer_fidelities_real_part
            )
            # set fidelity function

            self.initial_states = tf.stack(
                [tfq.qt2tf(state) for state in initial_states]
            )

            self.target_unitary = tfq.qt2tf(target_unitary)

            # if self.target_unitary is not None: TODO
            #     raise Exception("Need to fix target_unitary multi-state transfer generation!")

            self.target_states = (  # store dag
                tf.stack([tfq.qt2tf(state) for state in target_states])
                if self.target_unitary is None
                else self.target_unitary @ self.initial_states
            )

            self.target_states_dag = tf.linalg.adjoint(
                self.target_states
            )  # store dag to avoid having to take adjoint

        elif self.parameters["optimization_type"] == "expectation":
            raise Exception("Need to implement expectation optimization")
        elif (
            self.parameters["optimization_type"] == "calculation"
        ):  # using functions but not doing opt
            pass
        else:
            raise ValueError(
                "optimization_type must be one of {'state transfer', 'unitary', 'expectation', 'analysis', 'calculation'}"
            )

        self.parameters["N_cav1"] = N_cav1
        self.parameters["N_cav2"] = N_cav2
#         if P_cav is not None:
#             self.parameters["P_cav1"] = P_cav1
#             self.parameters["P_cav2"] = P_cav2

        print(initial_params)
        if initial_params is None: 
            self.randomize_and_set_vars()
        else: 
            betas, gammas, alphas1, alphas2, phis, etas, thetas = initial_params
            self.set_tf_vars(betas=betas, 
                             gammas = gammas,
                             alphas1=alphas1,
                             alphas2=alphas2,
                             phis=phis, 
                             etas = etas,
                             thetas=thetas)
        
        
        
        self._construct_needed_matrices()
        self._construct_optimization_masks(beta_mask,gamma_mask, alpha1_mask, alpha2_mask, phi_mask,eta_mask, theta_mask)

        # opt data will be a dictionary of dictonaries used to store optimization data
        # the dictionary will be addressed by timestamps of optmization.
        # each opt will append to opt_data a dictionary
        # this dictionary will contain optimization parameters and results

        self.timestamps = timestamps
        self.filename = (
            filename
            if (filename is not None and filename != "")
            else self.parameters["name"]
        )
        path = self.filename.split(".")
        if len(path) < 2 or (len(path) == 2 and path[-1] != ".h5"):
            self.filename = path[0] + ".h5"


    def _construct_needed_matrices(self):
        """
        EG: this function is making the a, a_dagger matrices for the cavity modes
        and these functions will be later exponentiated to get the displacement operator
        
        Edited to inclide 2nd mode
        """
        #print('construct needed matrices called')
        N_cav1 = self.parameters["N_cav1"]
        N_cav2 = self.parameters["N_cav2"]
        q1 = tfq.position(N_cav1)
        p1 = tfq.momentum(N_cav1)
        q2 = tfq.position(N_cav2)
        p2 = tfq.momentum(N_cav2)

        # Pre-diagonalize
        (self._eig_q1, self._U_q1) = tf.linalg.eigh(q1)
        (self._eig_p1, self._U_p1) = tf.linalg.eigh(p1)
        (self._eig_q2, self._U_q2) = tf.linalg.eigh(q2)
        (self._eig_p2, self._U_p2) = tf.linalg.eigh(p2)

        self._qp1_comm = tf.linalg.diag_part(q1 @ p1 - p1 @ q1)
        self._qp2_comm = tf.linalg.diag_part(q2 @ p2 - p2 @ q2)
        #first mode
        if self.parameters["optimization_type"] == "unitary":
            P_cav1 = self.parameters["P_cav1"]
            partial_I1 = np.array(qt.identity(N_cav1))
            for j in range(P_cav1, N_cav1):
                partial_I1[j, j] = 0
            partial_I1 = qt.Qobj(partial_I1)
            self.P1_matrix = tfq.qt2tf(qt.tensor(qt.identity(2), partial_I1))
        #second mode
        if self.parameters["optimization_type"] == "unitary":
            P_cav2 = self.parameters["P_cav2"]
            partial_I2 = np.array(qt.identity(N_cav2))
            for j in range(P_cav2, N_cav2):
                partial_I2[j, j] = 0
            partial_I2 = qt.Qobj(partial_I2)
            self.P2_matrix = tfq.qt2tf(qt.tensor(qt.identity(2), partial_I2))

    def _construct_optimization_masks(
        self, beta_mask=None,gamma_mask = None, alpha1_mask=None, alpha2_mask=None, phi_mask=None,eta_mask=None, theta_mask=None
    ):
        """
        EG: What is a mask?
        
        Edit: Added gamma
        """
        #print('construct opt masks called')
        if beta_mask is None:
            beta_mask = np.ones(
                shape=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
                dtype=np.float32,
            )
            if self.parameters["no_CD_end"]:
                beta_mask[-1, :] = 0  # don't optimize final CD
        else:
            # TODO: add mask to self.parameters for saving if it's non standard!
            raise Exception(
                "need to implement non-standard masks for batch optimization"
            )
        if gamma_mask is None:
            gamma_mask = np.ones(
                shape=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
                dtype=np.float32,
            )
            if self.parameters["no_CD_end"]:
                gamma_mask[-1, :] = 0  # don't optimize final CD
        else:
            # TODO: add mask to self.parameters for saving if it's non standard!
            raise Exception(
                "need to implement non-standard masks for batch optimization"
            )
        if alpha1_mask is None:
            alpha1_mask = np.ones(
                shape=(1, self.parameters["N_multistart"]), dtype=np.float32,
            )
        else:
            raise Exception(
                "need to implement non-standard masks for batch optimization"
            )
        if alpha2_mask is None:
            alpha2_mask = np.ones(
                shape=(1, self.parameters["N_multistart"]), dtype=np.float32,
            )
        else:
            raise Exception(
                "need to implement non-standard masks for batch optimization"
            )
        if phi_mask is None:
            phi_mask = np.ones(
                shape=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
                dtype=np.float32,
            )
            phi_mask[0, :] = 0  # stop gradient on first phi entry
        else:
            raise Exception(
                "need to implement non-standard masks for batch optimization"
            )
        if eta_mask is None:
            eta_mask = np.ones(
                shape=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
                dtype=np.float32,
            )
            phi_mask[0, :] = 0  # stop gradient on first phi entry
        else:
            raise Exception(
                "need to implement non-standard masks for batch optimization"
            )
        if theta_mask is None:
            theta_mask = np.ones(
                shape=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
                dtype=np.float32,
            )
        else:
            raise Exception(
                "need to implement non-standard masks for batch optimization"
            )
        self.beta_mask = beta_mask
        self.gamma_mask = gamma_mask
        self.alpha1_mask = alpha1_mask
        self.alpha2_mask = alpha2_mask
        self.phi_mask = phi_mask
        self.eta_mask = eta_mask
        self.theta_mask = theta_mask

    @tf.function
    def batch_construct_displacement_operators(self, alphas1, alphas2):
        '''
        Input: a list of displacements for 
        '''
        #print('construct displacement ops batch called')
        # Reshape amplitudes for broadcast against diagonals
        sqrt2 = tf.math.sqrt(tf.constant(2, dtype=tf.complex64))
        
        re_a1 = tf.reshape(
            sqrt2 * tf.cast(tf.math.real(alphas1), dtype=tf.complex64),
            [alphas1.shape[0], alphas1.shape[1], 1],
        )
        im_a1 = tf.reshape(
            sqrt2 * tf.cast(tf.math.imag(alphas1), dtype=tf.complex64),
            [alphas1.shape[0], alphas1.shape[1], 1],
        )
        
        re_a2 = tf.reshape(
            sqrt2 * tf.cast(tf.math.real(alphas2), dtype=tf.complex64),
            [alphas1.shape[0], alphas1.shape[1], 1],
        )
        im_a2 = tf.reshape(
            sqrt2 * tf.cast(tf.math.imag(alphas2), dtype=tf.complex64),
            [alphas2.shape[0], alphas2.shape[1], 1],
        )

        # Exponentiate diagonal matrices
        #first mode
        expm_q1 = tf.linalg.diag(tf.math.exp(1j * im_a1 * self._eig_q1))
        expm_p1 = tf.linalg.diag(tf.math.exp(-1j * re_a1 * self._eig_p1))
        expm_c1 = tf.linalg.diag(tf.math.exp(-0.5 * re_a1 * im_a1 * self._qp1_comm))
        #second mode
        expm_q2 = tf.linalg.diag(tf.math.exp(1j * im_a2 * self._eig_q2))
        expm_p2 = tf.linalg.diag(tf.math.exp(-1j * re_a2 * self._eig_p2))
        expm_c2 = tf.linalg.diag(tf.math.exp(-0.5 * re_a2 * im_a2 * self._qp2_comm))

        
        # Apply Baker-Campbell-Hausdorff to each
        disp1s =  tf.cast(
            self._U_q1
            @ expm_q1
            @ tf.linalg.adjoint(self._U_q1)
            @ self._U_p1
            @ expm_p1
            @ tf.linalg.adjoint(self._U_p1)
            @ expm_c1,
            dtype=tf.complex64,
        )
        disp2s =  tf.cast(
            self._U_q2
            @ expm_q2
            @ tf.linalg.adjoint(self._U_q2)
            @ self._U_p2
            @ expm_p2
            @ tf.linalg.adjoint(self._U_p2)
            @ expm_c2,
            dtype=tf.complex64,
        )
        
        #each dispNs object has shape (num_layers ,num_multistarts, N_cav, N_cav)
        #so each object contains num_multistatrs disp operators
        # so we gotta kron product disp op for mode 1 and 2 for each multistart in each layer
        num_layers = len(alphas1)
        num_multistarts = len(alphas1[0])
        layer_matrices= []
        for l in range(num_layers):
            multistart_matrices = []
            for m in range(num_multistarts): 
                d1 = disp1s[l][m]
                d2 = disp2s[l][m]
                #print(d1)
                tf.compat.v1.enable_eager_execution(
                config=None, device_policy=None, execution_mode=None
                ) # to make following tensor--> numpy matrix work
                operator_1 = tf.linalg.LinearOperatorFullMatrix(d1.numpy())
                operator_2 = tf.linalg.LinearOperatorFullMatrix(d2.numpy())
                operator = tf.linalg.LinearOperatorKronecker([operator_1, operator_2])
                mat = tf.cast(operator.to_dense().numpy(), dtype = tf.complex64)
                multistart_matrices.append(mat)
            layer_matrices.append(multistart_matrices)
        
        kron_disps = tf.cast(layer_matrices, dtype =tf.complex64 )
        return kron_disps

    @tf.function
    def batch_construct_block_operators(
        self, betas_rho, betas_angle,gammas_rho, gammas_angle,
        alphas1_rho, alphas1_angle, alphas2_rho, alphas2_angle,
        phis, etas, thetas
    ):
        #print('construct block ops batch called')
        # conditional displacements  (EG: a list of complex #s)
        Bs = (
            tf.cast(betas_rho, dtype=tf.complex64)
            / tf.constant(2, dtype=tf.complex64)
            * tf.math.exp(
                tf.constant(1j, dtype=tf.complex64)
                * tf.cast(betas_angle, dtype=tf.complex64)
            )
        )
        Gs = (
            tf.cast(gammas_rho, dtype=tf.complex64)
            / tf.constant(2, dtype=tf.complex64)
            * tf.math.exp(
                tf.constant(1j, dtype=tf.complex64)
                * tf.cast(gammas_angle, dtype=tf.complex64)
            )
        )

        # final displacement 
        #(EG: a list of complex #s, each complex num is alpha)
        D1 = tf.cast(alphas1_rho, dtype=tf.complex64) * tf.math.exp(
            tf.constant(1j, dtype=tf.complex64)
            * tf.cast(alphas1_angle, dtype=tf.complex64)
        )
        D2 = tf.cast(alphas2_rho, dtype=tf.complex64) * tf.math.exp(
            tf.constant(1j, dtype=tf.complex64)
            * tf.cast(alphas2_angle, dtype=tf.complex64)
        )

        ds_end = self.batch_construct_displacement_operators(D1, D2)
        ds_g = self.batch_construct_displacement_operators(Bs, Gs)
        ds_e = tf.linalg.adjoint(ds_g)

        # phi = phi - pi/2
        Phis = phis - tf.constant(np.pi, dtype=tf.float32) / tf.constant(
            2, dtype=tf.float32
        )
        #theta = theta/2
        Thetas = thetas / tf.constant(2, dtype=tf.float32)
        
        #Reshaping these lists of angles for some reason
        Phis = tf.cast(
            tf.reshape(Phis, [Phis.shape[0], Phis.shape[1], 1, 1]), dtype=tf.complex64
        )
        etas = tf.cast(
            tf.reshape(etas, [etas.shape[0], etas.shape[1], 1, 1]), dtype=tf.complex64
        )
        Thetas = tf.cast(
            tf.reshape(Thetas, [Thetas.shape[0], Thetas.shape[1], 1, 1]),
            dtype=tf.complex64,
        )
        #e^iphi
        exp = tf.math.exp(tf.constant(1j, dtype=tf.complex64) * Phis)
        im = tf.constant(1j, dtype=tf.complex64)
        exp_dag = tf.linalg.adjoint(exp)
        cos = tf.math.cos(Thetas)
        sin = tf.math.sin(Thetas)
        cos_e = tf.math.cos(etas)
        sin_e = tf.math.sin(etas)

        # constructing the blocks of the matrix
#         ul = (cos + im * sin * cos_e) * ds_g
#         ll = exp * sin * sin_e * ds_e
#         ur = tf.constant(-1, dtype=tf.complex64) * exp_dag * sin * sin_e * ds_g
#         lr = (cos - im * sin * cos_e) * ds_e
        ll = tf.constant(1j, dtype=tf.complex64) *(cos + im * sin * cos_e) * ds_g
        ul = exp * sin * sin_e * ds_e
        lr =  exp_dag * sin * sin_e * ds_g
        ur =tf.constant(1j, dtype=tf.complex64) * (cos - im * sin * cos_e) * ds_e
        # without pi pulse, block matrix is:
        # (ul, ur)
        # (ll, lr)
        # however, with pi pulse included:
        # (ll, lr)
        # (ul, ur)
        # pi pulse also adds -i phase, however don't need to trck it unless using multiple oscillators.a
        # append a final block matrix with a single displacement in each quadrant
        blocks = tf.concat(
            [
                -1j * tf.concat([tf.concat([ul, ur], 3), tf.concat([ll, lr], 3)], 2),
                tf.concat(
                    [
                        tf.concat([ds_end, tf.zeros_like(ds_end)], 3),
                        tf.concat([tf.zeros_like(ds_end), ds_end], 3),
                    ],
                    2,
                ),
            ],
            0,
        )
        return blocks


    @tf.function
    def batch_state_transfer_fidelities_real_part(
        self, betas_rho, betas_angle,gammas_rho, gammas_angle,
        alphas1_rho, alphas1_angle, alphas2_rho, alphas2_angle,
        phis, etas, thetas
    ):
        #print('batch state transfer fids real part called')
        bs = self.batch_construct_block_operators(
            betas_rho, betas_angle,gammas_rho, gammas_angle,
            alphas1_rho, alphas1_angle, alphas2_rho, alphas2_angle,
            phis, etas, thetas
        )
        psis = tf.stack([self.initial_states] * self.parameters["N_multistart"])
        for U in bs:
            psis = tf.einsum(
                "mij,msjk->msik", U, psis
            )  # m: multistart, s:multiple states
        overlaps = self.target_states_dag @ psis  # broadcasting
        overlaps = tf.reduce_mean(tf.math.real(overlaps), axis=1)
        overlaps = tf.squeeze(overlaps)
        # squeeze after reduce_mean which uses axis=1,
        # which will not exist if squeezed before for single state transfer
        # don't need to take the conjugate anymore
        fids = tf.cast(overlaps * overlaps, dtype=tf.float32)
        return fids

    def optimize(self, do_prints=True):
        #print('optimize called')
        timestamp = datetime.datetime.now().strftime(TIMESTAMP_FORMAT)
        self.timestamps.append(timestamp)
        print("Start time: " + timestamp)
        # start time
        start_time = time.time()
        optimizer = tf.optimizers.Adam(self.parameters["learning_rate"])
        if self.parameters["use_displacements"] and self.parameters["use_etas"]:
            variables = [
                self.betas_rho,
                self.betas_angle,
                self.gammas_rho,
                self.gammas_angle,
                self.alphas1_rho,
                self.alphas1_angle,
                self.alphas2_rho,
                self.alphas2_angle,
                self.phis,
                self.etas,
                self.thetas,
            ]
        elif self.parameters["use_etas"]:
            variables = [
                self.betas_rho,
                self.betas_angle,
                self.gammas_rho,
                self.gammas_angle,
                self.phis,
                self.etas,
                self.thetas,
            ]
        elif self.parameters["use_displacements"]:
            variables = [
                self.betas_rho,
                self.betas_angle,
                self.gammas_rho,
                self.gammas_angle,
                self.alphas1_rho,
                self.alphas1_angle,
                self.alphas2_rho,
                self.alphas2_angle,
                self.phis,
                self.thetas,
            ]
        else:
            variables = [
                self.betas_rho,
                self.betas_angle,
                self.gammas_rho,
                self.gammas_angle,
                self.phis,
                self.thetas,
            ]

        @tf.function
        def entry_stop_gradients(target, mask):
            #print('entry stop grad called')
            mask_h = tf.abs(mask - 1)
            return tf.stop_gradient(mask_h * target) + mask * target

        @tf.function
        def loss_fun(fids):
            #print('loss fun called')
            # I think it's important that the log is taken before the avg
            losses = tf.math.log(1 - fids)
            avg_loss = tf.reduce_sum(losses) / self.parameters["N_multistart"]
            return avg_loss

        def callback_fun(obj, fids, dfids, epoch):
            #print('callback fun called')
            elapsed_time_s = time.time() - start_time
            time_per_epoch = elapsed_time_s / epoch if epoch != 0 else 0.0
            epochs_left = self.parameters["epochs"] - epoch
            expected_time_remaining = epochs_left * time_per_epoch
            fidelities_np = np.squeeze(np.array(fids))
            betas_np, gammas_np, alphas1_np, alphas2_np, phis_np, etas_np, thetas_np = self.get_numpy_vars()
            avg_fid = tf.reduce_sum(fids) / self.parameters["N_multistart"]
            max_fid = tf.reduce_max(fids)
            avg_dfid = tf.reduce_sum(dfids) / self.parameters["N_multistart"]
            max_dfid = tf.reduce_max(dfids)
            extra_string = " (real part)" if self.parameters["use_phase"] else ""
            if do_prints:
                print(
                    "\r Epoch: %d / %d Max Fid: %.6f Avg Fid: %.6f Max dFid: %.6f Avg dFid: %.6f"
                    % (
                        epoch,
                        self.parameters["epochs"],
                        max_fid,
                        avg_fid,
                        max_dfid,
                        avg_dfid,
                    )
                    + " Elapsed time: "
                    + str(datetime.timedelta(seconds=elapsed_time_s))
                    + " Remaing time: "
                    + str(datetime.timedelta(seconds=expected_time_remaining))
                    + extra_string,
                    end="",
                )

        initial_fids = self.batch_fidelities(
            self.betas_rho,
            self.betas_angle,
            self.gammas_rho,
            self.gammas_angle,
            self.alphas1_rho,
            self.alphas1_angle,
            self.alphas2_rho,
            self.alphas2_angle,
            self.phis,
            self.etas,
            self.thetas,
        )
        fids = initial_fids
        callback_fun(self, fids, 0, 0)
        try:  # will catch keyboard inturrupt
            for epoch in range(self.parameters["epochs"] + 1)[1:]:
                for _ in range(self.parameters["epoch_size"]):
                    with tf.GradientTape() as tape:
                        betas_rho = entry_stop_gradients(self.betas_rho, self.beta_mask)
                        betas_angle = entry_stop_gradients(
                            self.betas_angle, self.beta_mask
                        )
                        gammas_rho = entry_stop_gradients(self.gammas_rho, self.gamma_mask)
                        gammas_angle = entry_stop_gradients(
                            self.gammas_angle, self.gamma_mask
                        )
                        if self.parameters["use_displacements"]:
                            alphas1_rho = entry_stop_gradients(
                                self.alphas1_rho, self.alpha1_mask
                            )
                            alphas1_angle = entry_stop_gradients(
                                self.alphas1_angle, self.alpha1_mask
                            )
                            alphas2_rho = entry_stop_gradients(
                                self.alphas2_rho, self.alpha2_mask
                            )
                            alphas2_angle = entry_stop_gradients(
                                self.alphas2_angle, self.alpha2_mask
                            )
                        else:
                            alphas1_rho = self.alphas1_rho
                            alphas1_angle = self.alphas1_angle
                            alphas2_rho = self.alphas2_rho
                            alphas2_angle = self.alphas2_angle
                        phis = entry_stop_gradients(self.phis, self.phi_mask)
                        if self.parameters["use_etas"]:
                            etas = entry_stop_gradients(self.etas, self.eta_mask)
                        else:
                            etas = self.etas
                        thetas = entry_stop_gradients(self.thetas, self.theta_mask)
                        new_fids = self.batch_fidelities(
                            betas_rho,
                            betas_angle,
                            gammas_rho,
                            gammas_angle,
                            alphas1_rho,
                            alphas1_angle,
                            alphas2_rho,
                            alphas2_angle,
                            phis,
                            etas,
                            thetas,
                        )
                        new_loss = loss_fun(new_fids)
                        dloss_dvar = tape.gradient(new_loss, variables)
                    optimizer.apply_gradients(zip(dloss_dvar, variables))
                dfids = new_fids - fids
                fids = new_fids
                callback_fun(self, fids, dfids, epoch)
                condition_fid = tf.greater(fids, self.parameters["term_fid"])
                condition_dfid = tf.greater(dfids, self.parameters["dfid_stop"])
                if tf.reduce_any(condition_fid):
                    print("\n\n Optimization stopped. Term fidelity reached.\n")
                    termination_reason = "term_fid"
                    break
                if not tf.reduce_any(condition_dfid):
                    print("\n max dFid: %6f" % tf.reduce_max(dfids).numpy())
                    print("dFid stop: %6f" % self.parameters["dfid_stop"])
                    print(
                        "\n\n Optimization stopped.  No dfid is greater than dfid_stop\n"
                    )
                    termination_reason = "dfid"
                    break
        except KeyboardInterrupt:
            print("\n max dFid: %6f" % tf.reduce_max(dfids).numpy())
            print("dFid stop: %6f" % self.parameters["dfid_stop"])
            print("\n\n Optimization stopped on keyboard interrupt")
            termination_reason = "keyboard_interrupt"

        if epoch == self.parameters["epochs"]:
            termination_reason = "epochs"
            print(
                "\n\nOptimization stopped.  Reached maximum number of epochs. Terminal fidelity not reached.\n"
            )
        #self._save_termination_reason(timestamp, termination_reason)
        timestamp_end = datetime.datetime.now().strftime(TIMESTAMP_FORMAT)
        elapsed_time_s = time.time() - start_time
        epoch_time_s = elapsed_time_s / epoch
        step_time_s = epoch_time_s / self.parameters["epochs"]
        print("all data saved as: " + self.filename)
        print("termination reason: " + termination_reason)
        print("optimization timestamp (start time): " + timestamp)
        print("timestamp (end time): " + timestamp_end)
        print("elapsed time: " + str(datetime.timedelta(seconds=elapsed_time_s)))
        print(
            "Time per epoch (epoch size = %d): " % self.parameters["epoch_size"]
            + str(datetime.timedelta(seconds=epoch_time_s))
        )
        print(
            "Time per Adam step (N_multistart = %d, N_cav1 = %d, N_cav2 = %d): "
            % (self.parameters["N_multistart"], self.parameters["N_cav1"], self.parameters["N_cav2"])
            + str(datetime.timedelta(seconds=step_time_s))
        )
        print(END_OPT_STRING)
        return timestamp

    def randomize_and_set_vars(self):
        beta_scale = self.parameters["beta_scale"]
        gamma_scale = self.parameters["gamma_scale"]
        alpha1_scale = self.parameters["alpha1_scale"]
        alpha2_scale = self.parameters["alpha2_scale"]
        theta_scale = self.parameters["theta_scale"]
        betas_rho = np.random.uniform(
            0,
            beta_scale,
            size=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
        )
        betas_angle = np.random.uniform(
            -np.pi,
            np.pi,
            size=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
        )
        gammas_rho = np.random.uniform(
            0,
            gamma_scale,
            size=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
        )
        gammas_angle = np.random.uniform(
            -np.pi,
            np.pi,
            size=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
        )
        if self.parameters["use_displacements"]:
            alphas1_rho = np.random.uniform(
                0, alpha1_scale, size=(1, self.parameters["N_multistart"]),
            )
            alphas1_angle = np.random.uniform(
                -np.pi, np.pi, size=(1, self.parameters["N_multistart"]),
            )
            alphas2_rho = np.random.uniform(
                0, alpha2_scale, size=(1, self.parameters["N_multistart"]),
            )
            alphas2_angle = np.random.uniform(
                -np.pi, np.pi, size=(1, self.parameters["N_multistart"]),
            )
        phis = np.random.uniform(
            -np.pi,
            np.pi,
            size=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
        )
        if self.parameters["use_etas"]:  # eta range is 0 to pi.
            etas = np.random.uniform(
                -np.pi,
                np.pi,
                size=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
            )
        thetas = np.random.uniform(
            -1 * theta_scale,
            theta_scale,
            size=(self.parameters["N_blocks"], self.parameters["N_multistart"]),
        )
        phis[0] = 0  # everything is relative to first phi
        if self.parameters["no_CD_end"]:
            betas_rho[-1] = 0
            betas_angle[-1] = 0
            gammas_rho[-1] = 0
            gammas_angle[-1] = 0
        self.betas_rho = tf.Variable(
            betas_rho, dtype=tf.float32, trainable=True, name="betas_rho",
        )
        self.betas_angle = tf.Variable(
            betas_angle, dtype=tf.float32, trainable=True, name="betas_angle",
        )
        self.gammas_rho = tf.Variable(
            gammas_rho, dtype=tf.float32, trainable=True, name="gammas_rho",
        )
        self.gammas_angle = tf.Variable(
            gammas_angle, dtype=tf.float32, trainable=True, name="gammas_angle",
        )
        
        
        if self.parameters["use_displacements"]:
            self.alphas1_rho = tf.Variable(
                alphas1_rho, dtype=tf.float32, trainable=True, name="alphas1_rho",
            )
            self.alphas1_angle = tf.Variable(
                alphas1_angle, dtype=tf.float32, trainable=True, name="alphas1_angle",
            )
            self.alphas2_rho = tf.Variable(
                alphas2_rho, dtype=tf.float32, trainable=True, name="alphas2_rho",
            )
            self.alphas2_angle = tf.Variable(
                alphas2_angle, dtype=tf.float32, trainable=True, name="alphas2_angle",
            )
        else:
            self.alphas1_rho = tf.constant(
                np.zeros(shape=((1, self.parameters["N_multistart"]))),
                dtype=tf.float32,
            )
            self.alphas1_angle = tf.constant(
                np.zeros(shape=((1, self.parameters["N_multistart"]))),
                dtype=tf.float32,
            )
            self.alphas2_rho = tf.constant(
                np.zeros(shape=((1, self.parameters["N_multistart"]))),
                dtype=tf.float32,
            )
            self.alphas2_angle = tf.constant(
                np.zeros(shape=((1, self.parameters["N_multistart"]))),
                dtype=tf.float32,
            )
        self.phis = tf.Variable(phis, dtype=tf.float32, trainable=True, name="phis",)
        if self.parameters["use_etas"]:
            self.etas = tf.Variable(
                etas, dtype=tf.float32, trainable=True, name="etas",
            )
        else:
            self.etas = tf.constant(
                (np.pi / 2.0) * np.ones_like(phis), dtype=tf.float32,
            )

        self.thetas = tf.Variable(
            thetas, dtype=tf.float32, trainable=True, name="thetas",
        )

    def get_numpy_vars(
        self,
        betas_rho=None,
        betas_angle=None,
        gammas_rho=None,
        gammas_angle=None,
        alphas1_rho=None,
        alphas1_angle=None,
        alphas2_rho=None,
        alphas2_angle=None,
        phis=None,
        etas=None,
        thetas=None,
    ):
        betas_rho = self.betas_rho if betas_rho is None else betas_rho
        betas_angle = self.betas_angle if betas_angle is None else betas_angle
        gammas_rho = self.gammas_rho if gammas_rho is None else gammas_rho
        gammas_angle = self.gammas_angle if gammas_angle is None else gammas_angle
        alphas1_rho = self.alphas1_rho if alphas1_rho is None else alphas1_rho
        alphas1_angle = self.alphas1_angle if alphas1_angle is None else alphas1_angle
        alphas2_rho = self.alphas2_rho if alphas2_rho is None else alphas2_rho
        alphas2_angle = self.alphas2_angle if alphas2_angle is None else alphas2_angle
        phis = self.phis if phis is None else phis
        etas = self.etas if etas is None else etas
        thetas = self.thetas if thetas is None else thetas

        betas = betas_rho.numpy() * np.exp(1j * betas_angle.numpy())
        gammas = gammas_rho.numpy() * np.exp(1j * gammas_angle.numpy())
        alphas1 = alphas1_rho.numpy() * np.exp(1j * alphas1_angle.numpy())
        alphas2 = alphas2_rho.numpy() * np.exp(1j * alphas2_angle.numpy())
        phis = phis.numpy()
        etas = etas.numpy()
        thetas = thetas.numpy()
        # now, to wrap phis, etas, and thetas so it's in the range [-pi, pi]
        phis = (phis + np.pi) % (2 * np.pi) - np.pi
        etas = (etas + np.pi) % (2 * np.pi) - np.pi
        thetas = (thetas + np.pi) % (2 * np.pi) - np.pi

        # these will have shape N_multistart x N_blocks
        return betas.T, gammas.T, alphas1.T, alphas2.T, phis.T, etas.T, thetas.T

    